基础架构

This commit is contained in:
2025-07-13 16:36:20 +08:00
parent e3d64e7485
commit 807004f78d
128 changed files with 17232 additions and 11396 deletions

View File

@@ -0,0 +1,284 @@
package cache
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/redis/go-redis/v9"
"go.uber.org/zap"
"tyapi-server/internal/shared/interfaces"
)
// RedisCache Redis缓存实现
type RedisCache struct {
client *redis.Client
logger *zap.Logger
prefix string
// 统计信息
hits int64
misses int64
}
// NewRedisCache 创建Redis缓存实例
func NewRedisCache(client *redis.Client, logger *zap.Logger, prefix string) *RedisCache {
return &RedisCache{
client: client,
logger: logger,
prefix: prefix,
}
}
// Name 返回服务名称
func (r *RedisCache) Name() string {
return "redis-cache"
}
// Initialize 初始化服务
func (r *RedisCache) Initialize(ctx context.Context) error {
// 测试连接
_, err := r.client.Ping(ctx).Result()
if err != nil {
r.logger.Error("Failed to connect to Redis", zap.Error(err))
return fmt.Errorf("redis connection failed: %w", err)
}
r.logger.Info("Redis cache service initialized")
return nil
}
// HealthCheck 健康检查
func (r *RedisCache) HealthCheck(ctx context.Context) error {
_, err := r.client.Ping(ctx).Result()
return err
}
// Shutdown 关闭服务
func (r *RedisCache) Shutdown(ctx context.Context) error {
return r.client.Close()
}
// Get 获取缓存值
func (r *RedisCache) Get(ctx context.Context, key string, dest interface{}) error {
fullKey := r.getFullKey(key)
val, err := r.client.Get(ctx, fullKey).Result()
if err != nil {
if err == redis.Nil {
r.misses++
return fmt.Errorf("cache miss: key %s not found", key)
}
r.logger.Error("Failed to get cache", zap.String("key", key), zap.Error(err))
return err
}
r.hits++
return json.Unmarshal([]byte(val), dest)
}
// Set 设置缓存值
func (r *RedisCache) Set(ctx context.Context, key string, value interface{}, ttl ...interface{}) error {
fullKey := r.getFullKey(key)
data, err := json.Marshal(value)
if err != nil {
return fmt.Errorf("failed to marshal value: %w", err)
}
var expiration time.Duration
if len(ttl) > 0 {
switch v := ttl[0].(type) {
case time.Duration:
expiration = v
case int:
expiration = time.Duration(v) * time.Second
case string:
expiration, _ = time.ParseDuration(v)
default:
expiration = 24 * time.Hour // 默认24小时
}
} else {
expiration = 24 * time.Hour // 默认24小时
}
err = r.client.Set(ctx, fullKey, data, expiration).Err()
if err != nil {
r.logger.Error("Failed to set cache", zap.String("key", key), zap.Error(err))
return err
}
return nil
}
// Delete 删除缓存
func (r *RedisCache) Delete(ctx context.Context, keys ...string) error {
if len(keys) == 0 {
return nil
}
fullKeys := make([]string, len(keys))
for i, key := range keys {
fullKeys[i] = r.getFullKey(key)
}
err := r.client.Del(ctx, fullKeys...).Err()
if err != nil {
r.logger.Error("Failed to delete cache", zap.Strings("keys", keys), zap.Error(err))
return err
}
return nil
}
// Exists 检查键是否存在
func (r *RedisCache) Exists(ctx context.Context, key string) (bool, error) {
fullKey := r.getFullKey(key)
count, err := r.client.Exists(ctx, fullKey).Result()
if err != nil {
return false, err
}
return count > 0, nil
}
// GetMultiple 批量获取
func (r *RedisCache) GetMultiple(ctx context.Context, keys []string) (map[string]interface{}, error) {
if len(keys) == 0 {
return make(map[string]interface{}), nil
}
fullKeys := make([]string, len(keys))
for i, key := range keys {
fullKeys[i] = r.getFullKey(key)
}
values, err := r.client.MGet(ctx, fullKeys...).Result()
if err != nil {
return nil, err
}
result := make(map[string]interface{})
for i, val := range values {
if val != nil {
var data interface{}
if err := json.Unmarshal([]byte(val.(string)), &data); err == nil {
result[keys[i]] = data
}
}
}
return result, nil
}
// SetMultiple 批量设置
func (r *RedisCache) SetMultiple(ctx context.Context, data map[string]interface{}, ttl ...interface{}) error {
if len(data) == 0 {
return nil
}
var expiration time.Duration
if len(ttl) > 0 {
switch v := ttl[0].(type) {
case time.Duration:
expiration = v
case int:
expiration = time.Duration(v) * time.Second
default:
expiration = 24 * time.Hour
}
} else {
expiration = 24 * time.Hour
}
pipe := r.client.Pipeline()
for key, value := range data {
fullKey := r.getFullKey(key)
jsonData, err := json.Marshal(value)
if err != nil {
continue
}
pipe.Set(ctx, fullKey, jsonData, expiration)
}
_, err := pipe.Exec(ctx)
return err
}
// DeletePattern 按模式删除
func (r *RedisCache) DeletePattern(ctx context.Context, pattern string) error {
fullPattern := r.getFullKey(pattern)
keys, err := r.client.Keys(ctx, fullPattern).Result()
if err != nil {
return err
}
if len(keys) > 0 {
return r.client.Del(ctx, keys...).Err()
}
return nil
}
// Keys 获取匹配的键
func (r *RedisCache) Keys(ctx context.Context, pattern string) ([]string, error) {
fullPattern := r.getFullKey(pattern)
keys, err := r.client.Keys(ctx, fullPattern).Result()
if err != nil {
return nil, err
}
// 移除前缀
result := make([]string, len(keys))
prefixLen := len(r.prefix) + 1 // +1 for ":"
for i, key := range keys {
if len(key) > prefixLen {
result[i] = key[prefixLen:]
} else {
result[i] = key
}
}
return result, nil
}
// Stats 获取缓存统计
func (r *RedisCache) Stats(ctx context.Context) (interfaces.CacheStats, error) {
dbSize, _ := r.client.DBSize(ctx).Result()
return interfaces.CacheStats{
Hits: r.hits,
Misses: r.misses,
Keys: dbSize,
Memory: 0, // 暂时设为0后续可解析Redis info
Connections: 0, // 暂时设为0后续可解析Redis info
}, nil
}
// getFullKey 获取完整键名
func (r *RedisCache) getFullKey(key string) string {
if r.prefix == "" {
return key
}
return fmt.Sprintf("%s:%s", r.prefix, key)
}
// Flush 清空所有缓存
func (r *RedisCache) Flush(ctx context.Context) error {
if r.prefix == "" {
return r.client.FlushDB(ctx).Err()
}
// 只删除带前缀的键
return r.DeletePattern(ctx, "*")
}
// GetClient 获取原始Redis客户端
func (r *RedisCache) GetClient() *redis.Client {
return r.client
}

View File

@@ -0,0 +1,199 @@
package database
import (
"context"
"fmt"
"time"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"gorm.io/gorm/schema"
)
// Config 数据库配置
type Config struct {
Host string
Port string
User string
Password string
Name string
SSLMode string
Timezone string
MaxOpenConns int
MaxIdleConns int
ConnMaxLifetime time.Duration
}
// DB 数据库包装器
type DB struct {
*gorm.DB
config Config
}
// NewConnection 创建新的数据库连接
func NewConnection(config Config) (*DB, error) {
// 构建DSN
dsn := buildDSN(config)
// 配置GORM
gormConfig := &gorm.Config{
Logger: logger.Default.LogMode(logger.Info),
NamingStrategy: schema.NamingStrategy{
SingularTable: true, // 使用单数表名
},
DisableForeignKeyConstraintWhenMigrating: true,
NowFunc: func() time.Time {
return time.Now().In(time.FixedZone("CST", 8*3600)) // 强制使用北京时间
},
}
// 连接数据库
db, err := gorm.Open(postgres.Open(dsn), gormConfig)
if err != nil {
return nil, fmt.Errorf("连接数据库失败: %w", err)
}
// 获取底层sql.DB
sqlDB, err := db.DB()
if err != nil {
return nil, fmt.Errorf("获取数据库实例失败: %w", err)
}
// 配置连接池
sqlDB.SetMaxOpenConns(config.MaxOpenConns)
sqlDB.SetMaxIdleConns(config.MaxIdleConns)
sqlDB.SetConnMaxLifetime(config.ConnMaxLifetime)
// 测试连接
if err := sqlDB.Ping(); err != nil {
return nil, fmt.Errorf("数据库连接测试失败: %w", err)
}
return &DB{
DB: db,
config: config,
}, nil
}
// buildDSN 构建数据库连接字符串
func buildDSN(config Config) string {
return fmt.Sprintf(
"host=%s user=%s password=%s dbname=%s port=%s sslmode=%s TimeZone=%s options='-c timezone=%s'",
config.Host,
config.User,
config.Password,
config.Name,
config.Port,
config.SSLMode,
config.Timezone,
config.Timezone,
)
}
// Close 关闭数据库连接
func (db *DB) Close() error {
sqlDB, err := db.DB.DB()
if err != nil {
return err
}
return sqlDB.Close()
}
// Ping 检查数据库连接
func (db *DB) Ping() error {
sqlDB, err := db.DB.DB()
if err != nil {
return err
}
return sqlDB.Ping()
}
// GetStats 获取连接池统计信息
func (db *DB) GetStats() (map[string]interface{}, error) {
sqlDB, err := db.DB.DB()
if err != nil {
return nil, err
}
stats := sqlDB.Stats()
return map[string]interface{}{
"max_open_connections": stats.MaxOpenConnections,
"open_connections": stats.OpenConnections,
"in_use": stats.InUse,
"idle": stats.Idle,
"wait_count": stats.WaitCount,
"wait_duration": stats.WaitDuration,
"max_idle_closed": stats.MaxIdleClosed,
"max_idle_time_closed": stats.MaxIdleTimeClosed,
"max_lifetime_closed": stats.MaxLifetimeClosed,
}, nil
}
// BeginTx 开始事务
func (db *DB) BeginTx() *gorm.DB {
return db.DB.Begin()
}
// Migrate 执行数据库迁移
func (db *DB) Migrate(models ...interface{}) error {
return db.DB.AutoMigrate(models...)
}
// IsHealthy 检查数据库健康状态
func (db *DB) IsHealthy() bool {
return db.Ping() == nil
}
// WithContext 返回带上下文的数据库实例
func (db *DB) WithContext(ctx interface{}) *gorm.DB {
if c, ok := ctx.(context.Context); ok {
return db.DB.WithContext(c)
}
return db.DB
}
// 事务包装器
type TxWrapper struct {
tx *gorm.DB
}
// NewTxWrapper 创建事务包装器
func (db *DB) NewTxWrapper() *TxWrapper {
return &TxWrapper{
tx: db.BeginTx(),
}
}
// Commit 提交事务
func (tx *TxWrapper) Commit() error {
return tx.tx.Commit().Error
}
// Rollback 回滚事务
func (tx *TxWrapper) Rollback() error {
return tx.tx.Rollback().Error
}
// GetDB 获取事务数据库实例
func (tx *TxWrapper) GetDB() *gorm.DB {
return tx.tx
}
// WithTx 在事务中执行函数
func (db *DB) WithTx(fn func(*gorm.DB) error) error {
tx := db.BeginTx()
defer func() {
if r := recover(); r != nil {
tx.Rollback()
panic(r)
}
}()
if err := fn(tx); err != nil {
tx.Rollback()
return err
}
return tx.Commit().Error
}

View File

@@ -0,0 +1,225 @@
package repositories
import (
"context"
"time"
"go.uber.org/zap"
"gorm.io/gorm"
"tyapi-server/internal/domains/admin/entities"
"tyapi-server/internal/domains/admin/repositories"
"tyapi-server/internal/domains/admin/repositories/queries"
"tyapi-server/internal/shared/interfaces"
)
// GormAdminLoginLogRepository 管理员登录日志GORM仓储实现
type GormAdminLoginLogRepository struct {
db *gorm.DB
logger *zap.Logger
}
// 编译时检查接口实现
var _ repositories.AdminLoginLogRepository = (*GormAdminLoginLogRepository)(nil)
// NewGormAdminLoginLogRepository 创建管理员登录日志GORM仓储
func NewGormAdminLoginLogRepository(db *gorm.DB, logger *zap.Logger) repositories.AdminLoginLogRepository {
return &GormAdminLoginLogRepository{
db: db,
logger: logger,
}
}
// ================ 基础CRUD操作 ================
// Create 创建登录日志
func (r *GormAdminLoginLogRepository) Create(ctx context.Context, log entities.AdminLoginLog) (entities.AdminLoginLog, error) {
r.logger.Info("创建管理员登录日志", zap.String("admin_id", log.AdminID))
err := r.db.WithContext(ctx).Create(&log).Error
return log, err
}
// GetByID 根据ID获取登录日志
func (r *GormAdminLoginLogRepository) GetByID(ctx context.Context, id string) (entities.AdminLoginLog, error) {
var log entities.AdminLoginLog
err := r.db.WithContext(ctx).Where("id = ?", id).First(&log).Error
return log, err
}
// Update 更新登录日志
func (r *GormAdminLoginLogRepository) Update(ctx context.Context, log entities.AdminLoginLog) error {
r.logger.Info("更新管理员登录日志", zap.String("id", log.ID))
return r.db.WithContext(ctx).Save(&log).Error
}
// Delete 删除登录日志
func (r *GormAdminLoginLogRepository) Delete(ctx context.Context, id string) error {
r.logger.Info("删除管理员登录日志", zap.String("id", id))
return r.db.WithContext(ctx).Delete(&entities.AdminLoginLog{}, "id = ?", id).Error
}
// SoftDelete 软删除登录日志
func (r *GormAdminLoginLogRepository) SoftDelete(ctx context.Context, id string) error {
r.logger.Info("软删除管理员登录日志", zap.String("id", id))
return r.db.WithContext(ctx).Delete(&entities.AdminLoginLog{}, "id = ?", id).Error
}
// Restore 恢复登录日志
func (r *GormAdminLoginLogRepository) Restore(ctx context.Context, id string) error {
r.logger.Info("恢复管理员登录日志", zap.String("id", id))
return r.db.WithContext(ctx).Unscoped().Model(&entities.AdminLoginLog{}).Where("id = ?", id).Update("deleted_at", nil).Error
}
// Count 统计登录日志数量
func (r *GormAdminLoginLogRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
var count int64
query := r.db.WithContext(ctx).Model(&entities.AdminLoginLog{})
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
if options.Search != "" {
query = query.Where("admin_id LIKE ? OR ip_address LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
}
return count, query.Count(&count).Error
}
// Exists 检查登录日志是否存在
func (r *GormAdminLoginLogRepository) Exists(ctx context.Context, id string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.AdminLoginLog{}).Where("id = ?", id).Count(&count).Error
return count > 0, err
}
// CreateBatch 批量创建登录日志
func (r *GormAdminLoginLogRepository) CreateBatch(ctx context.Context, logs []entities.AdminLoginLog) error {
r.logger.Info("批量创建管理员登录日志", zap.Int("count", len(logs)))
return r.db.WithContext(ctx).Create(&logs).Error
}
// GetByIDs 根据ID列表获取登录日志
func (r *GormAdminLoginLogRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.AdminLoginLog, error) {
var logs []entities.AdminLoginLog
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&logs).Error
return logs, err
}
// UpdateBatch 批量更新登录日志
func (r *GormAdminLoginLogRepository) UpdateBatch(ctx context.Context, logs []entities.AdminLoginLog) error {
r.logger.Info("批量更新管理员登录日志", zap.Int("count", len(logs)))
return r.db.WithContext(ctx).Save(&logs).Error
}
// DeleteBatch 批量删除登录日志
func (r *GormAdminLoginLogRepository) DeleteBatch(ctx context.Context, ids []string) error {
r.logger.Info("批量删除管理员登录日志", zap.Strings("ids", ids))
return r.db.WithContext(ctx).Delete(&entities.AdminLoginLog{}, "id IN ?", ids).Error
}
// List 获取登录日志列表
func (r *GormAdminLoginLogRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.AdminLoginLog, error) {
var logs []entities.AdminLoginLog
query := r.db.WithContext(ctx).Model(&entities.AdminLoginLog{})
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
if options.Search != "" {
query = query.Where("admin_id LIKE ? OR ip_address LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
}
if options.Sort != "" {
order := "ASC"
if options.Order != "" {
order = options.Order
}
query = query.Order(options.Sort + " " + order)
}
if options.Page > 0 && options.PageSize > 0 {
offset := (options.Page - 1) * options.PageSize
query = query.Offset(offset).Limit(options.PageSize)
}
return logs, query.Find(&logs).Error
}
// WithTx 使用事务
func (r *GormAdminLoginLogRepository) WithTx(tx interface{}) interfaces.Repository[entities.AdminLoginLog] {
if gormTx, ok := tx.(*gorm.DB); ok {
return &GormAdminLoginLogRepository{
db: gormTx,
logger: r.logger,
}
}
return r
}
// ================ 业务方法 ================
// ListLogs 获取登录日志列表(带分页和筛选)
func (r *GormAdminLoginLogRepository) ListLogs(ctx context.Context, query *queries.ListAdminLoginLogQuery) ([]*entities.AdminLoginLog, int64, error) {
var logs []entities.AdminLoginLog
var total int64
dbQuery := r.db.WithContext(ctx).Model(&entities.AdminLoginLog{})
// 应用筛选条件
if query.AdminID != "" {
dbQuery = dbQuery.Where("admin_id = ?", query.AdminID)
}
if query.StartDate != "" {
dbQuery = dbQuery.Where("created_at >= ?", query.StartDate)
}
if query.EndDate != "" {
dbQuery = dbQuery.Where("created_at <= ?", query.EndDate)
}
// 统计总数
if err := dbQuery.Count(&total).Error; err != nil {
return nil, 0, err
}
// 应用分页
offset := (query.Page - 1) * query.PageSize
dbQuery = dbQuery.Offset(offset).Limit(query.PageSize)
// 默认排序
dbQuery = dbQuery.Order("created_at DESC")
// 查询数据
if err := dbQuery.Find(&logs).Error; err != nil {
return nil, 0, err
}
// 转换为指针切片
logPtrs := make([]*entities.AdminLoginLog, len(logs))
for i := range logs {
logPtrs[i] = &logs[i]
}
return logPtrs, total, nil
}
// GetTodayLoginCount 获取今日登录次数
func (r *GormAdminLoginLogRepository) GetTodayLoginCount(ctx context.Context) (int64, error) {
var count int64
today := time.Now().Truncate(24 * time.Hour)
err := r.db.WithContext(ctx).Model(&entities.AdminLoginLog{}).Where("created_at >= ?", today).Count(&count).Error
return count, err
}
// GetLoginCountByAdmin 获取指定管理员在指定天数内的登录次数
func (r *GormAdminLoginLogRepository) GetLoginCountByAdmin(ctx context.Context, adminID string, days int) (int64, error) {
var count int64
startDate := time.Now().AddDate(0, 0, -days)
err := r.db.WithContext(ctx).Model(&entities.AdminLoginLog{}).Where("admin_id = ? AND created_at >= ?", adminID, startDate).Count(&count).Error
return count, err
}

View File

@@ -0,0 +1,236 @@
package repositories
import (
"context"
"time"
"go.uber.org/zap"
"gorm.io/gorm"
"tyapi-server/internal/domains/admin/entities"
"tyapi-server/internal/domains/admin/repositories"
"tyapi-server/internal/domains/admin/repositories/queries"
"tyapi-server/internal/shared/interfaces"
)
// GormAdminOperationLogRepository 管理员操作日志GORM仓储实现
type GormAdminOperationLogRepository struct {
db *gorm.DB
logger *zap.Logger
}
// 编译时检查接口实现
var _ repositories.AdminOperationLogRepository = (*GormAdminOperationLogRepository)(nil)
// NewGormAdminOperationLogRepository 创建管理员操作日志GORM仓储
func NewGormAdminOperationLogRepository(db *gorm.DB, logger *zap.Logger) repositories.AdminOperationLogRepository {
return &GormAdminOperationLogRepository{
db: db,
logger: logger,
}
}
// ================ 基础CRUD操作 ================
// Create 创建操作日志
func (r *GormAdminOperationLogRepository) Create(ctx context.Context, log entities.AdminOperationLog) (entities.AdminOperationLog, error) {
r.logger.Info("创建管理员操作日志", zap.String("admin_id", log.AdminID), zap.String("action", log.Action))
err := r.db.WithContext(ctx).Create(&log).Error
return log, err
}
// GetByID 根据ID获取操作日志
func (r *GormAdminOperationLogRepository) GetByID(ctx context.Context, id string) (entities.AdminOperationLog, error) {
var log entities.AdminOperationLog
err := r.db.WithContext(ctx).Where("id = ?", id).First(&log).Error
return log, err
}
// Update 更新操作日志
func (r *GormAdminOperationLogRepository) Update(ctx context.Context, log entities.AdminOperationLog) error {
r.logger.Info("更新管理员操作日志", zap.String("id", log.ID))
return r.db.WithContext(ctx).Save(&log).Error
}
// Delete 删除操作日志
func (r *GormAdminOperationLogRepository) Delete(ctx context.Context, id string) error {
r.logger.Info("删除管理员操作日志", zap.String("id", id))
return r.db.WithContext(ctx).Delete(&entities.AdminOperationLog{}, "id = ?", id).Error
}
// SoftDelete 软删除操作日志
func (r *GormAdminOperationLogRepository) SoftDelete(ctx context.Context, id string) error {
r.logger.Info("软删除管理员操作日志", zap.String("id", id))
return r.db.WithContext(ctx).Delete(&entities.AdminOperationLog{}, "id = ?", id).Error
}
// Restore 恢复操作日志
func (r *GormAdminOperationLogRepository) Restore(ctx context.Context, id string) error {
r.logger.Info("恢复管理员操作日志", zap.String("id", id))
return r.db.WithContext(ctx).Unscoped().Model(&entities.AdminOperationLog{}).Where("id = ?", id).Update("deleted_at", nil).Error
}
// Count 统计操作日志数量
func (r *GormAdminOperationLogRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
var count int64
query := r.db.WithContext(ctx).Model(&entities.AdminOperationLog{})
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
if options.Search != "" {
query = query.Where("admin_id LIKE ? OR action LIKE ? OR module LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%", "%"+options.Search+"%")
}
return count, query.Count(&count).Error
}
// Exists 检查操作日志是否存在
func (r *GormAdminOperationLogRepository) Exists(ctx context.Context, id string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.AdminOperationLog{}).Where("id = ?", id).Count(&count).Error
return count > 0, err
}
// CreateBatch 批量创建操作日志
func (r *GormAdminOperationLogRepository) CreateBatch(ctx context.Context, logs []entities.AdminOperationLog) error {
r.logger.Info("批量创建管理员操作日志", zap.Int("count", len(logs)))
return r.db.WithContext(ctx).Create(&logs).Error
}
// GetByIDs 根据ID列表获取操作日志
func (r *GormAdminOperationLogRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.AdminOperationLog, error) {
var logs []entities.AdminOperationLog
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&logs).Error
return logs, err
}
// UpdateBatch 批量更新操作日志
func (r *GormAdminOperationLogRepository) UpdateBatch(ctx context.Context, logs []entities.AdminOperationLog) error {
r.logger.Info("批量更新管理员操作日志", zap.Int("count", len(logs)))
return r.db.WithContext(ctx).Save(&logs).Error
}
// DeleteBatch 批量删除操作日志
func (r *GormAdminOperationLogRepository) DeleteBatch(ctx context.Context, ids []string) error {
r.logger.Info("批量删除管理员操作日志", zap.Strings("ids", ids))
return r.db.WithContext(ctx).Delete(&entities.AdminOperationLog{}, "id IN ?", ids).Error
}
// List 获取操作日志列表
func (r *GormAdminOperationLogRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.AdminOperationLog, error) {
var logs []entities.AdminOperationLog
query := r.db.WithContext(ctx).Model(&entities.AdminOperationLog{})
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
if options.Search != "" {
query = query.Where("admin_id LIKE ? OR action LIKE ? OR module LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%", "%"+options.Search+"%")
}
if options.Sort != "" {
order := "ASC"
if options.Order != "" {
order = options.Order
}
query = query.Order(options.Sort + " " + order)
}
if options.Page > 0 && options.PageSize > 0 {
offset := (options.Page - 1) * options.PageSize
query = query.Offset(offset).Limit(options.PageSize)
}
return logs, query.Find(&logs).Error
}
// WithTx 使用事务
func (r *GormAdminOperationLogRepository) WithTx(tx interface{}) interfaces.Repository[entities.AdminOperationLog] {
if gormTx, ok := tx.(*gorm.DB); ok {
return &GormAdminOperationLogRepository{
db: gormTx,
logger: r.logger,
}
}
return r
}
// ================ 业务方法 ================
// ListLogs 获取操作日志列表(带分页和筛选)
func (r *GormAdminOperationLogRepository) ListLogs(ctx context.Context, query *queries.ListAdminOperationLogQuery) ([]*entities.AdminOperationLog, int64, error) {
var logs []entities.AdminOperationLog
var total int64
dbQuery := r.db.WithContext(ctx).Model(&entities.AdminOperationLog{})
// 应用筛选条件
if query.AdminID != "" {
dbQuery = dbQuery.Where("admin_id = ?", query.AdminID)
}
if query.Module != "" {
dbQuery = dbQuery.Where("module = ?", query.Module)
}
if query.Action != "" {
dbQuery = dbQuery.Where("action = ?", query.Action)
}
if query.StartDate != "" {
dbQuery = dbQuery.Where("created_at >= ?", query.StartDate)
}
if query.EndDate != "" {
dbQuery = dbQuery.Where("created_at <= ?", query.EndDate)
}
// 统计总数
if err := dbQuery.Count(&total).Error; err != nil {
return nil, 0, err
}
// 应用分页
offset := (query.Page - 1) * query.PageSize
dbQuery = dbQuery.Offset(offset).Limit(query.PageSize)
// 默认排序
dbQuery = dbQuery.Order("created_at DESC")
// 查询数据
if err := dbQuery.Find(&logs).Error; err != nil {
return nil, 0, err
}
// 转换为指针切片
logPtrs := make([]*entities.AdminOperationLog, len(logs))
for i := range logs {
logPtrs[i] = &logs[i]
}
return logPtrs, total, nil
}
// GetTotalOperations 获取总操作数
func (r *GormAdminOperationLogRepository) GetTotalOperations(ctx context.Context) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.AdminOperationLog{}).Count(&count).Error
return count, err
}
// GetOperationsByAdmin 获取指定管理员在指定天数内的操作数
func (r *GormAdminOperationLogRepository) GetOperationsByAdmin(ctx context.Context, adminID string, days int) (int64, error) {
var count int64
startDate := time.Now().AddDate(0, 0, -days)
err := r.db.WithContext(ctx).Model(&entities.AdminOperationLog{}).Where("admin_id = ? AND created_at >= ?", adminID, startDate).Count(&count).Error
return count, err
}
// BatchCreate 批量创建操作日志
func (r *GormAdminOperationLogRepository) BatchCreate(ctx context.Context, logs []entities.AdminOperationLog) error {
r.logger.Info("批量创建管理员操作日志", zap.Int("count", len(logs)))
return r.db.WithContext(ctx).Create(&logs).Error
}

View File

@@ -0,0 +1,222 @@
package repositories
import (
"context"
"go.uber.org/zap"
"gorm.io/gorm"
"tyapi-server/internal/domains/admin/entities"
"tyapi-server/internal/domains/admin/repositories"
"tyapi-server/internal/shared/interfaces"
)
// GormAdminPermissionRepository 管理员权限GORM仓储实现
type GormAdminPermissionRepository struct {
db *gorm.DB
logger *zap.Logger
}
// 编译时检查接口实现
var _ repositories.AdminPermissionRepository = (*GormAdminPermissionRepository)(nil)
// NewGormAdminPermissionRepository 创建管理员权限GORM仓储
func NewGormAdminPermissionRepository(db *gorm.DB, logger *zap.Logger) repositories.AdminPermissionRepository {
return &GormAdminPermissionRepository{
db: db,
logger: logger,
}
}
// ================ 基础CRUD操作 ================
// Create 创建权限
func (r *GormAdminPermissionRepository) Create(ctx context.Context, permission entities.AdminPermission) (entities.AdminPermission, error) {
r.logger.Info("创建管理员权限", zap.String("code", permission.Code))
err := r.db.WithContext(ctx).Create(&permission).Error
return permission, err
}
// GetByID 根据ID获取权限
func (r *GormAdminPermissionRepository) GetByID(ctx context.Context, id string) (entities.AdminPermission, error) {
var permission entities.AdminPermission
err := r.db.WithContext(ctx).Where("id = ?", id).First(&permission).Error
return permission, err
}
// Update 更新权限
func (r *GormAdminPermissionRepository) Update(ctx context.Context, permission entities.AdminPermission) error {
r.logger.Info("更新管理员权限", zap.String("id", permission.ID))
return r.db.WithContext(ctx).Save(&permission).Error
}
// Delete 删除权限
func (r *GormAdminPermissionRepository) Delete(ctx context.Context, id string) error {
r.logger.Info("删除管理员权限", zap.String("id", id))
return r.db.WithContext(ctx).Delete(&entities.AdminPermission{}, "id = ?", id).Error
}
// SoftDelete 软删除权限
func (r *GormAdminPermissionRepository) SoftDelete(ctx context.Context, id string) error {
r.logger.Info("软删除管理员权限", zap.String("id", id))
return r.db.WithContext(ctx).Delete(&entities.AdminPermission{}, "id = ?", id).Error
}
// Restore 恢复权限
func (r *GormAdminPermissionRepository) Restore(ctx context.Context, id string) error {
r.logger.Info("恢复管理员权限", zap.String("id", id))
return r.db.WithContext(ctx).Unscoped().Model(&entities.AdminPermission{}).Where("id = ?", id).Update("deleted_at", nil).Error
}
// Count 统计权限数量
func (r *GormAdminPermissionRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
var count int64
query := r.db.WithContext(ctx).Model(&entities.AdminPermission{})
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
if options.Search != "" {
query = query.Where("code LIKE ? OR name LIKE ? OR module LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%", "%"+options.Search+"%")
}
return count, query.Count(&count).Error
}
// Exists 检查权限是否存在
func (r *GormAdminPermissionRepository) Exists(ctx context.Context, id string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.AdminPermission{}).Where("id = ?", id).Count(&count).Error
return count > 0, err
}
// CreateBatch 批量创建权限
func (r *GormAdminPermissionRepository) CreateBatch(ctx context.Context, permissions []entities.AdminPermission) error {
r.logger.Info("批量创建管理员权限", zap.Int("count", len(permissions)))
return r.db.WithContext(ctx).Create(&permissions).Error
}
// GetByIDs 根据ID列表获取权限
func (r *GormAdminPermissionRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.AdminPermission, error) {
var permissions []entities.AdminPermission
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&permissions).Error
return permissions, err
}
// UpdateBatch 批量更新权限
func (r *GormAdminPermissionRepository) UpdateBatch(ctx context.Context, permissions []entities.AdminPermission) error {
r.logger.Info("批量更新管理员权限", zap.Int("count", len(permissions)))
return r.db.WithContext(ctx).Save(&permissions).Error
}
// DeleteBatch 批量删除权限
func (r *GormAdminPermissionRepository) DeleteBatch(ctx context.Context, ids []string) error {
r.logger.Info("批量删除管理员权限", zap.Strings("ids", ids))
return r.db.WithContext(ctx).Delete(&entities.AdminPermission{}, "id IN ?", ids).Error
}
// List 获取权限列表
func (r *GormAdminPermissionRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.AdminPermission, error) {
var permissions []entities.AdminPermission
query := r.db.WithContext(ctx).Model(&entities.AdminPermission{})
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
if options.Search != "" {
query = query.Where("code LIKE ? OR name LIKE ? OR module LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%", "%"+options.Search+"%")
}
if options.Sort != "" {
order := "ASC"
if options.Order != "" {
order = options.Order
}
query = query.Order(options.Sort + " " + order)
}
if options.Page > 0 && options.PageSize > 0 {
offset := (options.Page - 1) * options.PageSize
query = query.Offset(offset).Limit(options.PageSize)
}
return permissions, query.Find(&permissions).Error
}
// WithTx 使用事务
func (r *GormAdminPermissionRepository) WithTx(tx interface{}) interfaces.Repository[entities.AdminPermission] {
if gormTx, ok := tx.(*gorm.DB); ok {
return &GormAdminPermissionRepository{
db: gormTx,
logger: r.logger,
}
}
return r
}
// ================ 业务方法 ================
// FindByCode 根据权限代码查找权限
func (r *GormAdminPermissionRepository) FindByCode(ctx context.Context, code string) (*entities.AdminPermission, error) {
var permission entities.AdminPermission
err := r.db.WithContext(ctx).Where("code = ?", code).First(&permission).Error
if err != nil {
return nil, err
}
return &permission, nil
}
// FindByModule 根据模块查找权限
func (r *GormAdminPermissionRepository) FindByModule(ctx context.Context, module string) ([]entities.AdminPermission, error) {
var permissions []entities.AdminPermission
err := r.db.WithContext(ctx).Where("module = ?", module).Find(&permissions).Error
return permissions, err
}
// ListActive 获取所有激活的权限
func (r *GormAdminPermissionRepository) ListActive(ctx context.Context) ([]entities.AdminPermission, error) {
var permissions []entities.AdminPermission
err := r.db.WithContext(ctx).Where("is_active = ?", true).Find(&permissions).Error
return permissions, err
}
// GetPermissionsByRole 根据角色获取权限
func (r *GormAdminPermissionRepository) GetPermissionsByRole(ctx context.Context, role entities.AdminRole) ([]entities.AdminPermission, error) {
var permissions []entities.AdminPermission
query := r.db.WithContext(ctx).
Joins("JOIN admin_role_permissions ON admin_permissions.id = admin_role_permissions.permission_id").
Where("admin_role_permissions.role = ? AND admin_permissions.is_active = ?", role, true)
return permissions, query.Find(&permissions).Error
}
// AssignPermissionsToRole 为角色分配权限
func (r *GormAdminPermissionRepository) AssignPermissionsToRole(ctx context.Context, role entities.AdminRole, permissionIDs []string) error {
// 先删除现有权限
if err := r.db.WithContext(ctx).Where("role = ?", role).Delete(&entities.AdminRolePermission{}).Error; err != nil {
return err
}
// 批量插入新权限
var rolePermissions []entities.AdminRolePermission
for _, permissionID := range permissionIDs {
rolePermissions = append(rolePermissions, entities.AdminRolePermission{
Role: role,
PermissionID: permissionID,
})
}
return r.db.WithContext(ctx).Create(&rolePermissions).Error
}
// RemovePermissionsFromRole 从角色移除权限
func (r *GormAdminPermissionRepository) RemovePermissionsFromRole(ctx context.Context, role entities.AdminRole, permissionIDs []string) error {
return r.db.WithContext(ctx).Where("role = ? AND permission_id IN ?", role, permissionIDs).Delete(&entities.AdminRolePermission{}).Error
}

View File

@@ -0,0 +1,319 @@
package repositories
import (
"context"
"encoding/json"
"fmt"
"time"
"go.uber.org/zap"
"gorm.io/gorm"
"tyapi-server/internal/domains/admin/entities"
"tyapi-server/internal/domains/admin/repositories"
"tyapi-server/internal/domains/admin/repositories/queries"
"tyapi-server/internal/shared/interfaces"
)
// GormAdminRepository 管理员GORM仓储实现
type GormAdminRepository struct {
db *gorm.DB
logger *zap.Logger
}
// 编译时检查接口实现
var _ repositories.AdminRepository = (*GormAdminRepository)(nil)
// NewGormAdminRepository 创建管理员GORM仓储
func NewGormAdminRepository(db *gorm.DB, logger *zap.Logger) *GormAdminRepository {
return &GormAdminRepository{
db: db,
logger: logger,
}
}
// Create 创建管理员
func (r *GormAdminRepository) Create(ctx context.Context, admin entities.Admin) (entities.Admin, error) {
r.logger.Info("创建管理员", zap.String("username", admin.Username))
err := r.db.WithContext(ctx).Create(&admin).Error
return admin, err
}
// GetByID 根据ID获取管理员
func (r *GormAdminRepository) GetByID(ctx context.Context, id string) (entities.Admin, error) {
var admin entities.Admin
err := r.db.WithContext(ctx).Where("id = ?", id).First(&admin).Error
return admin, err
}
// Update 更新管理员
func (r *GormAdminRepository) Update(ctx context.Context, admin entities.Admin) error {
r.logger.Info("更新管理员", zap.String("id", admin.ID))
return r.db.WithContext(ctx).Save(&admin).Error
}
// Delete 删除管理员
func (r *GormAdminRepository) Delete(ctx context.Context, id string) error {
r.logger.Info("删除管理员", zap.String("id", id))
return r.db.WithContext(ctx).Delete(&entities.Admin{}, "id = ?", id).Error
}
// SoftDelete 软删除管理员
func (r *GormAdminRepository) SoftDelete(ctx context.Context, id string) error {
r.logger.Info("软删除管理员", zap.String("id", id))
return r.db.WithContext(ctx).Delete(&entities.Admin{}, "id = ?", id).Error
}
// Restore 恢复管理员
func (r *GormAdminRepository) Restore(ctx context.Context, id string) error {
r.logger.Info("恢复管理员", zap.String("id", id))
return r.db.WithContext(ctx).Unscoped().Model(&entities.Admin{}).Where("id = ?", id).Update("deleted_at", nil).Error
}
// Count 统计管理员数量
func (r *GormAdminRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
var count int64
query := r.db.WithContext(ctx).Model(&entities.Admin{})
// 应用过滤条件
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
// 应用搜索条件
if options.Search != "" {
query = query.Where("username LIKE ? OR email LIKE ? OR real_name LIKE ?",
"%"+options.Search+"%", "%"+options.Search+"%", "%"+options.Search+"%")
}
return count, query.Count(&count).Error
}
// Exists 检查管理员是否存在
func (r *GormAdminRepository) Exists(ctx context.Context, id string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.Admin{}).Where("id = ?", id).Count(&count).Error
return count > 0, err
}
// CreateBatch 批量创建管理员
func (r *GormAdminRepository) CreateBatch(ctx context.Context, admins []entities.Admin) error {
r.logger.Info("批量创建管理员", zap.Int("count", len(admins)))
return r.db.WithContext(ctx).Create(&admins).Error
}
// GetByIDs 根据ID列表获取管理员
func (r *GormAdminRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.Admin, error) {
var admins []entities.Admin
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&admins).Error
return admins, err
}
// UpdateBatch 批量更新管理员
func (r *GormAdminRepository) UpdateBatch(ctx context.Context, admins []entities.Admin) error {
r.logger.Info("批量更新管理员", zap.Int("count", len(admins)))
return r.db.WithContext(ctx).Save(&admins).Error
}
// DeleteBatch 批量删除管理员
func (r *GormAdminRepository) DeleteBatch(ctx context.Context, ids []string) error {
r.logger.Info("批量删除管理员", zap.Strings("ids", ids))
return r.db.WithContext(ctx).Delete(&entities.Admin{}, "id IN ?", ids).Error
}
// List 获取管理员列表
func (r *GormAdminRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.Admin, error) {
var admins []entities.Admin
query := r.db.WithContext(ctx).Model(&entities.Admin{})
// 应用过滤条件
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
// 应用搜索条件
if options.Search != "" {
query = query.Where("username LIKE ? OR email LIKE ? OR real_name LIKE ?",
"%"+options.Search+"%", "%"+options.Search+"%", "%"+options.Search+"%")
}
// 应用排序
if options.Sort != "" {
order := "ASC"
if options.Order != "" {
order = options.Order
}
query = query.Order(options.Sort + " " + order)
}
// 应用分页
if options.Page > 0 && options.PageSize > 0 {
offset := (options.Page - 1) * options.PageSize
query = query.Offset(offset).Limit(options.PageSize)
}
return admins, query.Find(&admins).Error
}
// WithTx 使用事务
func (r *GormAdminRepository) WithTx(tx interface{}) interfaces.Repository[entities.Admin] {
if gormTx, ok := tx.(*gorm.DB); ok {
return &GormAdminRepository{
db: gormTx,
logger: r.logger,
}
}
return r
}
// FindByUsername 根据用户名查找管理员
func (r *GormAdminRepository) FindByUsername(ctx context.Context, username string) (*entities.Admin, error) {
var admin entities.Admin
err := r.db.WithContext(ctx).Where("username = ?", username).First(&admin).Error
if err != nil {
return nil, err
}
return &admin, nil
}
// FindByEmail 根据邮箱查找管理员
func (r *GormAdminRepository) FindByEmail(ctx context.Context, email string) (*entities.Admin, error) {
var admin entities.Admin
err := r.db.WithContext(ctx).Where("email = ?", email).First(&admin).Error
if err != nil {
return nil, err
}
return &admin, nil
}
// ListAdmins 获取管理员列表(带分页和筛选)
func (r *GormAdminRepository) ListAdmins(ctx context.Context, query *queries.ListAdminsQuery) ([]*entities.Admin, int64, error) {
var admins []entities.Admin
var total int64
dbQuery := r.db.WithContext(ctx).Model(&entities.Admin{})
// 应用筛选条件
if query.Username != "" {
dbQuery = dbQuery.Where("username LIKE ?", "%"+query.Username+"%")
}
if query.Email != "" {
dbQuery = dbQuery.Where("email LIKE ?", "%"+query.Email+"%")
}
if query.Role != "" {
dbQuery = dbQuery.Where("role = ?", query.Role)
}
if query.IsActive != nil {
dbQuery = dbQuery.Where("is_active = ?", *query.IsActive)
}
// 统计总数
if err := dbQuery.Count(&total).Error; err != nil {
return nil, 0, err
}
// 应用分页
offset := (query.Page - 1) * query.PageSize
dbQuery = dbQuery.Offset(offset).Limit(query.PageSize)
// 默认排序
dbQuery = dbQuery.Order("created_at DESC")
// 查询数据
if err := dbQuery.Find(&admins).Error; err != nil {
return nil, 0, err
}
// 转换为指针切片
adminPtrs := make([]*entities.Admin, len(admins))
for i := range admins {
adminPtrs[i] = &admins[i]
}
return adminPtrs, total, nil
}
// GetStats 获取管理员统计信息
func (r *GormAdminRepository) GetStats(ctx context.Context, query *queries.GetAdminInfoQuery) (*repositories.AdminStats, error) {
var stats repositories.AdminStats
// 总管理员数
if err := r.db.WithContext(ctx).Model(&entities.Admin{}).Count(&stats.TotalAdmins).Error; err != nil {
return nil, err
}
// 激活管理员数
if err := r.db.WithContext(ctx).Model(&entities.Admin{}).Where("is_active = ?", true).Count(&stats.ActiveAdmins).Error; err != nil {
return nil, err
}
// 今日登录数
today := time.Now().Truncate(24 * time.Hour)
if err := r.db.WithContext(ctx).Model(&entities.AdminLoginLog{}).Where("created_at >= ?", today).Count(&stats.TodayLogins).Error; err != nil {
return nil, err
}
// 总操作数
if err := r.db.WithContext(ctx).Model(&entities.AdminOperationLog{}).Count(&stats.TotalOperations).Error; err != nil {
return nil, err
}
return &stats, nil
}
// GetPermissionsByRole 根据角色获取权限
func (r *GormAdminRepository) GetPermissionsByRole(ctx context.Context, role entities.AdminRole) ([]entities.AdminPermission, error) {
var permissions []entities.AdminPermission
query := r.db.WithContext(ctx).
Joins("JOIN admin_role_permissions ON admin_permissions.id = admin_role_permissions.permission_id").
Where("admin_role_permissions.role = ? AND admin_permissions.is_active = ?", role, true)
return permissions, query.Find(&permissions).Error
}
// UpdatePermissions 更新管理员权限
func (r *GormAdminRepository) UpdatePermissions(ctx context.Context, adminID string, permissions []string) error {
permissionsJSON, err := json.Marshal(permissions)
if err != nil {
return fmt.Errorf("序列化权限失败: %w", err)
}
return r.db.WithContext(ctx).
Model(&entities.Admin{}).
Where("id = ?", adminID).
Update("permissions", string(permissionsJSON)).Error
}
// UpdateLoginStats 更新登录统计
func (r *GormAdminRepository) UpdateLoginStats(ctx context.Context, adminID string) error {
return r.db.WithContext(ctx).
Model(&entities.Admin{}).
Where("id = ?", adminID).
Updates(map[string]interface{}{
"last_login_at": time.Now(),
"login_count": gorm.Expr("login_count + 1"),
}).Error
}
// UpdateReviewStats 更新审核统计
func (r *GormAdminRepository) UpdateReviewStats(ctx context.Context, adminID string, approved bool) error {
updates := map[string]interface{}{
"review_count": gorm.Expr("review_count + 1"),
}
if approved {
updates["approved_count"] = gorm.Expr("approved_count + 1")
} else {
updates["rejected_count"] = gorm.Expr("rejected_count + 1")
}
return r.db.WithContext(ctx).
Model(&entities.Admin{}).
Where("id = ?", adminID).
Updates(updates).Error
}

View File

@@ -0,0 +1,353 @@
package repositories
import (
"context"
"time"
"go.uber.org/zap"
"gorm.io/gorm"
"tyapi-server/internal/domains/certification/entities"
"tyapi-server/internal/domains/certification/repositories"
"tyapi-server/internal/domains/certification/repositories/queries"
"tyapi-server/internal/shared/interfaces"
)
// GormCertificationRepository GORM认证仓储实现
type GormCertificationRepository struct {
db *gorm.DB
logger *zap.Logger
}
// 编译时检查接口实现
var _ repositories.CertificationRepository = (*GormCertificationRepository)(nil)
// NewGormCertificationRepository 创建GORM认证仓储
func NewGormCertificationRepository(db *gorm.DB, logger *zap.Logger) repositories.CertificationRepository {
return &GormCertificationRepository{
db: db,
logger: logger,
}
}
// ================ 基础CRUD操作 ================
// Create 创建认证申请
func (r *GormCertificationRepository) Create(ctx context.Context, cert entities.Certification) (entities.Certification, error) {
r.logger.Info("创建认证申请", zap.String("user_id", cert.UserID))
err := r.db.WithContext(ctx).Create(&cert).Error
return cert, err
}
// GetByID 根据ID获取认证申请
func (r *GormCertificationRepository) GetByID(ctx context.Context, id string) (entities.Certification, error) {
var cert entities.Certification
err := r.db.WithContext(ctx).Where("id = ?", id).First(&cert).Error
return cert, err
}
// Update 更新认证申请
func (r *GormCertificationRepository) Update(ctx context.Context, cert entities.Certification) error {
r.logger.Info("更新认证申请", zap.String("id", cert.ID))
return r.db.WithContext(ctx).Save(&cert).Error
}
// Delete 删除认证申请
func (r *GormCertificationRepository) Delete(ctx context.Context, id string) error {
r.logger.Info("删除认证申请", zap.String("id", id))
return r.db.WithContext(ctx).Delete(&entities.Certification{}, "id = ?", id).Error
}
// SoftDelete 软删除认证申请
func (r *GormCertificationRepository) SoftDelete(ctx context.Context, id string) error {
r.logger.Info("软删除认证申请", zap.String("id", id))
return r.db.WithContext(ctx).Delete(&entities.Certification{}, "id = ?", id).Error
}
// Restore 恢复认证申请
func (r *GormCertificationRepository) Restore(ctx context.Context, id string) error {
r.logger.Info("恢复认证申请", zap.String("id", id))
return r.db.WithContext(ctx).Unscoped().Model(&entities.Certification{}).Where("id = ?", id).Update("deleted_at", nil).Error
}
// Count 统计认证申请数量
func (r *GormCertificationRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
var count int64
query := r.db.WithContext(ctx).Model(&entities.Certification{})
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
if options.Search != "" {
query = query.Where("user_id LIKE ?", "%"+options.Search+"%")
}
err := query.Count(&count).Error
return count, err
}
// Exists 检查认证申请是否存在
func (r *GormCertificationRepository) Exists(ctx context.Context, id string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.Certification{}).Where("id = ?", id).Count(&count).Error
return count > 0, err
}
// CreateBatch 批量创建认证申请
func (r *GormCertificationRepository) CreateBatch(ctx context.Context, certs []entities.Certification) error {
r.logger.Info("批量创建认证申请", zap.Int("count", len(certs)))
return r.db.WithContext(ctx).Create(&certs).Error
}
// GetByIDs 根据ID列表获取认证申请
func (r *GormCertificationRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.Certification, error) {
var certs []entities.Certification
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&certs).Error
return certs, err
}
// UpdateBatch 批量更新认证申请
func (r *GormCertificationRepository) UpdateBatch(ctx context.Context, certs []entities.Certification) error {
r.logger.Info("批量更新认证申请", zap.Int("count", len(certs)))
return r.db.WithContext(ctx).Save(&certs).Error
}
// DeleteBatch 批量删除认证申请
func (r *GormCertificationRepository) DeleteBatch(ctx context.Context, ids []string) error {
r.logger.Info("批量删除认证申请", zap.Strings("ids", ids))
return r.db.WithContext(ctx).Delete(&entities.Certification{}, "id IN ?", ids).Error
}
// List 获取认证申请列表
func (r *GormCertificationRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.Certification, error) {
var certs []entities.Certification
query := r.db.WithContext(ctx).Model(&entities.Certification{})
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
if options.Search != "" {
query = query.Where("user_id LIKE ?", "%"+options.Search+"%")
}
if options.Sort != "" {
order := "ASC"
if options.Order != "" {
order = options.Order
}
query = query.Order(options.Sort + " " + order)
}
if options.Page > 0 && options.PageSize > 0 {
offset := (options.Page - 1) * options.PageSize
query = query.Offset(offset).Limit(options.PageSize)
}
return certs, query.Find(&certs).Error
}
// WithTx 使用事务
func (r *GormCertificationRepository) WithTx(tx interface{}) interfaces.Repository[entities.Certification] {
if gormTx, ok := tx.(*gorm.DB); ok {
return &GormCertificationRepository{
db: gormTx,
logger: r.logger,
}
}
return r
}
// ================ 业务方法 ================
// ListCertifications 获取认证申请列表(带分页和筛选)
func (r *GormCertificationRepository) ListCertifications(ctx context.Context, query *queries.ListCertificationsQuery) ([]*entities.Certification, int64, error) {
var certs []entities.Certification
var total int64
dbQuery := r.db.WithContext(ctx).Model(&entities.Certification{})
// 应用筛选条件
if query.UserID != "" {
dbQuery = dbQuery.Where("user_id = ?", query.UserID)
}
if query.Status != "" {
dbQuery = dbQuery.Where("status = ?", query.Status)
}
if query.AdminID != "" {
dbQuery = dbQuery.Where("admin_id = ?", query.AdminID)
}
if query.StartDate != "" {
dbQuery = dbQuery.Where("created_at >= ?", query.StartDate)
}
if query.EndDate != "" {
dbQuery = dbQuery.Where("created_at <= ?", query.EndDate)
}
if query.EnterpriseName != "" {
dbQuery = dbQuery.Joins("JOIN enterprises ON certifications.enterprise_id = enterprises.id").
Where("enterprises.enterprise_name LIKE ?", "%"+query.EnterpriseName+"%")
}
// 统计总数
if err := dbQuery.Count(&total).Error; err != nil {
return nil, 0, err
}
// 应用分页
offset := (query.Page - 1) * query.PageSize
dbQuery = dbQuery.Offset(offset).Limit(query.PageSize)
// 默认排序
dbQuery = dbQuery.Order("created_at DESC")
// 查询数据
if err := dbQuery.Find(&certs).Error; err != nil {
return nil, 0, err
}
// 转换为指针切片
certPtrs := make([]*entities.Certification, len(certs))
for i := range certs {
certPtrs[i] = &certs[i]
}
return certPtrs, total, nil
}
// GetByUserID 根据用户ID获取认证申请
func (r *GormCertificationRepository) GetByUserID(ctx context.Context, userID string) (*entities.Certification, error) {
var cert entities.Certification
err := r.db.WithContext(ctx).Where("user_id = ?", userID).First(&cert).Error
if err != nil {
return nil, err
}
return &cert, nil
}
// GetByStatus 根据状态获取认证申请列表
func (r *GormCertificationRepository) GetByStatus(ctx context.Context, status string) ([]*entities.Certification, error) {
var certs []entities.Certification
err := r.db.WithContext(ctx).Where("status = ?", status).Find(&certs).Error
if err != nil {
return nil, err
}
certPtrs := make([]*entities.Certification, len(certs))
for i := range certs {
certPtrs[i] = &certs[i]
}
return certPtrs, nil
}
// UpdateStatus 更新认证状态
func (r *GormCertificationRepository) UpdateStatus(ctx context.Context, certificationID string, status string, adminID *string, notes string) error {
updates := map[string]interface{}{
"status": status,
}
if adminID != nil {
updates["admin_id"] = *adminID
}
if notes != "" {
updates["approval_notes"] = notes
}
// 根据状态设置相应的时间戳
switch status {
case "INFO_SUBMITTED":
updates["info_submitted_at"] = time.Now()
case "FACE_VERIFIED":
updates["face_verified_at"] = time.Now()
case "CONTRACT_APPLIED":
updates["contract_applied_at"] = time.Now()
case "CONTRACT_APPROVED":
updates["contract_approved_at"] = time.Now()
case "CONTRACT_SIGNED":
updates["contract_signed_at"] = time.Now()
case "COMPLETED":
updates["completed_at"] = time.Now()
}
return r.db.WithContext(ctx).
Model(&entities.Certification{}).
Where("id = ?", certificationID).
Updates(updates).Error
}
// GetPendingCertifications 获取待审核的认证申请
func (r *GormCertificationRepository) GetPendingCertifications(ctx context.Context) ([]*entities.Certification, error) {
return r.GetByStatus(ctx, "CONTRACT_PENDING")
}
// GetStats 获取认证统计信息
func (r *GormCertificationRepository) GetStats(ctx context.Context) (*repositories.CertificationStats, error) {
var stats repositories.CertificationStats
// 总认证申请数
if err := r.db.WithContext(ctx).Model(&entities.Certification{}).Count(&stats.TotalCertifications).Error; err != nil {
return nil, err
}
// 待审核认证申请数
if err := r.db.WithContext(ctx).Model(&entities.Certification{}).Where("status = ?", "CONTRACT_PENDING").Count(&stats.PendingCertifications).Error; err != nil {
return nil, err
}
// 已完成认证申请数
if err := r.db.WithContext(ctx).Model(&entities.Certification{}).Where("status = ?", "COMPLETED").Count(&stats.CompletedCertifications).Error; err != nil {
return nil, err
}
// 被拒绝认证申请数
if err := r.db.WithContext(ctx).Model(&entities.Certification{}).Where("status = ?", "REJECTED").Count(&stats.RejectedCertifications).Error; err != nil {
return nil, err
}
// 今日提交数
today := time.Now().Truncate(24 * time.Hour)
if err := r.db.WithContext(ctx).Model(&entities.Certification{}).Where("created_at >= ?", today).Count(&stats.TodaySubmissions).Error; err != nil {
return nil, err
}
return &stats, nil
}
// GetStatsByDateRange 根据日期范围获取认证统计信息
func (r *GormCertificationRepository) GetStatsByDateRange(ctx context.Context, startDate, endDate string) (*repositories.CertificationStats, error) {
var stats repositories.CertificationStats
// 总认证申请数
if err := r.db.WithContext(ctx).Model(&entities.Certification{}).Where("created_at BETWEEN ? AND ?", startDate, endDate).Count(&stats.TotalCertifications).Error; err != nil {
return nil, err
}
// 待审核认证申请数
if err := r.db.WithContext(ctx).Model(&entities.Certification{}).Where("status = ? AND created_at BETWEEN ? AND ?", "CONTRACT_PENDING", startDate, endDate).Count(&stats.PendingCertifications).Error; err != nil {
return nil, err
}
// 已完成认证申请数
if err := r.db.WithContext(ctx).Model(&entities.Certification{}).Where("status = ? AND created_at BETWEEN ? AND ?", "COMPLETED", startDate, endDate).Count(&stats.CompletedCertifications).Error; err != nil {
return nil, err
}
// 被拒绝认证申请数
if err := r.db.WithContext(ctx).Model(&entities.Certification{}).Where("status = ? AND created_at BETWEEN ? AND ?", "REJECTED", startDate, endDate).Count(&stats.RejectedCertifications).Error; err != nil {
return nil, err
}
// 今日提交数
today := time.Now().Truncate(24 * time.Hour)
if err := r.db.WithContext(ctx).Model(&entities.Certification{}).Where("created_at >= ?", today).Count(&stats.TodaySubmissions).Error; err != nil {
return nil, err
}
return &stats, nil
}

View File

@@ -0,0 +1,422 @@
package repositories
import (
"context"
"fmt"
"go.uber.org/zap"
"gorm.io/gorm"
"tyapi-server/internal/domains/certification/entities"
"tyapi-server/internal/domains/certification/repositories"
"tyapi-server/internal/domains/certification/repositories/queries"
"tyapi-server/internal/shared/interfaces"
)
// GormContractRecordRepository GORM合同记录仓储实现
type GormContractRecordRepository struct {
db *gorm.DB
logger *zap.Logger
}
// 编译时检查接口实现
var _ repositories.ContractRecordRepository = (*GormContractRecordRepository)(nil)
// NewGormContractRecordRepository 创建GORM合同记录仓储
func NewGormContractRecordRepository(db *gorm.DB, logger *zap.Logger) repositories.ContractRecordRepository {
return &GormContractRecordRepository{
db: db,
logger: logger,
}
}
// ================ 基础CRUD操作 ================
// Create 创建合同记录
func (r *GormContractRecordRepository) Create(ctx context.Context, record entities.ContractRecord) (entities.ContractRecord, error) {
if err := r.db.WithContext(ctx).Create(&record).Error; err != nil {
r.logger.Error("创建合同记录失败",
zap.String("certification_id", record.CertificationID),
zap.String("contract_type", record.ContractType),
zap.Error(err),
)
return entities.ContractRecord{}, fmt.Errorf("创建合同记录失败: %w", err)
}
r.logger.Info("合同记录创建成功",
zap.String("id", record.ID),
zap.String("contract_type", record.ContractType),
)
return record, nil
}
// GetByID 根据ID获取合同记录
func (r *GormContractRecordRepository) GetByID(ctx context.Context, id string) (entities.ContractRecord, error) {
var record entities.ContractRecord
if err := r.db.WithContext(ctx).First(&record, "id = ?", id).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return entities.ContractRecord{}, fmt.Errorf("合同记录不存在")
}
r.logger.Error("获取合同记录失败",
zap.String("id", id),
zap.Error(err),
)
return entities.ContractRecord{}, fmt.Errorf("获取合同记录失败: %w", err)
}
return record, nil
}
// Update 更新合同记录
func (r *GormContractRecordRepository) Update(ctx context.Context, record entities.ContractRecord) error {
if err := r.db.WithContext(ctx).Save(&record).Error; err != nil {
r.logger.Error("更新合同记录失败",
zap.String("id", record.ID),
zap.Error(err),
)
return fmt.Errorf("更新合同记录失败: %w", err)
}
return nil
}
// Delete 删除合同记录
func (r *GormContractRecordRepository) Delete(ctx context.Context, id string) error {
if err := r.db.WithContext(ctx).Delete(&entities.ContractRecord{}, "id = ?", id).Error; err != nil {
r.logger.Error("删除合同记录失败",
zap.String("id", id),
zap.Error(err),
)
return fmt.Errorf("删除合同记录失败: %w", err)
}
return nil
}
// SoftDelete 软删除合同记录
func (r *GormContractRecordRepository) SoftDelete(ctx context.Context, id string) error {
return r.Delete(ctx, id)
}
// Restore 恢复合同记录
func (r *GormContractRecordRepository) Restore(ctx context.Context, id string) error {
if err := r.db.WithContext(ctx).Unscoped().Model(&entities.ContractRecord{}).Where("id = ?", id).Update("deleted_at", nil).Error; err != nil {
r.logger.Error("恢复合同记录失败",
zap.String("id", id),
zap.Error(err),
)
return fmt.Errorf("恢复合同记录失败: %w", err)
}
r.logger.Info("合同记录恢复成功", zap.String("id", id))
return nil
}
// Count 统计合同记录数量
func (r *GormContractRecordRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
var count int64
query := r.db.WithContext(ctx).Model(&entities.ContractRecord{})
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
if options.Search != "" {
query = query.Where("contract_type LIKE ? OR contract_name LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
}
return count, query.Count(&count).Error
}
// Exists 检查合同记录是否存在
func (r *GormContractRecordRepository) Exists(ctx context.Context, id string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.ContractRecord{}).Where("id = ?", id).Count(&count).Error
return count > 0, err
}
// CreateBatch 批量创建合同记录
func (r *GormContractRecordRepository) CreateBatch(ctx context.Context, records []entities.ContractRecord) error {
r.logger.Info("批量创建合同记录", zap.Int("count", len(records)))
return r.db.WithContext(ctx).Create(&records).Error
}
// GetByIDs 根据ID列表获取合同记录
func (r *GormContractRecordRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.ContractRecord, error) {
var records []entities.ContractRecord
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&records).Error
return records, err
}
// UpdateBatch 批量更新合同记录
func (r *GormContractRecordRepository) UpdateBatch(ctx context.Context, records []entities.ContractRecord) error {
r.logger.Info("批量更新合同记录", zap.Int("count", len(records)))
return r.db.WithContext(ctx).Save(&records).Error
}
// DeleteBatch 批量删除合同记录
func (r *GormContractRecordRepository) DeleteBatch(ctx context.Context, ids []string) error {
r.logger.Info("批量删除合同记录", zap.Strings("ids", ids))
return r.db.WithContext(ctx).Delete(&entities.ContractRecord{}, "id IN ?", ids).Error
}
// List 获取合同记录列表
func (r *GormContractRecordRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.ContractRecord, error) {
var records []entities.ContractRecord
query := r.db.WithContext(ctx).Model(&entities.ContractRecord{})
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
if options.Search != "" {
query = query.Where("contract_type LIKE ? OR contract_name LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
}
if options.Sort != "" {
order := "ASC"
if options.Order != "" {
order = options.Order
}
query = query.Order(options.Sort + " " + order)
}
if options.Page > 0 && options.PageSize > 0 {
offset := (options.Page - 1) * options.PageSize
query = query.Offset(offset).Limit(options.PageSize)
}
return records, query.Find(&records).Error
}
// WithTx 使用事务
func (r *GormContractRecordRepository) WithTx(tx interface{}) interfaces.Repository[entities.ContractRecord] {
if gormTx, ok := tx.(*gorm.DB); ok {
return &GormContractRecordRepository{
db: gormTx,
logger: r.logger,
}
}
return r
}
// ================ 业务方法 ================
// GetByCertificationID 根据认证申请ID获取合同记录列表
func (r *GormContractRecordRepository) GetByCertificationID(ctx context.Context, certificationID string) ([]*entities.ContractRecord, error) {
var records []entities.ContractRecord
if err := r.db.WithContext(ctx).Where("certification_id = ?", certificationID).Order("created_at DESC").Find(&records).Error; err != nil {
r.logger.Error("根据认证申请ID获取合同记录失败",
zap.String("certification_id", certificationID),
zap.Error(err),
)
return nil, fmt.Errorf("获取合同记录失败: %w", err)
}
// 转换为指针切片
recordPtrs := make([]*entities.ContractRecord, len(records))
for i := range records {
recordPtrs[i] = &records[i]
}
return recordPtrs, nil
}
// GetLatestByCertificationID 根据认证申请ID获取最新的合同记录
func (r *GormContractRecordRepository) GetLatestByCertificationID(ctx context.Context, certificationID string) (*entities.ContractRecord, error) {
var record entities.ContractRecord
if err := r.db.WithContext(ctx).Where("certification_id = ?", certificationID).Order("created_at DESC").First(&record).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return nil, fmt.Errorf("合同记录不存在")
}
r.logger.Error("根据认证申请ID获取最新合同记录失败",
zap.String("certification_id", certificationID),
zap.Error(err),
)
return nil, fmt.Errorf("获取合同记录失败: %w", err)
}
return &record, nil
}
// ListRecords 获取合同记录列表(带分页和筛选)
func (r *GormContractRecordRepository) ListRecords(ctx context.Context, query *queries.ListContractRecordsQuery) ([]*entities.ContractRecord, int64, error) {
var records []entities.ContractRecord
var total int64
dbQuery := r.db.WithContext(ctx).Model(&entities.ContractRecord{})
// 应用筛选条件
if query.CertificationID != "" {
dbQuery = dbQuery.Where("certification_id = ?", query.CertificationID)
}
if query.UserID != "" {
dbQuery = dbQuery.Where("user_id = ?", query.UserID)
}
if query.Status != "" {
dbQuery = dbQuery.Where("status = ?", query.Status)
}
if query.StartDate != "" {
dbQuery = dbQuery.Where("created_at >= ?", query.StartDate)
}
if query.EndDate != "" {
dbQuery = dbQuery.Where("created_at <= ?", query.EndDate)
}
// 统计总数
if err := dbQuery.Count(&total).Error; err != nil {
return nil, 0, err
}
// 应用分页
offset := (query.Page - 1) * query.PageSize
dbQuery = dbQuery.Offset(offset).Limit(query.PageSize)
// 默认排序
dbQuery = dbQuery.Order("created_at DESC")
// 查询数据
if err := dbQuery.Find(&records).Error; err != nil {
return nil, 0, err
}
// 转换为指针切片
recordPtrs := make([]*entities.ContractRecord, len(records))
for i := range records {
recordPtrs[i] = &records[i]
}
return recordPtrs, total, nil
}
// UpdateContractStatus 更新合同状态
func (r *GormContractRecordRepository) UpdateContractStatus(ctx context.Context, recordID string, status string, adminID *string, notes string) error {
updates := map[string]interface{}{
"status": status,
}
if adminID != nil {
updates["admin_id"] = *adminID
}
if notes != "" {
updates["admin_notes"] = notes
}
if err := r.db.WithContext(ctx).
Model(&entities.ContractRecord{}).
Where("id = ?", recordID).
Updates(updates).Error; err != nil {
r.logger.Error("更新合同状态失败",
zap.String("record_id", recordID),
zap.String("status", status),
zap.Error(err),
)
return fmt.Errorf("更新合同状态失败: %w", err)
}
r.logger.Info("合同状态更新成功",
zap.String("record_id", recordID),
zap.String("status", status),
)
return nil
}
// GetByUserID 根据用户ID获取合同记录列表
func (r *GormContractRecordRepository) GetByUserID(ctx context.Context, userID string, page, pageSize int) ([]*entities.ContractRecord, int, error) {
var records []entities.ContractRecord
var total int64
query := r.db.WithContext(ctx).Model(&entities.ContractRecord{}).Where("user_id = ?", userID)
// 获取总数
if err := query.Count(&total).Error; err != nil {
r.logger.Error("获取用户合同记录总数失败", zap.Error(err))
return nil, 0, fmt.Errorf("获取合同记录总数失败: %w", err)
}
// 分页查询
offset := (page - 1) * pageSize
if err := query.Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&records).Error; err != nil {
r.logger.Error("获取用户合同记录列表失败", zap.Error(err))
return nil, 0, fmt.Errorf("获取合同记录列表失败: %w", err)
}
// 转换为指针切片
recordPtrs := make([]*entities.ContractRecord, len(records))
for i := range records {
recordPtrs[i] = &records[i]
}
return recordPtrs, int(total), nil
}
// GetByStatus 根据状态获取合同记录列表
func (r *GormContractRecordRepository) GetByStatus(ctx context.Context, status string, page, pageSize int) ([]*entities.ContractRecord, int, error) {
var records []entities.ContractRecord
var total int64
query := r.db.WithContext(ctx).Model(&entities.ContractRecord{}).Where("status = ?", status)
// 获取总数
if err := query.Count(&total).Error; err != nil {
r.logger.Error("根据状态获取合同记录总数失败", zap.Error(err))
return nil, 0, fmt.Errorf("获取合同记录总数失败: %w", err)
}
// 分页查询
offset := (page - 1) * pageSize
if err := query.Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&records).Error; err != nil {
r.logger.Error("根据状态获取合同记录列表失败", zap.Error(err))
return nil, 0, fmt.Errorf("获取合同记录列表失败: %w", err)
}
// 转换为指针切片
recordPtrs := make([]*entities.ContractRecord, len(records))
for i := range records {
recordPtrs[i] = &records[i]
}
return recordPtrs, int(total), nil
}
// GetPendingContracts 获取待审核的合同记录
func (r *GormContractRecordRepository) GetPendingContracts(ctx context.Context, page, pageSize int) ([]*entities.ContractRecord, int, error) {
return r.GetByStatus(ctx, "PENDING", page, pageSize)
}
// GetExpiredSigningContracts 获取签署链接已过期的合同记录
func (r *GormContractRecordRepository) GetExpiredSigningContracts(ctx context.Context, limit int) ([]*entities.ContractRecord, error) {
var records []entities.ContractRecord
if err := r.db.WithContext(ctx).
Where("expires_at < NOW() AND status = ?", "APPROVED").
Limit(limit).
Order("expires_at ASC").
Find(&records).Error; err != nil {
r.logger.Error("获取过期签署合同记录失败", zap.Error(err))
return nil, fmt.Errorf("获取过期签署合同记录失败: %w", err)
}
// 转换为指针切片
recordPtrs := make([]*entities.ContractRecord, len(records))
for i := range records {
recordPtrs[i] = &records[i]
}
return recordPtrs, nil
}
// GetExpiredContracts 获取已过期的合同记录(通用方法)
func (r *GormContractRecordRepository) GetExpiredContracts(ctx context.Context, limit int) ([]*entities.ContractRecord, error) {
return r.GetExpiredSigningContracts(ctx, limit)
}

View File

@@ -0,0 +1,394 @@
package repositories
import (
"context"
"fmt"
"go.uber.org/zap"
"gorm.io/gorm"
"tyapi-server/internal/domains/certification/entities"
"tyapi-server/internal/domains/certification/repositories"
"tyapi-server/internal/domains/certification/repositories/queries"
"tyapi-server/internal/shared/interfaces"
)
// GormFaceVerifyRecordRepository GORM人脸识别记录仓储实现
type GormFaceVerifyRecordRepository struct {
db *gorm.DB
logger *zap.Logger
}
// 编译时检查接口实现
var _ repositories.FaceVerifyRecordRepository = (*GormFaceVerifyRecordRepository)(nil)
// NewGormFaceVerifyRecordRepository 创建GORM人脸识别记录仓储
func NewGormFaceVerifyRecordRepository(db *gorm.DB, logger *zap.Logger) repositories.FaceVerifyRecordRepository {
return &GormFaceVerifyRecordRepository{
db: db,
logger: logger,
}
}
// ================ 基础CRUD操作 ================
// Create 创建人脸识别记录
func (r *GormFaceVerifyRecordRepository) Create(ctx context.Context, record entities.FaceVerifyRecord) (entities.FaceVerifyRecord, error) {
if err := r.db.WithContext(ctx).Create(&record).Error; err != nil {
r.logger.Error("创建人脸识别记录失败",
zap.String("certification_id", record.CertificationID),
zap.String("certify_id", record.CertifyID),
zap.Error(err),
)
return entities.FaceVerifyRecord{}, fmt.Errorf("创建人脸识别记录失败: %w", err)
}
r.logger.Info("人脸识别记录创建成功",
zap.String("id", record.ID),
zap.String("certify_id", record.CertifyID),
)
return record, nil
}
// GetByID 根据ID获取人脸识别记录
func (r *GormFaceVerifyRecordRepository) GetByID(ctx context.Context, id string) (entities.FaceVerifyRecord, error) {
var record entities.FaceVerifyRecord
if err := r.db.WithContext(ctx).First(&record, "id = ?", id).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return entities.FaceVerifyRecord{}, fmt.Errorf("人脸识别记录不存在")
}
r.logger.Error("获取人脸识别记录失败",
zap.String("id", id),
zap.Error(err),
)
return entities.FaceVerifyRecord{}, fmt.Errorf("获取人脸识别记录失败: %w", err)
}
return record, nil
}
// Update 更新人脸识别记录
func (r *GormFaceVerifyRecordRepository) Update(ctx context.Context, record entities.FaceVerifyRecord) error {
if err := r.db.WithContext(ctx).Save(&record).Error; err != nil {
r.logger.Error("更新人脸识别记录失败",
zap.String("id", record.ID),
zap.Error(err),
)
return fmt.Errorf("更新人脸识别记录失败: %w", err)
}
return nil
}
// Delete 删除人脸识别记录
func (r *GormFaceVerifyRecordRepository) Delete(ctx context.Context, id string) error {
if err := r.db.WithContext(ctx).Delete(&entities.FaceVerifyRecord{}, "id = ?", id).Error; err != nil {
r.logger.Error("删除人脸识别记录失败",
zap.String("id", id),
zap.Error(err),
)
return fmt.Errorf("删除人脸识别记录失败: %w", err)
}
return nil
}
// SoftDelete 软删除人脸识别记录
func (r *GormFaceVerifyRecordRepository) SoftDelete(ctx context.Context, id string) error {
return r.Delete(ctx, id)
}
// Restore 恢复人脸识别记录
func (r *GormFaceVerifyRecordRepository) Restore(ctx context.Context, id string) error {
if err := r.db.WithContext(ctx).Unscoped().Model(&entities.FaceVerifyRecord{}).Where("id = ?", id).Update("deleted_at", nil).Error; err != nil {
r.logger.Error("恢复人脸识别记录失败",
zap.String("id", id),
zap.Error(err),
)
return fmt.Errorf("恢复人脸识别记录失败: %w", err)
}
r.logger.Info("人脸识别记录恢复成功", zap.String("id", id))
return nil
}
// Count 统计人脸识别记录数量
func (r *GormFaceVerifyRecordRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
var count int64
query := r.db.WithContext(ctx).Model(&entities.FaceVerifyRecord{})
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
if options.Search != "" {
query = query.Where("certify_id LIKE ? OR user_id LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
}
return count, query.Count(&count).Error
}
// Exists 检查人脸识别记录是否存在
func (r *GormFaceVerifyRecordRepository) Exists(ctx context.Context, id string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.FaceVerifyRecord{}).Where("id = ?", id).Count(&count).Error
return count > 0, err
}
// CreateBatch 批量创建人脸识别记录
func (r *GormFaceVerifyRecordRepository) CreateBatch(ctx context.Context, records []entities.FaceVerifyRecord) error {
r.logger.Info("批量创建人脸识别记录", zap.Int("count", len(records)))
return r.db.WithContext(ctx).Create(&records).Error
}
// GetByIDs 根据ID列表获取人脸识别记录
func (r *GormFaceVerifyRecordRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.FaceVerifyRecord, error) {
var records []entities.FaceVerifyRecord
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&records).Error
return records, err
}
// UpdateBatch 批量更新人脸识别记录
func (r *GormFaceVerifyRecordRepository) UpdateBatch(ctx context.Context, records []entities.FaceVerifyRecord) error {
r.logger.Info("批量更新人脸识别记录", zap.Int("count", len(records)))
return r.db.WithContext(ctx).Save(&records).Error
}
// DeleteBatch 批量删除人脸识别记录
func (r *GormFaceVerifyRecordRepository) DeleteBatch(ctx context.Context, ids []string) error {
r.logger.Info("批量删除人脸识别记录", zap.Strings("ids", ids))
return r.db.WithContext(ctx).Delete(&entities.FaceVerifyRecord{}, "id IN ?", ids).Error
}
// List 获取人脸识别记录列表
func (r *GormFaceVerifyRecordRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.FaceVerifyRecord, error) {
var records []entities.FaceVerifyRecord
query := r.db.WithContext(ctx).Model(&entities.FaceVerifyRecord{})
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
if options.Search != "" {
query = query.Where("certify_id LIKE ? OR user_id LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
}
if options.Sort != "" {
order := "ASC"
if options.Order != "" {
order = options.Order
}
query = query.Order(options.Sort + " " + order)
}
if options.Page > 0 && options.PageSize > 0 {
offset := (options.Page - 1) * options.PageSize
query = query.Offset(offset).Limit(options.PageSize)
}
return records, query.Find(&records).Error
}
// WithTx 使用事务
func (r *GormFaceVerifyRecordRepository) WithTx(tx interface{}) interfaces.Repository[entities.FaceVerifyRecord] {
if gormTx, ok := tx.(*gorm.DB); ok {
return &GormFaceVerifyRecordRepository{
db: gormTx,
logger: r.logger,
}
}
return r
}
// ================ 业务方法 ================
// GetByCertificationID 根据认证申请ID获取人脸识别记录列表
func (r *GormFaceVerifyRecordRepository) GetByCertificationID(ctx context.Context, certificationID string) ([]*entities.FaceVerifyRecord, error) {
var records []entities.FaceVerifyRecord
if err := r.db.WithContext(ctx).Where("certification_id = ?", certificationID).Order("created_at DESC").Find(&records).Error; err != nil {
r.logger.Error("根据认证申请ID获取人脸识别记录失败",
zap.String("certification_id", certificationID),
zap.Error(err),
)
return nil, fmt.Errorf("获取人脸识别记录失败: %w", err)
}
// 转换为指针切片
recordPtrs := make([]*entities.FaceVerifyRecord, len(records))
for i := range records {
recordPtrs[i] = &records[i]
}
return recordPtrs, nil
}
// GetLatestByCertificationID 根据认证申请ID获取最新的人脸识别记录
func (r *GormFaceVerifyRecordRepository) GetLatestByCertificationID(ctx context.Context, certificationID string) (*entities.FaceVerifyRecord, error) {
var record entities.FaceVerifyRecord
if err := r.db.WithContext(ctx).Where("certification_id = ?", certificationID).Order("created_at DESC").First(&record).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return nil, fmt.Errorf("人脸识别记录不存在")
}
r.logger.Error("根据认证申请ID获取最新人脸识别记录失败",
zap.String("certification_id", certificationID),
zap.Error(err),
)
return nil, fmt.Errorf("获取人脸识别记录失败: %w", err)
}
return &record, nil
}
// ListRecords 获取人脸识别记录列表(带分页和筛选)
func (r *GormFaceVerifyRecordRepository) ListRecords(ctx context.Context, query *queries.ListFaceVerifyRecordsQuery) ([]*entities.FaceVerifyRecord, int64, error) {
var records []entities.FaceVerifyRecord
var total int64
dbQuery := r.db.WithContext(ctx).Model(&entities.FaceVerifyRecord{})
// 应用筛选条件
if query.CertificationID != "" {
dbQuery = dbQuery.Where("certification_id = ?", query.CertificationID)
}
if query.UserID != "" {
dbQuery = dbQuery.Where("user_id = ?", query.UserID)
}
if query.Status != "" {
dbQuery = dbQuery.Where("status = ?", query.Status)
}
if query.StartDate != "" {
dbQuery = dbQuery.Where("created_at >= ?", query.StartDate)
}
if query.EndDate != "" {
dbQuery = dbQuery.Where("created_at <= ?", query.EndDate)
}
// 统计总数
if err := dbQuery.Count(&total).Error; err != nil {
return nil, 0, err
}
// 应用分页
offset := (query.Page - 1) * query.PageSize
dbQuery = dbQuery.Offset(offset).Limit(query.PageSize)
// 默认排序
dbQuery = dbQuery.Order("created_at DESC")
// 查询数据
if err := dbQuery.Find(&records).Error; err != nil {
return nil, 0, err
}
// 转换为指针切片
recordPtrs := make([]*entities.FaceVerifyRecord, len(records))
for i := range records {
recordPtrs[i] = &records[i]
}
return recordPtrs, total, nil
}
// GetSuccessRate 获取成功率
func (r *GormFaceVerifyRecordRepository) GetSuccessRate(ctx context.Context, days int) (float64, error) {
var totalCount int64
var successCount int64
// 计算指定天数前的日期
startDate := fmt.Sprintf("DATE_SUB(NOW(), INTERVAL %d DAY)", days)
// 获取总数
if err := r.db.WithContext(ctx).Model(&entities.FaceVerifyRecord{}).
Where("created_at >= " + startDate).Count(&totalCount).Error; err != nil {
return 0, err
}
// 获取成功数
if err := r.db.WithContext(ctx).Model(&entities.FaceVerifyRecord{}).
Where("created_at >= "+startDate+" AND status = ?", "SUCCESS").Count(&successCount).Error; err != nil {
return 0, err
}
if totalCount == 0 {
return 0, nil
}
return float64(successCount) / float64(totalCount) * 100, nil
}
// GetByCertifyID 根据认证ID获取人脸识别记录
func (r *GormFaceVerifyRecordRepository) GetByCertifyID(ctx context.Context, certifyID string) (*entities.FaceVerifyRecord, error) {
var record entities.FaceVerifyRecord
if err := r.db.WithContext(ctx).First(&record, "certify_id = ?", certifyID).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return nil, fmt.Errorf("人脸识别记录不存在")
}
r.logger.Error("根据认证ID获取人脸识别记录失败",
zap.String("certify_id", certifyID),
zap.Error(err),
)
return nil, fmt.Errorf("获取人脸识别记录失败: %w", err)
}
return &record, nil
}
// GetByUserID 根据用户ID获取人脸识别记录列表
func (r *GormFaceVerifyRecordRepository) GetByUserID(ctx context.Context, userID string, page, pageSize int) ([]*entities.FaceVerifyRecord, int, error) {
var records []entities.FaceVerifyRecord
var total int64
query := r.db.WithContext(ctx).Model(&entities.FaceVerifyRecord{}).Where("user_id = ?", userID)
// 获取总数
if err := query.Count(&total).Error; err != nil {
r.logger.Error("获取用户人脸识别记录总数失败", zap.Error(err))
return nil, 0, fmt.Errorf("获取人脸识别记录总数失败: %w", err)
}
// 分页查询
offset := (page - 1) * pageSize
if err := query.Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&records).Error; err != nil {
r.logger.Error("获取用户人脸识别记录列表失败", zap.Error(err))
return nil, 0, fmt.Errorf("获取人脸识别记录列表失败: %w", err)
}
// 转换为指针切片
recordPtrs := make([]*entities.FaceVerifyRecord, len(records))
for i := range records {
recordPtrs[i] = &records[i]
}
return recordPtrs, int(total), nil
}
// GetExpiredRecords 获取已过期的人脸识别记录
func (r *GormFaceVerifyRecordRepository) GetExpiredRecords(ctx context.Context, limit int) ([]*entities.FaceVerifyRecord, error) {
var records []entities.FaceVerifyRecord
if err := r.db.WithContext(ctx).
Where("expires_at < NOW() AND status = ?", "PROCESSING").
Limit(limit).
Order("expires_at ASC").
Find(&records).Error; err != nil {
r.logger.Error("获取过期人脸识别记录失败", zap.Error(err))
return nil, fmt.Errorf("获取过期人脸识别记录失败: %w", err)
}
// 转换为指针切片
recordPtrs := make([]*entities.FaceVerifyRecord, len(records))
for i := range records {
recordPtrs[i] = &records[i]
}
return recordPtrs, nil
}

View File

@@ -0,0 +1,374 @@
package repositories
import (
"context"
"fmt"
"go.uber.org/zap"
"gorm.io/gorm"
"tyapi-server/internal/domains/certification/entities"
"tyapi-server/internal/domains/certification/repositories"
"tyapi-server/internal/domains/certification/repositories/queries"
"tyapi-server/internal/shared/interfaces"
)
// GormLicenseUploadRecordRepository GORM营业执照上传记录仓储实现
type GormLicenseUploadRecordRepository struct {
db *gorm.DB
logger *zap.Logger
}
// 编译时检查接口实现
var _ repositories.LicenseUploadRecordRepository = (*GormLicenseUploadRecordRepository)(nil)
// NewGormLicenseUploadRecordRepository 创建GORM营业执照上传记录仓储
func NewGormLicenseUploadRecordRepository(db *gorm.DB, logger *zap.Logger) repositories.LicenseUploadRecordRepository {
return &GormLicenseUploadRecordRepository{
db: db,
logger: logger,
}
}
// ================ 基础CRUD操作 ================
// Create 创建上传记录
func (r *GormLicenseUploadRecordRepository) Create(ctx context.Context, record entities.LicenseUploadRecord) (entities.LicenseUploadRecord, error) {
if err := r.db.WithContext(ctx).Create(&record).Error; err != nil {
r.logger.Error("创建上传记录失败",
zap.String("user_id", record.UserID),
zap.String("file_name", record.OriginalFileName),
zap.Error(err),
)
return entities.LicenseUploadRecord{}, fmt.Errorf("创建上传记录失败: %w", err)
}
r.logger.Info("上传记录创建成功",
zap.String("id", record.ID),
zap.String("file_name", record.OriginalFileName),
)
return record, nil
}
// GetByID 根据ID获取上传记录
func (r *GormLicenseUploadRecordRepository) GetByID(ctx context.Context, id string) (entities.LicenseUploadRecord, error) {
var record entities.LicenseUploadRecord
if err := r.db.WithContext(ctx).First(&record, "id = ?", id).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return entities.LicenseUploadRecord{}, fmt.Errorf("上传记录不存在")
}
r.logger.Error("获取上传记录失败",
zap.String("id", id),
zap.Error(err),
)
return entities.LicenseUploadRecord{}, fmt.Errorf("获取上传记录失败: %w", err)
}
return record, nil
}
// Update 更新上传记录
func (r *GormLicenseUploadRecordRepository) Update(ctx context.Context, record entities.LicenseUploadRecord) error {
if err := r.db.WithContext(ctx).Save(&record).Error; err != nil {
r.logger.Error("更新上传记录失败",
zap.String("id", record.ID),
zap.Error(err),
)
return fmt.Errorf("更新上传记录失败: %w", err)
}
return nil
}
// Delete 删除上传记录
func (r *GormLicenseUploadRecordRepository) Delete(ctx context.Context, id string) error {
if err := r.db.WithContext(ctx).Delete(&entities.LicenseUploadRecord{}, "id = ?", id).Error; err != nil {
r.logger.Error("删除上传记录失败",
zap.String("id", id),
zap.Error(err),
)
return fmt.Errorf("删除上传记录失败: %w", err)
}
return nil
}
// SoftDelete 软删除上传记录
func (r *GormLicenseUploadRecordRepository) SoftDelete(ctx context.Context, id string) error {
return r.Delete(ctx, id)
}
// Restore 恢复上传记录
func (r *GormLicenseUploadRecordRepository) Restore(ctx context.Context, id string) error {
if err := r.db.WithContext(ctx).Unscoped().Model(&entities.LicenseUploadRecord{}).Where("id = ?", id).Update("deleted_at", nil).Error; err != nil {
r.logger.Error("恢复上传记录失败",
zap.String("id", id),
zap.Error(err),
)
return fmt.Errorf("恢复上传记录失败: %w", err)
}
r.logger.Info("上传记录恢复成功", zap.String("id", id))
return nil
}
// Count 统计上传记录数量
func (r *GormLicenseUploadRecordRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
var count int64
query := r.db.WithContext(ctx).Model(&entities.LicenseUploadRecord{})
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
if options.Search != "" {
query = query.Where("original_file_name LIKE ? OR user_id LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
}
return count, query.Count(&count).Error
}
// Exists 检查上传记录是否存在
func (r *GormLicenseUploadRecordRepository) Exists(ctx context.Context, id string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.LicenseUploadRecord{}).Where("id = ?", id).Count(&count).Error
return count > 0, err
}
// CreateBatch 批量创建上传记录
func (r *GormLicenseUploadRecordRepository) CreateBatch(ctx context.Context, records []entities.LicenseUploadRecord) error {
r.logger.Info("批量创建上传记录", zap.Int("count", len(records)))
return r.db.WithContext(ctx).Create(&records).Error
}
// GetByIDs 根据ID列表获取上传记录
func (r *GormLicenseUploadRecordRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.LicenseUploadRecord, error) {
var records []entities.LicenseUploadRecord
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&records).Error
return records, err
}
// UpdateBatch 批量更新上传记录
func (r *GormLicenseUploadRecordRepository) UpdateBatch(ctx context.Context, records []entities.LicenseUploadRecord) error {
r.logger.Info("批量更新上传记录", zap.Int("count", len(records)))
return r.db.WithContext(ctx).Save(&records).Error
}
// DeleteBatch 批量删除上传记录
func (r *GormLicenseUploadRecordRepository) DeleteBatch(ctx context.Context, ids []string) error {
r.logger.Info("批量删除上传记录", zap.Strings("ids", ids))
return r.db.WithContext(ctx).Delete(&entities.LicenseUploadRecord{}, "id IN ?", ids).Error
}
// List 获取上传记录列表
func (r *GormLicenseUploadRecordRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.LicenseUploadRecord, error) {
var records []entities.LicenseUploadRecord
query := r.db.WithContext(ctx).Model(&entities.LicenseUploadRecord{})
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
if options.Search != "" {
query = query.Where("original_file_name LIKE ? OR user_id LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
}
if options.Sort != "" {
order := "ASC"
if options.Order != "" {
order = options.Order
}
query = query.Order(options.Sort + " " + order)
}
if options.Page > 0 && options.PageSize > 0 {
offset := (options.Page - 1) * options.PageSize
query = query.Offset(offset).Limit(options.PageSize)
}
return records, query.Find(&records).Error
}
// WithTx 使用事务
func (r *GormLicenseUploadRecordRepository) WithTx(tx interface{}) interfaces.Repository[entities.LicenseUploadRecord] {
if gormTx, ok := tx.(*gorm.DB); ok {
return &GormLicenseUploadRecordRepository{
db: gormTx,
logger: r.logger,
}
}
return r
}
// ================ 业务方法 ================
// GetByCertificationID 根据认证ID获取上传记录
func (r *GormLicenseUploadRecordRepository) GetByCertificationID(ctx context.Context, certificationID string) (*entities.LicenseUploadRecord, error) {
var record entities.LicenseUploadRecord
if err := r.db.WithContext(ctx).First(&record, "certification_id = ?", certificationID).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return nil, fmt.Errorf("上传记录不存在")
}
r.logger.Error("根据认证ID获取上传记录失败",
zap.String("certification_id", certificationID),
zap.Error(err),
)
return nil, fmt.Errorf("获取上传记录失败: %w", err)
}
return &record, nil
}
// ListRecords 获取上传记录列表(带分页和筛选)
func (r *GormLicenseUploadRecordRepository) ListRecords(ctx context.Context, query *queries.ListLicenseUploadRecordsQuery) ([]*entities.LicenseUploadRecord, int64, error) {
var records []entities.LicenseUploadRecord
var total int64
dbQuery := r.db.WithContext(ctx).Model(&entities.LicenseUploadRecord{})
// 应用筛选条件
if query.CertificationID != "" {
dbQuery = dbQuery.Where("certification_id = ?", query.CertificationID)
}
if query.UserID != "" {
dbQuery = dbQuery.Where("user_id = ?", query.UserID)
}
if query.Status != "" {
dbQuery = dbQuery.Where("status = ?", query.Status)
}
if query.StartDate != "" {
dbQuery = dbQuery.Where("created_at >= ?", query.StartDate)
}
if query.EndDate != "" {
dbQuery = dbQuery.Where("created_at <= ?", query.EndDate)
}
// 统计总数
if err := dbQuery.Count(&total).Error; err != nil {
return nil, 0, err
}
// 应用分页
offset := (query.Page - 1) * query.PageSize
dbQuery = dbQuery.Offset(offset).Limit(query.PageSize)
// 默认排序
dbQuery = dbQuery.Order("created_at DESC")
// 查询数据
if err := dbQuery.Find(&records).Error; err != nil {
return nil, 0, err
}
// 转换为指针切片
recordPtrs := make([]*entities.LicenseUploadRecord, len(records))
for i := range records {
recordPtrs[i] = &records[i]
}
return recordPtrs, total, nil
}
// UpdateOCRResult 更新OCR结果
func (r *GormLicenseUploadRecordRepository) UpdateOCRResult(ctx context.Context, recordID string, ocrResult string, confidence float64) error {
updates := map[string]interface{}{
"ocr_result": ocrResult,
"ocr_confidence": confidence,
"ocr_processed": true,
"ocr_success": true,
}
if err := r.db.WithContext(ctx).
Model(&entities.LicenseUploadRecord{}).
Where("id = ?", recordID).
Updates(updates).Error; err != nil {
r.logger.Error("更新OCR结果失败",
zap.String("record_id", recordID),
zap.Error(err),
)
return fmt.Errorf("更新OCR结果失败: %w", err)
}
r.logger.Info("OCR结果更新成功",
zap.String("record_id", recordID),
zap.Float64("confidence", confidence),
)
return nil
}
// GetByUserID 根据用户ID获取上传记录列表
func (r *GormLicenseUploadRecordRepository) GetByUserID(ctx context.Context, userID string, page, pageSize int) ([]*entities.LicenseUploadRecord, int, error) {
var records []entities.LicenseUploadRecord
var total int64
query := r.db.WithContext(ctx).Model(&entities.LicenseUploadRecord{}).Where("user_id = ?", userID)
// 获取总数
if err := query.Count(&total).Error; err != nil {
r.logger.Error("获取用户上传记录总数失败", zap.Error(err))
return nil, 0, fmt.Errorf("获取上传记录总数失败: %w", err)
}
// 分页查询
offset := (page - 1) * pageSize
if err := query.Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&records).Error; err != nil {
r.logger.Error("获取用户上传记录列表失败", zap.Error(err))
return nil, 0, fmt.Errorf("获取上传记录列表失败: %w", err)
}
// 转换为指针切片
recordPtrs := make([]*entities.LicenseUploadRecord, len(records))
for i := range records {
recordPtrs[i] = &records[i]
}
return recordPtrs, int(total), nil
}
// GetByQiNiuKey 根据七牛云Key获取上传记录
func (r *GormLicenseUploadRecordRepository) GetByQiNiuKey(ctx context.Context, key string) (*entities.LicenseUploadRecord, error) {
var record entities.LicenseUploadRecord
if err := r.db.WithContext(ctx).First(&record, "qiniu_key = ?", key).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return nil, fmt.Errorf("上传记录不存在")
}
r.logger.Error("根据七牛云Key获取上传记录失败",
zap.String("qiniu_key", key),
zap.Error(err),
)
return nil, fmt.Errorf("获取上传记录失败: %w", err)
}
return &record, nil
}
// GetPendingOCR 获取待OCR处理的上传记录
func (r *GormLicenseUploadRecordRepository) GetPendingOCR(ctx context.Context, limit int) ([]*entities.LicenseUploadRecord, error) {
var records []entities.LicenseUploadRecord
if err := r.db.WithContext(ctx).
Where("ocr_processed = ? OR (ocr_processed = ? AND ocr_success = ?)", false, true, false).
Limit(limit).
Order("created_at ASC").
Find(&records).Error; err != nil {
r.logger.Error("获取待OCR处理记录失败", zap.Error(err))
return nil, fmt.Errorf("获取待OCR处理记录失败: %w", err)
}
// 转换为指针切片
recordPtrs := make([]*entities.LicenseUploadRecord, len(records))
for i := range records {
recordPtrs[i] = &records[i]
}
return recordPtrs, nil
}

View File

@@ -0,0 +1,344 @@
package repositories
import (
"context"
"fmt"
"go.uber.org/zap"
"gorm.io/gorm"
"tyapi-server/internal/domains/certification/entities"
"tyapi-server/internal/domains/certification/repositories"
"tyapi-server/internal/domains/certification/repositories/queries"
"tyapi-server/internal/shared/interfaces"
)
// GormNotificationRecordRepository GORM通知记录仓储实现
type GormNotificationRecordRepository struct {
db *gorm.DB
logger *zap.Logger
}
// 编译时检查接口实现
var _ repositories.NotificationRecordRepository = (*GormNotificationRecordRepository)(nil)
// NewGormNotificationRecordRepository 创建GORM通知记录仓储
func NewGormNotificationRecordRepository(db *gorm.DB, logger *zap.Logger) repositories.NotificationRecordRepository {
return &GormNotificationRecordRepository{
db: db,
logger: logger,
}
}
// ================ 基础CRUD操作 ================
// Create 创建通知记录
func (r *GormNotificationRecordRepository) Create(ctx context.Context, record entities.NotificationRecord) (entities.NotificationRecord, error) {
if err := r.db.WithContext(ctx).Create(&record).Error; err != nil {
r.logger.Error("创建通知记录失败",
zap.String("user_id", *record.UserID),
zap.String("type", record.NotificationType),
zap.Error(err),
)
return entities.NotificationRecord{}, fmt.Errorf("创建通知记录失败: %w", err)
}
r.logger.Info("通知记录创建成功",
zap.String("id", record.ID),
zap.String("type", record.NotificationType),
)
return record, nil
}
// GetByID 根据ID获取通知记录
func (r *GormNotificationRecordRepository) GetByID(ctx context.Context, id string) (entities.NotificationRecord, error) {
var record entities.NotificationRecord
if err := r.db.WithContext(ctx).First(&record, "id = ?", id).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return entities.NotificationRecord{}, fmt.Errorf("通知记录不存在")
}
r.logger.Error("获取通知记录失败",
zap.String("id", id),
zap.Error(err),
)
return entities.NotificationRecord{}, fmt.Errorf("获取通知记录失败: %w", err)
}
return record, nil
}
// Update 更新通知记录
func (r *GormNotificationRecordRepository) Update(ctx context.Context, record entities.NotificationRecord) error {
if err := r.db.WithContext(ctx).Save(&record).Error; err != nil {
r.logger.Error("更新通知记录失败",
zap.String("id", record.ID),
zap.Error(err),
)
return fmt.Errorf("更新通知记录失败: %w", err)
}
return nil
}
// Delete 删除通知记录
func (r *GormNotificationRecordRepository) Delete(ctx context.Context, id string) error {
if err := r.db.WithContext(ctx).Delete(&entities.NotificationRecord{}, "id = ?", id).Error; err != nil {
r.logger.Error("删除通知记录失败",
zap.String("id", id),
zap.Error(err),
)
return fmt.Errorf("删除通知记录失败: %w", err)
}
return nil
}
// SoftDelete 软删除通知记录
func (r *GormNotificationRecordRepository) SoftDelete(ctx context.Context, id string) error {
return r.Delete(ctx, id)
}
// Restore 恢复通知记录
func (r *GormNotificationRecordRepository) Restore(ctx context.Context, id string) error {
if err := r.db.WithContext(ctx).Unscoped().Model(&entities.NotificationRecord{}).Where("id = ?", id).Update("deleted_at", nil).Error; err != nil {
r.logger.Error("恢复通知记录失败",
zap.String("id", id),
zap.Error(err),
)
return fmt.Errorf("恢复通知记录失败: %w", err)
}
r.logger.Info("通知记录恢复成功", zap.String("id", id))
return nil
}
// Count 统计通知记录数量
func (r *GormNotificationRecordRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
var count int64
query := r.db.WithContext(ctx).Model(&entities.NotificationRecord{})
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
if options.Search != "" {
query = query.Where("title LIKE ? OR content LIKE ? OR user_id LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%", "%"+options.Search+"%")
}
return count, query.Count(&count).Error
}
// Exists 检查通知记录是否存在
func (r *GormNotificationRecordRepository) Exists(ctx context.Context, id string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.NotificationRecord{}).Where("id = ?", id).Count(&count).Error
return count > 0, err
}
// CreateBatch 批量创建通知记录
func (r *GormNotificationRecordRepository) CreateBatch(ctx context.Context, records []entities.NotificationRecord) error {
r.logger.Info("批量创建通知记录", zap.Int("count", len(records)))
return r.db.WithContext(ctx).Create(&records).Error
}
// GetByIDs 根据ID列表获取通知记录
func (r *GormNotificationRecordRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.NotificationRecord, error) {
var records []entities.NotificationRecord
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&records).Error
return records, err
}
// UpdateBatch 批量更新通知记录
func (r *GormNotificationRecordRepository) UpdateBatch(ctx context.Context, records []entities.NotificationRecord) error {
r.logger.Info("批量更新通知记录", zap.Int("count", len(records)))
return r.db.WithContext(ctx).Save(&records).Error
}
// DeleteBatch 批量删除通知记录
func (r *GormNotificationRecordRepository) DeleteBatch(ctx context.Context, ids []string) error {
r.logger.Info("批量删除通知记录", zap.Strings("ids", ids))
return r.db.WithContext(ctx).Delete(&entities.NotificationRecord{}, "id IN ?", ids).Error
}
// List 获取通知记录列表
func (r *GormNotificationRecordRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.NotificationRecord, error) {
var records []entities.NotificationRecord
query := r.db.WithContext(ctx).Model(&entities.NotificationRecord{})
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
if options.Search != "" {
query = query.Where("title LIKE ? OR content LIKE ? OR user_id LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%", "%"+options.Search+"%")
}
if options.Sort != "" {
order := "ASC"
if options.Order != "" {
order = options.Order
}
query = query.Order(options.Sort + " " + order)
}
if options.Page > 0 && options.PageSize > 0 {
offset := (options.Page - 1) * options.PageSize
query = query.Offset(offset).Limit(options.PageSize)
}
return records, query.Find(&records).Error
}
// WithTx 使用事务
func (r *GormNotificationRecordRepository) WithTx(tx interface{}) interfaces.Repository[entities.NotificationRecord] {
if gormTx, ok := tx.(*gorm.DB); ok {
return &GormNotificationRecordRepository{
db: gormTx,
logger: r.logger,
}
}
return r
}
// ================ 业务方法 ================
// GetByCertificationID 根据认证申请ID获取通知记录列表
func (r *GormNotificationRecordRepository) GetByCertificationID(ctx context.Context, certificationID string) ([]*entities.NotificationRecord, error) {
var records []entities.NotificationRecord
if err := r.db.WithContext(ctx).Where("certification_id = ?", certificationID).Order("created_at DESC").Find(&records).Error; err != nil {
r.logger.Error("根据认证申请ID获取通知记录失败",
zap.String("certification_id", certificationID),
zap.Error(err),
)
return nil, fmt.Errorf("获取通知记录失败: %w", err)
}
// 转换为指针切片
recordPtrs := make([]*entities.NotificationRecord, len(records))
for i := range records {
recordPtrs[i] = &records[i]
}
return recordPtrs, nil
}
// GetUnreadByUserID 根据用户ID获取未读通知记录列表
func (r *GormNotificationRecordRepository) GetUnreadByUserID(ctx context.Context, userID string) ([]*entities.NotificationRecord, error) {
var records []entities.NotificationRecord
if err := r.db.WithContext(ctx).Where("user_id = ? AND is_read = ?", userID, false).Order("created_at DESC").Find(&records).Error; err != nil {
r.logger.Error("根据用户ID获取未读通知记录失败",
zap.String("user_id", userID),
zap.Error(err),
)
return nil, fmt.Errorf("获取未读通知记录失败: %w", err)
}
// 转换为指针切片
recordPtrs := make([]*entities.NotificationRecord, len(records))
for i := range records {
recordPtrs[i] = &records[i]
}
return recordPtrs, nil
}
// ListRecords 获取通知记录列表(带分页和筛选)
func (r *GormNotificationRecordRepository) ListRecords(ctx context.Context, query *queries.ListNotificationRecordsQuery) ([]*entities.NotificationRecord, int64, error) {
var records []entities.NotificationRecord
var total int64
dbQuery := r.db.WithContext(ctx).Model(&entities.NotificationRecord{})
// 应用筛选条件
if query.CertificationID != "" {
dbQuery = dbQuery.Where("certification_id = ?", query.CertificationID)
}
if query.UserID != "" {
dbQuery = dbQuery.Where("user_id = ?", query.UserID)
}
if query.Type != "" {
dbQuery = dbQuery.Where("type = ?", query.Type)
}
if query.IsRead != nil {
dbQuery = dbQuery.Where("is_read = ?", *query.IsRead)
}
if query.StartDate != "" {
dbQuery = dbQuery.Where("created_at >= ?", query.StartDate)
}
if query.EndDate != "" {
dbQuery = dbQuery.Where("created_at <= ?", query.EndDate)
}
// 统计总数
if err := dbQuery.Count(&total).Error; err != nil {
return nil, 0, err
}
// 应用分页
offset := (query.Page - 1) * query.PageSize
dbQuery = dbQuery.Offset(offset).Limit(query.PageSize)
// 默认排序
dbQuery = dbQuery.Order("created_at DESC")
// 查询数据
if err := dbQuery.Find(&records).Error; err != nil {
return nil, 0, err
}
// 转换为指针切片
recordPtrs := make([]*entities.NotificationRecord, len(records))
for i := range records {
recordPtrs[i] = &records[i]
}
return recordPtrs, total, nil
}
// BatchCreate 批量创建通知记录
func (r *GormNotificationRecordRepository) BatchCreate(ctx context.Context, records []entities.NotificationRecord) error {
r.logger.Info("批量创建通知记录", zap.Int("count", len(records)))
return r.db.WithContext(ctx).Create(&records).Error
}
// MarkAsRead 标记通知记录为已读
func (r *GormNotificationRecordRepository) MarkAsRead(ctx context.Context, recordIDs []string) error {
if err := r.db.WithContext(ctx).
Model(&entities.NotificationRecord{}).
Where("id IN ?", recordIDs).
Update("is_read", true).Error; err != nil {
r.logger.Error("标记通知记录为已读失败",
zap.Strings("record_ids", recordIDs),
zap.Error(err),
)
return fmt.Errorf("标记通知记录为已读失败: %w", err)
}
r.logger.Info("通知记录标记为已读成功", zap.Strings("record_ids", recordIDs))
return nil
}
// MarkAllAsReadByUser 标记用户所有通知记录为已读
func (r *GormNotificationRecordRepository) MarkAllAsReadByUser(ctx context.Context, userID string) error {
if err := r.db.WithContext(ctx).
Model(&entities.NotificationRecord{}).
Where("user_id = ? AND is_read = ?", userID, false).
Update("is_read", true).Error; err != nil {
r.logger.Error("标记用户所有通知记录为已读失败",
zap.String("user_id", userID),
zap.Error(err),
)
return fmt.Errorf("标记用户所有通知记录为已读失败: %w", err)
}
r.logger.Info("用户所有通知记录标记为已读成功", zap.String("user_id", userID))
return nil
}

View File

@@ -0,0 +1,663 @@
package repositories
import (
"context"
"time"
"github.com/shopspring/decimal"
"go.uber.org/zap"
"gorm.io/gorm"
"tyapi-server/internal/domains/finance/entities"
domain_finance_repo "tyapi-server/internal/domains/finance/repositories"
"tyapi-server/internal/domains/finance/repositories/queries"
"tyapi-server/internal/shared/interfaces"
)
// GormWalletRepository 钱包GORM仓储实现
type GormWalletRepository struct {
db *gorm.DB
logger *zap.Logger
}
// 编译时检查接口实现
var _ domain_finance_repo.WalletRepository = (*GormWalletRepository)(nil)
// NewGormWalletRepository 创建钱包GORM仓储
func NewGormWalletRepository(db *gorm.DB, logger *zap.Logger) domain_finance_repo.WalletRepository {
return &GormWalletRepository{
db: db,
logger: logger,
}
}
// Create 创建钱包
func (r *GormWalletRepository) Create(ctx context.Context, wallet entities.Wallet) (entities.Wallet, error) {
r.logger.Info("创建钱包", zap.String("user_id", wallet.UserID))
err := r.db.WithContext(ctx).Create(&wallet).Error
return wallet, err
}
// GetByID 根据ID获取钱包
func (r *GormWalletRepository) GetByID(ctx context.Context, id string) (entities.Wallet, error) {
var wallet entities.Wallet
err := r.db.WithContext(ctx).Where("id = ?", id).First(&wallet).Error
return wallet, err
}
// Update 更新钱包
func (r *GormWalletRepository) Update(ctx context.Context, wallet entities.Wallet) error {
r.logger.Info("更新钱包", zap.String("id", wallet.ID))
return r.db.WithContext(ctx).Save(&wallet).Error
}
// Delete 删除钱包
func (r *GormWalletRepository) Delete(ctx context.Context, id string) error {
r.logger.Info("删除钱包", zap.String("id", id))
return r.db.WithContext(ctx).Delete(&entities.Wallet{}, "id = ?", id).Error
}
// SoftDelete 软删除钱包
func (r *GormWalletRepository) SoftDelete(ctx context.Context, id string) error {
r.logger.Info("软删除钱包", zap.String("id", id))
return r.db.WithContext(ctx).Delete(&entities.Wallet{}, "id = ?", id).Error
}
// Restore 恢复钱包
func (r *GormWalletRepository) Restore(ctx context.Context, id string) error {
r.logger.Info("恢复钱包", zap.String("id", id))
return r.db.WithContext(ctx).Unscoped().Model(&entities.Wallet{}).Where("id = ?", id).Update("deleted_at", nil).Error
}
// Count 统计钱包数量
func (r *GormWalletRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
var count int64
query := r.db.WithContext(ctx).Model(&entities.Wallet{})
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
if options.Search != "" {
query = query.Where("user_id LIKE ?", "%"+options.Search+"%")
}
return count, query.Count(&count).Error
}
// Exists 检查钱包是否存在
func (r *GormWalletRepository) Exists(ctx context.Context, id string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.Wallet{}).Where("id = ?", id).Count(&count).Error
return count > 0, err
}
// CreateBatch 批量创建钱包
func (r *GormWalletRepository) CreateBatch(ctx context.Context, wallets []entities.Wallet) error {
r.logger.Info("批量创建钱包", zap.Int("count", len(wallets)))
return r.db.WithContext(ctx).Create(&wallets).Error
}
// GetByIDs 根据ID列表获取钱包
func (r *GormWalletRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.Wallet, error) {
var wallets []entities.Wallet
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&wallets).Error
return wallets, err
}
// UpdateBatch 批量更新钱包
func (r *GormWalletRepository) UpdateBatch(ctx context.Context, wallets []entities.Wallet) error {
r.logger.Info("批量更新钱包", zap.Int("count", len(wallets)))
return r.db.WithContext(ctx).Save(&wallets).Error
}
// DeleteBatch 批量删除钱包
func (r *GormWalletRepository) DeleteBatch(ctx context.Context, ids []string) error {
r.logger.Info("批量删除钱包", zap.Strings("ids", ids))
return r.db.WithContext(ctx).Delete(&entities.Wallet{}, "id IN ?", ids).Error
}
// List 获取钱包列表
func (r *GormWalletRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.Wallet, error) {
var wallets []entities.Wallet
query := r.db.WithContext(ctx).Model(&entities.Wallet{})
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
if options.Search != "" {
query = query.Where("user_id LIKE ?", "%"+options.Search+"%")
}
if options.Sort != "" {
order := "ASC"
if options.Order != "" {
order = options.Order
}
query = query.Order(options.Sort + " " + order)
}
if options.Page > 0 && options.PageSize > 0 {
offset := (options.Page - 1) * options.PageSize
query = query.Offset(offset).Limit(options.PageSize)
}
return wallets, query.Find(&wallets).Error
}
// WithTx 使用事务
func (r *GormWalletRepository) WithTx(tx interface{}) interfaces.Repository[entities.Wallet] {
if gormTx, ok := tx.(*gorm.DB); ok {
return &GormWalletRepository{
db: gormTx,
logger: r.logger,
}
}
return r
}
// FindByUserID 根据用户ID查找钱包
func (r *GormWalletRepository) FindByUserID(ctx context.Context, userID string) (*entities.Wallet, error) {
var wallet entities.Wallet
err := r.db.WithContext(ctx).Where("user_id = ?", userID).First(&wallet).Error
if err != nil {
return nil, err
}
return &wallet, nil
}
// ExistsByUserID 检查用户钱包是否存在
func (r *GormWalletRepository) ExistsByUserID(ctx context.Context, userID string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.Wallet{}).Where("user_id = ?", userID).Count(&count).Error
return count > 0, err
}
// GetTotalBalance 获取总余额
func (r *GormWalletRepository) GetTotalBalance(ctx context.Context) (interface{}, error) {
var total decimal.Decimal
err := r.db.WithContext(ctx).Model(&entities.Wallet{}).Select("COALESCE(SUM(balance), 0)").Scan(&total).Error
return total, err
}
// GetActiveWalletCount 获取激活钱包数量
func (r *GormWalletRepository) GetActiveWalletCount(ctx context.Context) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.Wallet{}).Where("is_active = ?", true).Count(&count).Error
return count, err
}
// ================ 接口要求的方法 ================
// GetByUserID 根据用户ID获取钱包
func (r *GormWalletRepository) GetByUserID(ctx context.Context, userID string) (*entities.Wallet, error) {
var wallet entities.Wallet
err := r.db.WithContext(ctx).Where("user_id = ?", userID).First(&wallet).Error
if err != nil {
return nil, err
}
return &wallet, nil
}
// GetByWalletAddress 根据钱包地址获取钱包
func (r *GormWalletRepository) GetByWalletAddress(ctx context.Context, walletAddress string) (*entities.Wallet, error) {
var wallet entities.Wallet
err := r.db.WithContext(ctx).Where("wallet_address = ?", walletAddress).First(&wallet).Error
if err != nil {
return nil, err
}
return &wallet, nil
}
// GetByWalletType 根据钱包类型获取钱包
func (r *GormWalletRepository) GetByWalletType(ctx context.Context, userID string, walletType string) (*entities.Wallet, error) {
var wallet entities.Wallet
err := r.db.WithContext(ctx).Where("user_id = ? AND wallet_type = ?", userID, walletType).First(&wallet).Error
if err != nil {
return nil, err
}
return &wallet, nil
}
// ListWallets 获取钱包列表(带分页和筛选)
func (r *GormWalletRepository) ListWallets(ctx context.Context, query *queries.ListWalletsQuery) ([]*entities.Wallet, int64, error) {
var wallets []entities.Wallet
var total int64
dbQuery := r.db.WithContext(ctx).Model(&entities.Wallet{})
// 应用筛选条件
if query.UserID != "" {
dbQuery = dbQuery.Where("user_id = ?", query.UserID)
}
if query.WalletType != "" {
dbQuery = dbQuery.Where("wallet_type = ?", query.WalletType)
}
if query.WalletAddress != "" {
dbQuery = dbQuery.Where("wallet_address LIKE ?", "%"+query.WalletAddress+"%")
}
if query.IsActive != nil {
dbQuery = dbQuery.Where("is_active = ?", *query.IsActive)
}
if query.StartDate != "" {
dbQuery = dbQuery.Where("created_at >= ?", query.StartDate)
}
if query.EndDate != "" {
dbQuery = dbQuery.Where("created_at <= ?", query.EndDate)
}
// 统计总数
if err := dbQuery.Count(&total).Error; err != nil {
return nil, 0, err
}
// 应用分页
offset := (query.Page - 1) * query.PageSize
dbQuery = dbQuery.Offset(offset).Limit(query.PageSize)
// 默认排序
dbQuery = dbQuery.Order("created_at DESC")
// 查询数据
if err := dbQuery.Find(&wallets).Error; err != nil {
return nil, 0, err
}
// 转换为指针切片
walletPtrs := make([]*entities.Wallet, len(wallets))
for i := range wallets {
walletPtrs[i] = &wallets[i]
}
return walletPtrs, total, nil
}
// UpdateBalance 更新钱包余额
func (r *GormWalletRepository) UpdateBalance(ctx context.Context, walletID string, balance string) error {
return r.db.WithContext(ctx).Model(&entities.Wallet{}).Where("id = ?", walletID).Update("balance", balance).Error
}
// AddBalance 增加钱包余额
func (r *GormWalletRepository) AddBalance(ctx context.Context, walletID string, amount string) error {
return r.db.WithContext(ctx).Model(&entities.Wallet{}).Where("id = ?", walletID).Update("balance", gorm.Expr("balance + ?", amount)).Error
}
// SubtractBalance 减少钱包余额
func (r *GormWalletRepository) SubtractBalance(ctx context.Context, walletID string, amount string) error {
return r.db.WithContext(ctx).Model(&entities.Wallet{}).Where("id = ?", walletID).Update("balance", gorm.Expr("balance - ?", amount)).Error
}
// ActivateWallet 激活钱包
func (r *GormWalletRepository) ActivateWallet(ctx context.Context, walletID string) error {
return r.db.WithContext(ctx).Model(&entities.Wallet{}).Where("id = ?", walletID).Update("is_active", true).Error
}
// DeactivateWallet 停用钱包
func (r *GormWalletRepository) DeactivateWallet(ctx context.Context, walletID string) error {
return r.db.WithContext(ctx).Model(&entities.Wallet{}).Where("id = ?", walletID).Update("is_active", false).Error
}
// GetStats 获取财务统计信息
func (r *GormWalletRepository) GetStats(ctx context.Context) (*domain_finance_repo.FinanceStats, error) {
var stats domain_finance_repo.FinanceStats
// 总钱包数
if err := r.db.WithContext(ctx).Model(&entities.Wallet{}).Count(&stats.TotalWallets).Error; err != nil {
return nil, err
}
// 激活钱包数
if err := r.db.WithContext(ctx).Model(&entities.Wallet{}).Where("is_active = ?", true).Count(&stats.ActiveWallets).Error; err != nil {
return nil, err
}
// 总余额
var totalBalance decimal.Decimal
if err := r.db.WithContext(ctx).Model(&entities.Wallet{}).Select("COALESCE(SUM(balance), 0)").Scan(&totalBalance).Error; err != nil {
return nil, err
}
stats.TotalBalance = totalBalance.String()
// 今日交易数(这里需要根据实际业务逻辑实现)
stats.TodayTransactions = 0
return &stats, nil
}
// GetUserWalletStats 获取用户钱包统计信息
func (r *GormWalletRepository) GetUserWalletStats(ctx context.Context, userID string) (*domain_finance_repo.FinanceStats, error) {
var stats domain_finance_repo.FinanceStats
// 用户钱包数
if err := r.db.WithContext(ctx).Model(&entities.Wallet{}).Where("user_id = ?", userID).Count(&stats.TotalWallets).Error; err != nil {
return nil, err
}
// 用户激活钱包数
if err := r.db.WithContext(ctx).Model(&entities.Wallet{}).Where("user_id = ? AND is_active = ?", userID, true).Count(&stats.ActiveWallets).Error; err != nil {
return nil, err
}
// 用户总余额
var totalBalance decimal.Decimal
if err := r.db.WithContext(ctx).Model(&entities.Wallet{}).Where("user_id = ?", userID).Select("COALESCE(SUM(balance), 0)").Scan(&totalBalance).Error; err != nil {
return nil, err
}
stats.TotalBalance = totalBalance.String()
// 用户今日交易数(这里需要根据实际业务逻辑实现)
stats.TodayTransactions = 0
return &stats, nil
}
// GormUserSecretsRepository 用户密钥GORM仓储实现
type GormUserSecretsRepository struct {
db *gorm.DB
logger *zap.Logger
}
// 编译时检查接口实现
var _ domain_finance_repo.UserSecretsRepository = (*GormUserSecretsRepository)(nil)
// NewGormUserSecretsRepository 创建用户密钥GORM仓储
func NewGormUserSecretsRepository(db *gorm.DB, logger *zap.Logger) domain_finance_repo.UserSecretsRepository {
return &GormUserSecretsRepository{
db: db,
logger: logger,
}
}
// Create 创建用户密钥
func (r *GormUserSecretsRepository) Create(ctx context.Context, secrets entities.UserSecrets) (entities.UserSecrets, error) {
r.logger.Info("创建用户密钥", zap.String("user_id", secrets.UserID))
err := r.db.WithContext(ctx).Create(&secrets).Error
return secrets, err
}
// GetByID 根据ID获取用户密钥
func (r *GormUserSecretsRepository) GetByID(ctx context.Context, id string) (entities.UserSecrets, error) {
var secrets entities.UserSecrets
err := r.db.WithContext(ctx).Where("id = ?", id).First(&secrets).Error
return secrets, err
}
// Update 更新用户密钥
func (r *GormUserSecretsRepository) Update(ctx context.Context, secrets entities.UserSecrets) error {
r.logger.Info("更新用户密钥", zap.String("id", secrets.ID))
return r.db.WithContext(ctx).Save(&secrets).Error
}
// Delete 删除用户密钥
func (r *GormUserSecretsRepository) Delete(ctx context.Context, id string) error {
r.logger.Info("删除用户密钥", zap.String("id", id))
return r.db.WithContext(ctx).Delete(&entities.UserSecrets{}, "id = ?", id).Error
}
// SoftDelete 软删除用户密钥
func (r *GormUserSecretsRepository) SoftDelete(ctx context.Context, id string) error {
r.logger.Info("软删除用户密钥", zap.String("id", id))
return r.db.WithContext(ctx).Delete(&entities.UserSecrets{}, "id = ?", id).Error
}
// Restore 恢复用户密钥
func (r *GormUserSecretsRepository) Restore(ctx context.Context, id string) error {
r.logger.Info("恢复用户密钥", zap.String("id", id))
return r.db.WithContext(ctx).Unscoped().Model(&entities.UserSecrets{}).Where("id = ?", id).Update("deleted_at", nil).Error
}
// Count 统计用户密钥数量
func (r *GormUserSecretsRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
var count int64
query := r.db.WithContext(ctx).Model(&entities.UserSecrets{})
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
if options.Search != "" {
query = query.Where("user_id LIKE ? OR access_id LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
}
return count, query.Count(&count).Error
}
// Exists 检查用户密钥是否存在
func (r *GormUserSecretsRepository) Exists(ctx context.Context, id string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.UserSecrets{}).Where("id = ?", id).Count(&count).Error
return count > 0, err
}
// CreateBatch 批量创建用户密钥
func (r *GormUserSecretsRepository) CreateBatch(ctx context.Context, secrets []entities.UserSecrets) error {
r.logger.Info("批量创建用户密钥", zap.Int("count", len(secrets)))
return r.db.WithContext(ctx).Create(&secrets).Error
}
// GetByIDs 根据ID列表获取用户密钥
func (r *GormUserSecretsRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.UserSecrets, error) {
var secrets []entities.UserSecrets
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&secrets).Error
return secrets, err
}
// UpdateBatch 批量更新用户密钥
func (r *GormUserSecretsRepository) UpdateBatch(ctx context.Context, secrets []entities.UserSecrets) error {
r.logger.Info("批量更新用户密钥", zap.Int("count", len(secrets)))
return r.db.WithContext(ctx).Save(&secrets).Error
}
// DeleteBatch 批量删除用户密钥
func (r *GormUserSecretsRepository) DeleteBatch(ctx context.Context, ids []string) error {
r.logger.Info("批量删除用户密钥", zap.Strings("ids", ids))
return r.db.WithContext(ctx).Delete(&entities.UserSecrets{}, "id IN ?", ids).Error
}
// List 获取用户密钥列表
func (r *GormUserSecretsRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.UserSecrets, error) {
var secrets []entities.UserSecrets
query := r.db.WithContext(ctx).Model(&entities.UserSecrets{})
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
if options.Search != "" {
query = query.Where("user_id LIKE ? OR access_id LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
}
if options.Sort != "" {
order := "ASC"
if options.Order != "" {
order = options.Order
}
query = query.Order(options.Sort + " " + order)
}
if options.Page > 0 && options.PageSize > 0 {
offset := (options.Page - 1) * options.PageSize
query = query.Offset(offset).Limit(options.PageSize)
}
return secrets, query.Find(&secrets).Error
}
// WithTx 使用事务
func (r *GormUserSecretsRepository) WithTx(tx interface{}) interfaces.Repository[entities.UserSecrets] {
if gormTx, ok := tx.(*gorm.DB); ok {
return &GormUserSecretsRepository{
db: gormTx,
logger: r.logger,
}
}
return r
}
// FindByUserID 根据用户ID查找密钥
func (r *GormUserSecretsRepository) FindByUserID(ctx context.Context, userID string) (*entities.UserSecrets, error) {
var secrets entities.UserSecrets
err := r.db.WithContext(ctx).Where("user_id = ?", userID).First(&secrets).Error
if err != nil {
return nil, err
}
return &secrets, nil
}
// FindByAccessID 根据访问ID查找密钥
func (r *GormUserSecretsRepository) FindByAccessID(ctx context.Context, accessID string) (*entities.UserSecrets, error) {
var secrets entities.UserSecrets
err := r.db.WithContext(ctx).Where("access_id = ?", accessID).First(&secrets).Error
if err != nil {
return nil, err
}
return &secrets, nil
}
// ExistsByUserID 检查用户密钥是否存在
func (r *GormUserSecretsRepository) ExistsByUserID(ctx context.Context, userID string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.UserSecrets{}).Where("user_id = ?", userID).Count(&count).Error
return count > 0, err
}
// ExistsByAccessID 检查访问ID是否存在
func (r *GormUserSecretsRepository) ExistsByAccessID(ctx context.Context, accessID string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.UserSecrets{}).Where("access_id = ?", accessID).Count(&count).Error
return count > 0, err
}
// UpdateLastUsedAt 更新最后使用时间
func (r *GormUserSecretsRepository) UpdateLastUsedAt(ctx context.Context, accessID string) error {
return r.db.WithContext(ctx).Model(&entities.UserSecrets{}).Where("access_id = ?", accessID).Update("last_used_at", time.Now()).Error
}
// DeactivateByUserID 停用用户密钥
func (r *GormUserSecretsRepository) DeactivateByUserID(ctx context.Context, userID string) error {
return r.db.WithContext(ctx).Model(&entities.UserSecrets{}).Where("user_id = ?", userID).Update("is_active", false).Error
}
// RegenerateAccessKey 重新生成访问密钥
func (r *GormUserSecretsRepository) RegenerateAccessKey(ctx context.Context, userID string, accessID, accessKey string) error {
return r.db.WithContext(ctx).Model(&entities.UserSecrets{}).Where("user_id = ?", userID).Updates(map[string]interface{}{
"access_id": accessID,
"access_key": accessKey,
"updated_at": time.Now(),
}).Error
}
// GetExpiredSecrets 获取过期的密钥
func (r *GormUserSecretsRepository) GetExpiredSecrets(ctx context.Context) ([]entities.UserSecrets, error) {
var secrets []entities.UserSecrets
err := r.db.WithContext(ctx).Where("expires_at IS NOT NULL AND expires_at < ?", time.Now()).Find(&secrets).Error
return secrets, err
}
// DeleteExpiredSecrets 删除过期的密钥
func (r *GormUserSecretsRepository) DeleteExpiredSecrets(ctx context.Context) error {
return r.db.WithContext(ctx).Where("expires_at < ?", time.Now()).Delete(&entities.UserSecrets{}).Error
}
// ================ 接口要求的方法 ================
// GetByUserID 根据用户ID获取用户密钥
func (r *GormUserSecretsRepository) GetByUserID(ctx context.Context, userID string) (*entities.UserSecrets, error) {
var secrets entities.UserSecrets
err := r.db.WithContext(ctx).Where("user_id = ?", userID).First(&secrets).Error
if err != nil {
return nil, err
}
return &secrets, nil
}
// GetBySecretType 根据密钥类型获取用户密钥
func (r *GormUserSecretsRepository) GetBySecretType(ctx context.Context, userID string, secretType string) (*entities.UserSecrets, error) {
var secrets entities.UserSecrets
err := r.db.WithContext(ctx).Where("user_id = ? AND secret_type = ?", userID, secretType).First(&secrets).Error
if err != nil {
return nil, err
}
return &secrets, nil
}
// ListUserSecrets 获取用户密钥列表(带分页和筛选)
func (r *GormUserSecretsRepository) ListUserSecrets(ctx context.Context, query *queries.ListUserSecretsQuery) ([]*entities.UserSecrets, int64, error) {
var secrets []entities.UserSecrets
var total int64
dbQuery := r.db.WithContext(ctx).Model(&entities.UserSecrets{})
// 应用筛选条件
if query.UserID != "" {
dbQuery = dbQuery.Where("user_id = ?", query.UserID)
}
if query.SecretType != "" {
dbQuery = dbQuery.Where("secret_type = ?", query.SecretType)
}
if query.IsActive != nil {
dbQuery = dbQuery.Where("is_active = ?", *query.IsActive)
}
if query.StartDate != "" {
dbQuery = dbQuery.Where("created_at >= ?", query.StartDate)
}
if query.EndDate != "" {
dbQuery = dbQuery.Where("created_at <= ?", query.EndDate)
}
// 统计总数
if err := dbQuery.Count(&total).Error; err != nil {
return nil, 0, err
}
// 应用分页
offset := (query.Page - 1) * query.PageSize
dbQuery = dbQuery.Offset(offset).Limit(query.PageSize)
// 默认排序
dbQuery = dbQuery.Order("created_at DESC")
// 查询数据
if err := dbQuery.Find(&secrets).Error; err != nil {
return nil, 0, err
}
// 转换为指针切片
secretPtrs := make([]*entities.UserSecrets, len(secrets))
for i := range secrets {
secretPtrs[i] = &secrets[i]
}
return secretPtrs, total, nil
}
// UpdateSecret 更新密钥
func (r *GormUserSecretsRepository) UpdateSecret(ctx context.Context, userID string, secretType string, secretValue string) error {
return r.db.WithContext(ctx).Model(&entities.UserSecrets{}).
Where("user_id = ? AND secret_type = ?", userID, secretType).
Update("secret_value", secretValue).Error
}
// DeleteSecret 删除密钥
func (r *GormUserSecretsRepository) DeleteSecret(ctx context.Context, userID string, secretType string) error {
return r.db.WithContext(ctx).Where("user_id = ? AND secret_type = ?", userID, secretType).
Delete(&entities.UserSecrets{}).Error
}
// ValidateSecret 验证密钥
func (r *GormUserSecretsRepository) ValidateSecret(ctx context.Context, userID string, secretType string, secretValue string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.UserSecrets{}).
Where("user_id = ? AND secret_type = ? AND secret_value = ?", userID, secretType, secretValue).
Count(&count).Error
return count > 0, err
}

View File

@@ -0,0 +1,273 @@
package repositories
import (
"context"
"fmt"
"time"
"go.uber.org/zap"
"gorm.io/gorm"
"tyapi-server/internal/domains/user/entities"
"tyapi-server/internal/domains/user/repositories"
"tyapi-server/internal/shared/interfaces"
)
// GormEnterpriseInfoRepository 企业信息GORM仓储实现
type GormEnterpriseInfoRepository struct {
db *gorm.DB
logger *zap.Logger
}
// NewGormEnterpriseInfoRepository 创建企业信息GORM仓储
func NewGormEnterpriseInfoRepository(db *gorm.DB, logger *zap.Logger) repositories.EnterpriseInfoRepository {
return &GormEnterpriseInfoRepository{
db: db,
logger: logger,
}
}
// Create 创建企业信息
func (r *GormEnterpriseInfoRepository) Create(ctx context.Context, enterpriseInfo entities.EnterpriseInfo) (entities.EnterpriseInfo, error) {
if err := r.db.WithContext(ctx).Create(&enterpriseInfo).Error; err != nil {
r.logger.Error("创建企业信息失败", zap.Error(err))
return entities.EnterpriseInfo{}, fmt.Errorf("创建企业信息失败: %w", err)
}
return enterpriseInfo, nil
}
// GetByID 根据ID获取企业信息
func (r *GormEnterpriseInfoRepository) GetByID(ctx context.Context, id string) (entities.EnterpriseInfo, error) {
var enterpriseInfo entities.EnterpriseInfo
if err := r.db.WithContext(ctx).Where("id = ?", id).First(&enterpriseInfo).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return entities.EnterpriseInfo{}, fmt.Errorf("企业信息不存在")
}
r.logger.Error("获取企业信息失败", zap.Error(err))
return entities.EnterpriseInfo{}, fmt.Errorf("获取企业信息失败: %w", err)
}
return enterpriseInfo, nil
}
// Update 更新企业信息
func (r *GormEnterpriseInfoRepository) Update(ctx context.Context, enterpriseInfo entities.EnterpriseInfo) error {
// 检查企业信息是否已认证完成,认证完成后不可修改
if enterpriseInfo.IsReadOnly() {
return fmt.Errorf("企业信息已认证完成,不可修改")
}
if err := r.db.WithContext(ctx).Save(&enterpriseInfo).Error; err != nil {
r.logger.Error("更新企业信息失败", zap.Error(err))
return fmt.Errorf("更新企业信息失败: %w", err)
}
return nil
}
// Delete 删除企业信息
func (r *GormEnterpriseInfoRepository) Delete(ctx context.Context, id string) error {
if err := r.db.WithContext(ctx).Delete(&entities.EnterpriseInfo{}, "id = ?", id).Error; err != nil {
r.logger.Error("删除企业信息失败", zap.Error(err))
return fmt.Errorf("删除企业信息失败: %w", err)
}
return nil
}
// SoftDelete 软删除企业信息
func (r *GormEnterpriseInfoRepository) SoftDelete(ctx context.Context, id string) error {
if err := r.db.WithContext(ctx).Delete(&entities.EnterpriseInfo{}, "id = ?", id).Error; err != nil {
r.logger.Error("软删除企业信息失败", zap.Error(err))
return fmt.Errorf("软删除企业信息失败: %w", err)
}
return nil
}
// Restore 恢复软删除的企业信息
func (r *GormEnterpriseInfoRepository) Restore(ctx context.Context, id string) error {
if err := r.db.WithContext(ctx).Unscoped().Model(&entities.EnterpriseInfo{}).Where("id = ?", id).Update("deleted_at", nil).Error; err != nil {
r.logger.Error("恢复企业信息失败", zap.Error(err))
return fmt.Errorf("恢复企业信息失败: %w", err)
}
return nil
}
// GetByUserID 根据用户ID获取企业信息
func (r *GormEnterpriseInfoRepository) GetByUserID(ctx context.Context, userID string) (*entities.EnterpriseInfo, error) {
var enterpriseInfo entities.EnterpriseInfo
if err := r.db.WithContext(ctx).Where("user_id = ?", userID).First(&enterpriseInfo).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return nil, fmt.Errorf("企业信息不存在")
}
r.logger.Error("获取企业信息失败", zap.Error(err))
return nil, fmt.Errorf("获取企业信息失败: %w", err)
}
return &enterpriseInfo, nil
}
// GetByUnifiedSocialCode 根据统一社会信用代码获取企业信息
func (r *GormEnterpriseInfoRepository) GetByUnifiedSocialCode(ctx context.Context, unifiedSocialCode string) (*entities.EnterpriseInfo, error) {
var enterpriseInfo entities.EnterpriseInfo
if err := r.db.WithContext(ctx).Where("unified_social_code = ?", unifiedSocialCode).First(&enterpriseInfo).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return nil, fmt.Errorf("企业信息不存在")
}
r.logger.Error("获取企业信息失败", zap.Error(err))
return nil, fmt.Errorf("获取企业信息失败: %w", err)
}
return &enterpriseInfo, nil
}
// CheckUnifiedSocialCodeExists 检查统一社会信用代码是否已存在
func (r *GormEnterpriseInfoRepository) CheckUnifiedSocialCodeExists(ctx context.Context, unifiedSocialCode string, excludeUserID string) (bool, error) {
var count int64
query := r.db.WithContext(ctx).Model(&entities.EnterpriseInfo{}).Where("unified_social_code = ?", unifiedSocialCode)
if excludeUserID != "" {
query = query.Where("user_id != ?", excludeUserID)
}
if err := query.Count(&count).Error; err != nil {
r.logger.Error("检查统一社会信用代码失败", zap.Error(err))
return false, fmt.Errorf("检查统一社会信用代码失败: %w", err)
}
return count > 0, nil
}
// UpdateVerificationStatus 更新验证状态
func (r *GormEnterpriseInfoRepository) UpdateVerificationStatus(ctx context.Context, userID string, isOCRVerified, isFaceVerified, isCertified bool) error {
updates := map[string]interface{}{
"is_ocr_verified": isOCRVerified,
"is_face_verified": isFaceVerified,
"is_certified": isCertified,
}
if err := r.db.WithContext(ctx).Model(&entities.EnterpriseInfo{}).Where("user_id = ?", userID).Updates(updates).Error; err != nil {
r.logger.Error("更新验证状态失败", zap.Error(err))
return fmt.Errorf("更新验证状态失败: %w", err)
}
return nil
}
// UpdateOCRData 更新OCR数据
func (r *GormEnterpriseInfoRepository) UpdateOCRData(ctx context.Context, userID string, rawData string, confidence float64) error {
updates := map[string]interface{}{
"ocr_raw_data": rawData,
"ocr_confidence": confidence,
"is_ocr_verified": true,
}
if err := r.db.WithContext(ctx).Model(&entities.EnterpriseInfo{}).Where("user_id = ?", userID).Updates(updates).Error; err != nil {
r.logger.Error("更新OCR数据失败", zap.Error(err))
return fmt.Errorf("更新OCR数据失败: %w", err)
}
return nil
}
// CompleteCertification 完成认证
func (r *GormEnterpriseInfoRepository) CompleteCertification(ctx context.Context, userID string) error {
now := time.Now()
updates := map[string]interface{}{
"is_certified": true,
"certified_at": &now,
}
if err := r.db.WithContext(ctx).Model(&entities.EnterpriseInfo{}).Where("user_id = ?", userID).Updates(updates).Error; err != nil {
r.logger.Error("完成认证失败", zap.Error(err))
return fmt.Errorf("完成认证失败: %w", err)
}
return nil
}
// Count 统计企业信息数量
func (r *GormEnterpriseInfoRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
var count int64
query := r.db.WithContext(ctx).Model(&entities.EnterpriseInfo{})
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
if options.Search != "" {
query = query.Where("company_name LIKE ? OR unified_social_code LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
}
err := query.Count(&count).Error
return count, err
}
// Exists 检查企业信息是否存在
func (r *GormEnterpriseInfoRepository) Exists(ctx context.Context, id string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.EnterpriseInfo{}).Where("id = ?", id).Count(&count).Error
return count > 0, err
}
// CreateBatch 批量创建企业信息
func (r *GormEnterpriseInfoRepository) CreateBatch(ctx context.Context, enterpriseInfos []entities.EnterpriseInfo) error {
return r.db.WithContext(ctx).Create(&enterpriseInfos).Error
}
// GetByIDs 根据ID列表获取企业信息
func (r *GormEnterpriseInfoRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.EnterpriseInfo, error) {
var enterpriseInfos []entities.EnterpriseInfo
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&enterpriseInfos).Error
return enterpriseInfos, err
}
// UpdateBatch 批量更新企业信息
func (r *GormEnterpriseInfoRepository) UpdateBatch(ctx context.Context, enterpriseInfos []entities.EnterpriseInfo) error {
return r.db.WithContext(ctx).Save(&enterpriseInfos).Error
}
// DeleteBatch 批量删除企业信息
func (r *GormEnterpriseInfoRepository) DeleteBatch(ctx context.Context, ids []string) error {
return r.db.WithContext(ctx).Delete(&entities.EnterpriseInfo{}, "id IN ?", ids).Error
}
// List 获取企业信息列表
func (r *GormEnterpriseInfoRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.EnterpriseInfo, error) {
var enterpriseInfos []entities.EnterpriseInfo
query := r.db.WithContext(ctx).Model(&entities.EnterpriseInfo{})
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
if options.Search != "" {
query = query.Where("company_name LIKE ? OR unified_social_code LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
}
if options.Sort != "" {
order := "ASC"
if options.Order != "" {
order = options.Order
}
query = query.Order(options.Sort + " " + order)
}
if options.Page > 0 && options.PageSize > 0 {
offset := (options.Page - 1) * options.PageSize
query = query.Offset(offset).Limit(options.PageSize)
}
err := query.Find(&enterpriseInfos).Error
return enterpriseInfos, err
}
// WithTx 使用事务
func (r *GormEnterpriseInfoRepository) WithTx(tx interface{}) interfaces.Repository[entities.EnterpriseInfo] {
if gormTx, ok := tx.(*gorm.DB); ok {
return &GormEnterpriseInfoRepository{
db: gormTx,
logger: r.logger,
}
}
return r
}

View File

@@ -0,0 +1,499 @@
//go:build !test
// +build !test
package repositories
import (
"context"
"fmt"
"time"
"go.uber.org/zap"
"gorm.io/gorm"
"tyapi-server/internal/domains/user/entities"
"tyapi-server/internal/domains/user/repositories"
"tyapi-server/internal/domains/user/repositories/queries"
"tyapi-server/internal/shared/interfaces"
)
// SMSCodeRepository 短信验证码仓储
type GormSMSCodeRepository struct {
db *gorm.DB
cache interfaces.CacheService
logger *zap.Logger
}
// NewGormSMSCodeRepository 创建短信验证码仓储
func NewGormSMSCodeRepository(db *gorm.DB, cache interfaces.CacheService, logger *zap.Logger) repositories.SMSCodeRepository {
return &GormSMSCodeRepository{
db: db,
cache: cache,
logger: logger,
}
}
// 确保 GormSMSCodeRepository 实现了 SMSCodeRepository 接口
var _ repositories.SMSCodeRepository = (*GormSMSCodeRepository)(nil)
// ================ 基础CRUD操作 ================
// Create 创建短信验证码记录
func (r *GormSMSCodeRepository) Create(ctx context.Context, smsCode entities.SMSCode) (entities.SMSCode, error) {
if err := r.db.WithContext(ctx).Create(&smsCode).Error; err != nil {
r.logger.Error("创建短信验证码失败", zap.Error(err))
return entities.SMSCode{}, err
}
// 缓存验证码
cacheKey := r.buildCacheKey(smsCode.Phone, smsCode.Scene)
r.cache.Set(ctx, cacheKey, &smsCode, 5*time.Minute)
return smsCode, nil
}
// GetByID 根据ID获取短信验证码
func (r *GormSMSCodeRepository) GetByID(ctx context.Context, id string) (entities.SMSCode, error) {
var smsCode entities.SMSCode
if err := r.db.WithContext(ctx).Where("id = ?", id).First(&smsCode).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return entities.SMSCode{}, fmt.Errorf("短信验证码不存在")
}
r.logger.Error("获取短信验证码失败", zap.Error(err))
return entities.SMSCode{}, err
}
return smsCode, nil
}
// Update 更新验证码记录
func (r *GormSMSCodeRepository) Update(ctx context.Context, smsCode entities.SMSCode) error {
if err := r.db.WithContext(ctx).Save(&smsCode).Error; err != nil {
r.logger.Error("更新验证码记录失败", zap.Error(err))
return err
}
// 更新缓存
cacheKey := r.buildCacheKey(smsCode.Phone, smsCode.Scene)
r.cache.Set(ctx, cacheKey, &smsCode, 5*time.Minute)
r.logger.Info("验证码记录更新成功", zap.String("code_id", smsCode.ID))
return nil
}
// Delete 删除短信验证码
func (r *GormSMSCodeRepository) Delete(ctx context.Context, id string) error {
if err := r.db.WithContext(ctx).Delete(&entities.SMSCode{}, "id = ?", id).Error; err != nil {
r.logger.Error("删除短信验证码失败", zap.Error(err))
return err
}
r.logger.Info("短信验证码删除成功", zap.String("id", id))
return nil
}
// SoftDelete 软删除短信验证码
func (r *GormSMSCodeRepository) SoftDelete(ctx context.Context, id string) error {
return r.Delete(ctx, id)
}
// Restore 恢复短信验证码
func (r *GormSMSCodeRepository) Restore(ctx context.Context, id string) error {
if err := r.db.WithContext(ctx).Unscoped().Model(&entities.SMSCode{}).Where("id = ?", id).Update("deleted_at", nil).Error; err != nil {
r.logger.Error("恢复短信验证码失败", zap.Error(err))
return err
}
r.logger.Info("短信验证码恢复成功", zap.String("id", id))
return nil
}
// Count 统计短信验证码数量
func (r *GormSMSCodeRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
var count int64
query := r.db.WithContext(ctx).Model(&entities.SMSCode{})
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
if options.Search != "" {
query = query.Where("phone LIKE ? OR code LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
}
err := query.Count(&count).Error
return count, err
}
// Exists 检查短信验证码是否存在
func (r *GormSMSCodeRepository) Exists(ctx context.Context, id string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.SMSCode{}).Where("id = ?", id).Count(&count).Error
return count > 0, err
}
// CreateBatch 批量创建短信验证码
func (r *GormSMSCodeRepository) CreateBatch(ctx context.Context, smsCodes []entities.SMSCode) error {
r.logger.Info("批量创建短信验证码", zap.Int("count", len(smsCodes)))
return r.db.WithContext(ctx).Create(&smsCodes).Error
}
// GetByIDs 根据ID列表获取短信验证码
func (r *GormSMSCodeRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.SMSCode, error) {
var smsCodes []entities.SMSCode
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&smsCodes).Error
return smsCodes, err
}
// UpdateBatch 批量更新短信验证码
func (r *GormSMSCodeRepository) UpdateBatch(ctx context.Context, smsCodes []entities.SMSCode) error {
r.logger.Info("批量更新短信验证码", zap.Int("count", len(smsCodes)))
return r.db.WithContext(ctx).Save(&smsCodes).Error
}
// DeleteBatch 批量删除短信验证码
func (r *GormSMSCodeRepository) DeleteBatch(ctx context.Context, ids []string) error {
r.logger.Info("批量删除短信验证码", zap.Strings("ids", ids))
return r.db.WithContext(ctx).Delete(&entities.SMSCode{}, "id IN ?", ids).Error
}
// List 获取短信验证码列表
func (r *GormSMSCodeRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.SMSCode, error) {
var smsCodes []entities.SMSCode
query := r.db.WithContext(ctx).Model(&entities.SMSCode{})
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
if options.Search != "" {
query = query.Where("phone LIKE ? OR code LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
}
if options.Sort != "" {
order := "ASC"
if options.Order != "" {
order = options.Order
}
query = query.Order(options.Sort + " " + order)
}
if options.Page > 0 && options.PageSize > 0 {
offset := (options.Page - 1) * options.PageSize
query = query.Offset(offset).Limit(options.PageSize)
}
return smsCodes, query.Find(&smsCodes).Error
}
// WithTx 使用事务
func (r *GormSMSCodeRepository) WithTx(tx interface{}) interfaces.Repository[entities.SMSCode] {
if gormTx, ok := tx.(*gorm.DB); ok {
return &GormSMSCodeRepository{
db: gormTx,
cache: r.cache,
logger: r.logger,
}
}
return r
}
// ================ 业务方法 ================
// GetByPhone 根据手机号获取短信验证码
func (r *GormSMSCodeRepository) GetByPhone(ctx context.Context, phone string) (*entities.SMSCode, error) {
var smsCode entities.SMSCode
if err := r.db.WithContext(ctx).Where("phone = ?", phone).Order("created_at DESC").First(&smsCode).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return nil, fmt.Errorf("短信验证码不存在")
}
r.logger.Error("根据手机号获取短信验证码失败", zap.Error(err))
return nil, err
}
return &smsCode, nil
}
// GetLatestByPhone 根据手机号获取最新的短信验证码
func (r *GormSMSCodeRepository) GetLatestByPhone(ctx context.Context, phone string) (*entities.SMSCode, error) {
return r.GetByPhone(ctx, phone)
}
// GetValidByPhone 根据手机号获取有效的短信验证码
func (r *GormSMSCodeRepository) GetValidByPhone(ctx context.Context, phone string) (*entities.SMSCode, error) {
return r.GetValidCode(ctx, phone, "")
}
// GetValidByPhoneAndScene 根据手机号和场景获取有效的验证码
func (r *GormSMSCodeRepository) GetValidByPhoneAndScene(ctx context.Context, phone string, scene entities.SMSScene) (*entities.SMSCode, error) {
return r.GetValidCode(ctx, phone, scene)
}
// ListSMSCodes 获取短信验证码列表(带分页和筛选)
func (r *GormSMSCodeRepository) ListSMSCodes(ctx context.Context, query *queries.ListSMSCodesQuery) ([]*entities.SMSCode, int64, error) {
var smsCodes []entities.SMSCode
var total int64
dbQuery := r.db.WithContext(ctx).Model(&entities.SMSCode{})
// 应用筛选条件
if query.Phone != "" {
dbQuery = dbQuery.Where("phone = ?", query.Phone)
}
if query.Purpose != "" {
dbQuery = dbQuery.Where("scene = ?", query.Purpose)
}
if query.Status != "" {
dbQuery = dbQuery.Where("status = ?", query.Status)
}
if query.StartDate != "" {
dbQuery = dbQuery.Where("created_at >= ?", query.StartDate)
}
if query.EndDate != "" {
dbQuery = dbQuery.Where("created_at <= ?", query.EndDate)
}
// 统计总数
if err := dbQuery.Count(&total).Error; err != nil {
return nil, 0, err
}
// 应用分页
offset := (query.Page - 1) * query.PageSize
dbQuery = dbQuery.Offset(offset).Limit(query.PageSize)
// 默认排序
dbQuery = dbQuery.Order("created_at DESC")
// 查询数据
if err := dbQuery.Find(&smsCodes).Error; err != nil {
return nil, 0, err
}
// 转换为指针切片
smsCodePtrs := make([]*entities.SMSCode, len(smsCodes))
for i := range smsCodes {
smsCodePtrs[i] = &smsCodes[i]
}
return smsCodePtrs, total, nil
}
// CreateCode 创建验证码
func (r *GormSMSCodeRepository) CreateCode(ctx context.Context, phone string, code string, purpose string) (entities.SMSCode, error) {
smsCode := entities.SMSCode{
Phone: phone,
Code: code,
Scene: entities.SMSScene(purpose),
ExpiresAt: time.Now().Add(5 * time.Minute), // 5分钟过期
Used: false,
}
return r.Create(ctx, smsCode)
}
// ValidateCode 验证验证码
func (r *GormSMSCodeRepository) ValidateCode(ctx context.Context, phone string, code string, purpose string) (bool, error) {
var smsCode entities.SMSCode
if err := r.db.WithContext(ctx).
Where("phone = ? AND code = ? AND scene = ? AND expires_at > ? AND used_at IS NULL",
phone, code, purpose, time.Now()).
First(&smsCode).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return false, nil
}
r.logger.Error("验证验证码失败", zap.Error(err))
return false, err
}
// 标记为已使用
if err := r.MarkAsUsed(ctx, smsCode.ID); err != nil {
r.logger.Error("标记验证码为已使用失败", zap.Error(err))
return false, err
}
return true, nil
}
// InvalidateCode 使验证码失效
func (r *GormSMSCodeRepository) InvalidateCode(ctx context.Context, phone string) error {
if err := r.db.WithContext(ctx).
Model(&entities.SMSCode{}).
Where("phone = ? AND used_at IS NULL", phone).
Update("used_at", time.Now()).Error; err != nil {
r.logger.Error("使验证码失效失败", zap.Error(err))
return err
}
// 清除缓存
cacheKey := r.buildCacheKey(phone, "")
r.cache.Delete(ctx, cacheKey)
r.logger.Info("验证码已失效", zap.String("phone", phone))
return nil
}
// CheckSendFrequency 检查发送频率
func (r *GormSMSCodeRepository) CheckSendFrequency(ctx context.Context, phone string, purpose string) (bool, error) {
// 检查最近1分钟内是否已发送
oneMinuteAgo := time.Now().Add(-1 * time.Minute)
var count int64
if err := r.db.WithContext(ctx).
Model(&entities.SMSCode{}).
Where("phone = ? AND scene = ? AND created_at > ?", phone, purpose, oneMinuteAgo).
Count(&count).Error; err != nil {
r.logger.Error("检查发送频率失败", zap.Error(err))
return false, err
}
return count == 0, nil
}
// GetTodaySendCount 获取今日发送次数
func (r *GormSMSCodeRepository) GetTodaySendCount(ctx context.Context, phone string) (int64, error) {
today := time.Now().Truncate(24 * time.Hour)
var count int64
if err := r.db.WithContext(ctx).
Model(&entities.SMSCode{}).
Where("phone = ? AND created_at >= ?", phone, today).
Count(&count).Error; err != nil {
r.logger.Error("获取今日发送次数失败", zap.Error(err))
return 0, err
}
return count, nil
}
// GetCodeStats 获取验证码统计信息
func (r *GormSMSCodeRepository) GetCodeStats(ctx context.Context, phone string, days int) (*repositories.SMSCodeStats, error) {
var stats repositories.SMSCodeStats
// 计算指定天数前的日期
startDate := time.Now().AddDate(0, 0, -days)
// 总发送数
if err := r.db.WithContext(ctx).
Model(&entities.SMSCode{}).
Where("phone = ? AND created_at >= ?", phone, startDate).
Count(&stats.TotalSent).Error; err != nil {
return nil, err
}
// 总验证数
if err := r.db.WithContext(ctx).
Model(&entities.SMSCode{}).
Where("phone = ? AND created_at >= ? AND used_at IS NOT NULL", phone, startDate).
Count(&stats.TotalValidated).Error; err != nil {
return nil, err
}
// 成功率
if stats.TotalSent > 0 {
stats.SuccessRate = float64(stats.TotalValidated) / float64(stats.TotalSent) * 100
}
// 今日发送数
today := time.Now().Truncate(24 * time.Hour)
if err := r.db.WithContext(ctx).
Model(&entities.SMSCode{}).
Where("phone = ? AND created_at >= ?", phone, today).
Count(&stats.TodaySent).Error; err != nil {
return nil, err
}
return &stats, nil
}
// GetValidCode 获取有效的验证码
func (r *GormSMSCodeRepository) GetValidCode(ctx context.Context, phone string, scene entities.SMSScene) (*entities.SMSCode, error) {
// 先从缓存查找
cacheKey := r.buildCacheKey(phone, scene)
var smsCode entities.SMSCode
if err := r.cache.Get(ctx, cacheKey, &smsCode); err == nil {
return &smsCode, nil
}
// 从数据库查找最新的有效验证码
if err := r.db.WithContext(ctx).
Where("phone = ? AND scene = ? AND expires_at > ? AND used_at IS NULL",
phone, scene, time.Now()).
Order("created_at DESC").
First(&smsCode).Error; err != nil {
return nil, err
}
// 缓存结果
r.cache.Set(ctx, cacheKey, &smsCode, 5*time.Minute)
return &smsCode, nil
}
// MarkAsUsed 标记验证码为已使用
func (r *GormSMSCodeRepository) MarkAsUsed(ctx context.Context, id string) error {
now := time.Now()
if err := r.db.WithContext(ctx).
Model(&entities.SMSCode{}).
Where("id = ?", id).
Update("used_at", now).Error; err != nil {
r.logger.Error("标记验证码为已使用失败", zap.Error(err))
return err
}
r.logger.Info("验证码已标记为使用", zap.String("code_id", id))
return nil
}
// GetRecentCode 获取最近的验证码记录(不限制有效性)
func (r *GormSMSCodeRepository) GetRecentCode(ctx context.Context, phone string, scene entities.SMSScene) (*entities.SMSCode, error) {
var smsCode entities.SMSCode
if err := r.db.WithContext(ctx).
Where("phone = ? AND scene = ?", phone, scene).
Order("created_at DESC").
First(&smsCode).Error; err != nil {
return nil, err
}
return &smsCode, nil
}
// CleanupExpired 清理过期的验证码
func (r *GormSMSCodeRepository) CleanupExpired(ctx context.Context) error {
result := r.db.WithContext(ctx).
Where("expires_at < ?", time.Now()).
Delete(&entities.SMSCode{})
if result.Error != nil {
r.logger.Error("清理过期验证码失败", zap.Error(result.Error))
return result.Error
}
if result.RowsAffected > 0 {
r.logger.Info("清理过期验证码完成", zap.Int64("count", result.RowsAffected))
}
return nil
}
// CountRecentCodes 统计最近发送的验证码数量
func (r *GormSMSCodeRepository) CountRecentCodes(ctx context.Context, phone string, scene entities.SMSScene, duration time.Duration) (int64, error) {
var count int64
if err := r.db.WithContext(ctx).
Model(&entities.SMSCode{}).
Where("phone = ? AND scene = ? AND created_at > ?",
phone, scene, time.Now().Add(-duration)).
Count(&count).Error; err != nil {
r.logger.Error("统计最近验证码数量失败", zap.Error(err))
return 0, err
}
return count, nil
}
// buildCacheKey 构建缓存键
func (r *GormSMSCodeRepository) buildCacheKey(phone string, scene entities.SMSScene) string {
return fmt.Sprintf("sms_code:%s:%s", phone, string(scene))
}

View File

@@ -0,0 +1,511 @@
//go:build !test
// +build !test
package repositories
import (
"context"
"errors"
"fmt"
"time"
"go.uber.org/zap"
"gorm.io/gorm"
"tyapi-server/internal/domains/user/entities"
"tyapi-server/internal/domains/user/repositories"
"tyapi-server/internal/domains/user/repositories/queries"
"tyapi-server/internal/shared/interfaces"
)
// 定义错误常量
var (
// ErrUserNotFound 用户不存在错误
ErrUserNotFound = errors.New("用户不存在")
)
// UserRepository 用户仓储实现
type GormUserRepository struct {
db *gorm.DB
cache interfaces.CacheService
logger *zap.Logger
}
// NewGormUserRepository 创建用户仓储
func NewGormUserRepository(db *gorm.DB, cache interfaces.CacheService, logger *zap.Logger) repositories.UserRepository {
return &GormUserRepository{
db: db,
cache: cache,
logger: logger,
}
}
// 确保 GormUserRepository 实现了 UserRepository 接口
var _ repositories.UserRepository = (*GormUserRepository)(nil)
// ================ 基础CRUD操作 ================
// Create 创建用户
func (r *GormUserRepository) Create(ctx context.Context, user entities.User) (entities.User, error) {
if err := r.db.WithContext(ctx).Create(&user).Error; err != nil {
r.logger.Error("创建用户失败", zap.Error(err))
return entities.User{}, err
}
// 清除相关缓存
r.deleteCacheByPhone(ctx, user.Phone)
r.logger.Info("用户创建成功", zap.String("user_id", user.ID))
return user, nil
}
// GetByID 根据ID获取用户
func (r *GormUserRepository) GetByID(ctx context.Context, id string) (entities.User, error) {
// 尝试从缓存获取
cacheKey := fmt.Sprintf("user:id:%s", id)
var userCache entities.UserCache
if err := r.cache.Get(ctx, cacheKey, &userCache); err == nil {
var user entities.User
user.FromCache(&userCache)
return user, nil
}
// 从数据库查询
var user entities.User
if err := r.db.WithContext(ctx).Where("id = ?", id).First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return entities.User{}, ErrUserNotFound
}
r.logger.Error("根据ID查询用户失败", zap.Error(err))
return entities.User{}, err
}
// 缓存结果
r.cache.Set(ctx, cacheKey, user.ToCache(), 10*time.Minute)
return user, nil
}
// Update 更新用户
func (r *GormUserRepository) Update(ctx context.Context, user entities.User) error {
if err := r.db.WithContext(ctx).Save(&user).Error; err != nil {
r.logger.Error("更新用户失败", zap.Error(err))
return err
}
// 清除相关缓存
r.deleteCacheByID(ctx, user.ID)
r.deleteCacheByPhone(ctx, user.Phone)
r.logger.Info("用户更新成功", zap.String("user_id", user.ID))
return nil
}
// Delete 删除用户
func (r *GormUserRepository) Delete(ctx context.Context, id string) error {
// 先获取用户信息用于清除缓存
user, err := r.GetByID(ctx, id)
if err != nil {
return err
}
if err := r.db.WithContext(ctx).Delete(&entities.User{}, "id = ?", id).Error; err != nil {
r.logger.Error("删除用户失败", zap.Error(err))
return err
}
// 清除相关缓存
r.deleteCacheByID(ctx, id)
r.deleteCacheByPhone(ctx, user.Phone)
r.logger.Info("用户删除成功", zap.String("user_id", id))
return nil
}
// SoftDelete 软删除用户
func (r *GormUserRepository) SoftDelete(ctx context.Context, id string) error {
// 先获取用户信息用于清除缓存
user, err := r.GetByID(ctx, id)
if err != nil {
return err
}
if err := r.db.WithContext(ctx).Delete(&entities.User{}, "id = ?", id).Error; err != nil {
r.logger.Error("软删除用户失败", zap.Error(err))
return err
}
// 清除相关缓存
r.deleteCacheByID(ctx, id)
r.deleteCacheByPhone(ctx, user.Phone)
r.logger.Info("用户软删除成功", zap.String("user_id", id))
return nil
}
// Restore 恢复软删除的用户
func (r *GormUserRepository) Restore(ctx context.Context, id string) error {
if err := r.db.WithContext(ctx).Unscoped().Model(&entities.User{}).Where("id = ?", id).Update("deleted_at", nil).Error; err != nil {
r.logger.Error("恢复用户失败", zap.Error(err))
return err
}
// 清除相关缓存
r.deleteCacheByID(ctx, id)
r.logger.Info("用户恢复成功", zap.String("user_id", id))
return nil
}
// Count 统计用户数量
func (r *GormUserRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
var count int64
query := r.db.WithContext(ctx).Model(&entities.User{})
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
if options.Search != "" {
query = query.Where("phone LIKE ? OR nickname LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
}
err := query.Count(&count).Error
return count, err
}
// Exists 检查用户是否存在
func (r *GormUserRepository) Exists(ctx context.Context, id string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.User{}).Where("id = ?", id).Count(&count).Error
return count > 0, err
}
// CreateBatch 批量创建用户
func (r *GormUserRepository) CreateBatch(ctx context.Context, users []entities.User) error {
r.logger.Info("批量创建用户", zap.Int("count", len(users)))
return r.db.WithContext(ctx).Create(&users).Error
}
// GetByIDs 根据ID列表获取用户
func (r *GormUserRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.User, error) {
var users []entities.User
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&users).Error
return users, err
}
// UpdateBatch 批量更新用户
func (r *GormUserRepository) UpdateBatch(ctx context.Context, users []entities.User) error {
r.logger.Info("批量更新用户", zap.Int("count", len(users)))
return r.db.WithContext(ctx).Save(&users).Error
}
// DeleteBatch 批量删除用户
func (r *GormUserRepository) DeleteBatch(ctx context.Context, ids []string) error {
r.logger.Info("批量删除用户", zap.Strings("ids", ids))
return r.db.WithContext(ctx).Delete(&entities.User{}, "id IN ?", ids).Error
}
// List 获取用户列表
func (r *GormUserRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.User, error) {
var users []entities.User
query := r.db.WithContext(ctx).Model(&entities.User{})
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
if options.Search != "" {
query = query.Where("phone LIKE ? OR nickname LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
}
if options.Sort != "" {
order := "ASC"
if options.Order != "" {
order = options.Order
}
query = query.Order(options.Sort + " " + order)
}
if options.Page > 0 && options.PageSize > 0 {
offset := (options.Page - 1) * options.PageSize
query = query.Offset(offset).Limit(options.PageSize)
}
return users, query.Find(&users).Error
}
// WithTx 使用事务
func (r *GormUserRepository) WithTx(tx interface{}) interfaces.Repository[entities.User] {
if gormTx, ok := tx.(*gorm.DB); ok {
return &GormUserRepository{
db: gormTx,
cache: r.cache,
logger: r.logger,
}
}
return r
}
// ================ 业务方法 ================
// GetByPhone 根据手机号获取用户
func (r *GormUserRepository) GetByPhone(ctx context.Context, phone string) (*entities.User, error) {
// 尝试从缓存获取
cacheKey := fmt.Sprintf("user:phone:%s", phone)
var userCache entities.UserCache
if err := r.cache.Get(ctx, cacheKey, &userCache); err == nil {
var user entities.User
user.FromCache(&userCache)
return &user, nil
}
// 从数据库查询
var user entities.User
if err := r.db.WithContext(ctx).Where("phone = ?", phone).First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
r.logger.Error("根据手机号查询用户失败", zap.Error(err))
return nil, err
}
// 缓存结果
r.cache.Set(ctx, cacheKey, user.ToCache(), 10*time.Minute)
return &user, nil
}
// ListUsers 获取用户列表(带分页和筛选)
func (r *GormUserRepository) ListUsers(ctx context.Context, query *queries.ListUsersQuery) ([]*entities.User, int64, error) {
var users []entities.User
var total int64
dbQuery := r.db.WithContext(ctx).Model(&entities.User{})
// 应用筛选条件
if query.Phone != "" {
dbQuery = dbQuery.Where("phone LIKE ?", "%"+query.Phone+"%")
}
if query.StartDate != "" {
dbQuery = dbQuery.Where("created_at >= ?", query.StartDate)
}
if query.EndDate != "" {
dbQuery = dbQuery.Where("created_at <= ?", query.EndDate)
}
// 统计总数
if err := dbQuery.Count(&total).Error; err != nil {
return nil, 0, err
}
// 应用分页
offset := (query.Page - 1) * query.PageSize
dbQuery = dbQuery.Offset(offset).Limit(query.PageSize)
// 默认排序
dbQuery = dbQuery.Order("created_at DESC")
// 查询数据
if err := dbQuery.Find(&users).Error; err != nil {
return nil, 0, err
}
// 转换为指针切片
userPtrs := make([]*entities.User, len(users))
for i := range users {
userPtrs[i] = &users[i]
}
return userPtrs, total, nil
}
// ValidateUser 验证用户
func (r *GormUserRepository) ValidateUser(ctx context.Context, phone, password string) (*entities.User, error) {
var user entities.User
if err := r.db.WithContext(ctx).Where("phone = ? AND password = ?", phone, password).First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("手机号或密码错误")
}
r.logger.Error("验证用户失败", zap.Error(err))
return nil, err
}
return &user, nil
}
// UpdateLastLogin 更新最后登录时间
func (r *GormUserRepository) UpdateLastLogin(ctx context.Context, userID string) error {
if err := r.db.WithContext(ctx).
Model(&entities.User{}).
Where("id = ?", userID).
Update("last_login_at", time.Now()).Error; err != nil {
r.logger.Error("更新最后登录时间失败", zap.Error(err))
return err
}
// 清除相关缓存
r.deleteCacheByID(ctx, userID)
r.logger.Info("最后登录时间更新成功", zap.String("user_id", userID))
return nil
}
// UpdatePassword 更新密码
func (r *GormUserRepository) UpdatePassword(ctx context.Context, userID string, newPassword string) error {
if err := r.db.WithContext(ctx).
Model(&entities.User{}).
Where("id = ?", userID).
Update("password", newPassword).Error; err != nil {
r.logger.Error("更新密码失败", zap.Error(err))
return err
}
// 清除相关缓存
r.deleteCacheByID(ctx, userID)
r.logger.Info("密码更新成功", zap.String("user_id", userID))
return nil
}
// CheckPassword 检查密码
func (r *GormUserRepository) CheckPassword(ctx context.Context, userID string, password string) (bool, error) {
var user entities.User
if err := r.db.WithContext(ctx).Where("id = ? AND password = ?", userID, password).First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return false, nil
}
r.logger.Error("检查密码失败", zap.Error(err))
return false, err
}
return true, nil
}
// ActivateUser 激活用户
func (r *GormUserRepository) ActivateUser(ctx context.Context, userID string) error {
if err := r.db.WithContext(ctx).
Model(&entities.User{}).
Where("id = ?", userID).
Update("status", "ACTIVE").Error; err != nil {
r.logger.Error("激活用户失败", zap.Error(err))
return err
}
// 清除相关缓存
r.deleteCacheByID(ctx, userID)
r.logger.Info("用户激活成功", zap.String("user_id", userID))
return nil
}
// DeactivateUser 停用用户
func (r *GormUserRepository) DeactivateUser(ctx context.Context, userID string) error {
if err := r.db.WithContext(ctx).
Model(&entities.User{}).
Where("id = ?", userID).
Update("status", "INACTIVE").Error; err != nil {
r.logger.Error("停用用户失败", zap.Error(err))
return err
}
// 清除相关缓存
r.deleteCacheByID(ctx, userID)
r.logger.Info("用户停用成功", zap.String("user_id", userID))
return nil
}
// GetStats 获取用户统计信息
func (r *GormUserRepository) GetStats(ctx context.Context) (*repositories.UserStats, error) {
var stats repositories.UserStats
// 总用户数
if err := r.db.WithContext(ctx).Model(&entities.User{}).Count(&stats.TotalUsers).Error; err != nil {
return nil, err
}
// 活跃用户数
if err := r.db.WithContext(ctx).Model(&entities.User{}).Where("status = ?", "ACTIVE").Count(&stats.ActiveUsers).Error; err != nil {
return nil, err
}
// 今日注册数
today := time.Now().Truncate(24 * time.Hour)
if err := r.db.WithContext(ctx).Model(&entities.User{}).Where("created_at >= ?", today).Count(&stats.TodayRegistrations).Error; err != nil {
return nil, err
}
// 今日登录数
if err := r.db.WithContext(ctx).Model(&entities.User{}).Where("last_login_at >= ?", today).Count(&stats.TodayLogins).Error; err != nil {
return nil, err
}
return &stats, nil
}
// GetStatsByDateRange 根据日期范围获取用户统计信息
func (r *GormUserRepository) GetStatsByDateRange(ctx context.Context, startDate, endDate string) (*repositories.UserStats, error) {
var stats repositories.UserStats
// 总用户数
if err := r.db.WithContext(ctx).Model(&entities.User{}).Where("created_at BETWEEN ? AND ?", startDate, endDate).Count(&stats.TotalUsers).Error; err != nil {
return nil, err
}
// 活跃用户数
if err := r.db.WithContext(ctx).Model(&entities.User{}).Where("status = ? AND created_at BETWEEN ? AND ?", "ACTIVE", startDate, endDate).Count(&stats.ActiveUsers).Error; err != nil {
return nil, err
}
// 今日注册数
today := time.Now().Truncate(24 * time.Hour)
if err := r.db.WithContext(ctx).Model(&entities.User{}).Where("created_at >= ?", today).Count(&stats.TodayRegistrations).Error; err != nil {
return nil, err
}
// 今日登录数
if err := r.db.WithContext(ctx).Model(&entities.User{}).Where("last_login_at >= ?", today).Count(&stats.TodayLogins).Error; err != nil {
return nil, err
}
return &stats, nil
}
// FindByPhone 根据手机号查找用户(兼容旧方法)
func (r *GormUserRepository) FindByPhone(ctx context.Context, phone string) (*entities.User, error) {
return r.GetByPhone(ctx, phone)
}
// ExistsByPhone 检查手机号是否存在
func (r *GormUserRepository) ExistsByPhone(ctx context.Context, phone string) (bool, error) {
var count int64
if err := r.db.WithContext(ctx).Model(&entities.User{}).Where("phone = ?", phone).Count(&count).Error; err != nil {
r.logger.Error("检查手机号是否存在失败", zap.Error(err))
return false, err
}
return count > 0, nil
}
// 私有辅助方法
// deleteCacheByID 根据ID删除缓存
func (r *GormUserRepository) deleteCacheByID(ctx context.Context, id string) {
cacheKey := fmt.Sprintf("user:id:%s", id)
if err := r.cache.Delete(ctx, cacheKey); err != nil {
r.logger.Warn("删除用户ID缓存失败", zap.String("cache_key", cacheKey), zap.Error(err))
}
}
// deleteCacheByPhone 根据手机号删除缓存
func (r *GormUserRepository) deleteCacheByPhone(ctx context.Context, phone string) {
cacheKey := fmt.Sprintf("user:phone:%s", phone)
if err := r.cache.Delete(ctx, cacheKey); err != nil {
r.logger.Warn("删除用户手机号缓存失败", zap.String("cache_key", cacheKey), zap.Error(err))
}
}

View File

@@ -0,0 +1,515 @@
package notification
import (
"bytes"
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"time"
"go.uber.org/zap"
)
// WeChatWorkService 企业微信通知服务
type WeChatWorkService struct {
webhookURL string
secret string
timeout time.Duration
logger *zap.Logger
}
// WechatWorkConfig 企业微信配置
type WechatWorkConfig struct {
WebhookURL string `yaml:"webhook_url"`
Timeout time.Duration `yaml:"timeout"`
}
// WechatWorkMessage 企业微信消息
type WechatWorkMessage struct {
MsgType string `json:"msgtype"`
Text *WechatWorkText `json:"text,omitempty"`
Markdown *WechatWorkMarkdown `json:"markdown,omitempty"`
}
// WechatWorkText 文本消息
type WechatWorkText struct {
Content string `json:"content"`
MentionedList []string `json:"mentioned_list,omitempty"`
MentionedMobileList []string `json:"mentioned_mobile_list,omitempty"`
}
// WechatWorkMarkdown Markdown消息
type WechatWorkMarkdown struct {
Content string `json:"content"`
}
// NewWeChatWorkService 创建企业微信通知服务
func NewWeChatWorkService(webhookURL, secret string, logger *zap.Logger) *WeChatWorkService {
return &WeChatWorkService{
webhookURL: webhookURL,
secret: secret,
timeout: 30 * time.Second,
logger: logger,
}
}
// SendTextMessage 发送文本消息
func (s *WeChatWorkService) SendTextMessage(ctx context.Context, content string, mentionedList []string, mentionedMobileList []string) error {
s.logger.Info("发送企业微信文本消息",
zap.String("content", content),
zap.Strings("mentioned_list", mentionedList),
)
message := map[string]interface{}{
"msgtype": "text",
"text": map[string]interface{}{
"content": content,
"mentioned_list": mentionedList,
"mentioned_mobile_list": mentionedMobileList,
},
}
return s.sendMessage(ctx, message)
}
// SendMarkdownMessage 发送Markdown消息
func (s *WeChatWorkService) SendMarkdownMessage(ctx context.Context, content string) error {
s.logger.Info("发送企业微信Markdown消息", zap.String("content", content))
message := map[string]interface{}{
"msgtype": "markdown",
"markdown": map[string]interface{}{
"content": content,
},
}
return s.sendMessage(ctx, message)
}
// SendCardMessage 发送卡片消息
func (s *WeChatWorkService) SendCardMessage(ctx context.Context, title, description, url string, btnText string) error {
s.logger.Info("发送企业微信卡片消息",
zap.String("title", title),
zap.String("description", description),
)
message := map[string]interface{}{
"msgtype": "template_card",
"template_card": map[string]interface{}{
"card_type": "text_notice",
"source": map[string]interface{}{
"icon_url": "https://example.com/icon.png",
"desc": "企业认证系统",
},
"main_title": map[string]interface{}{
"title": title,
},
"horizontal_content_list": []map[string]interface{}{
{
"keyname": "描述",
"value": description,
},
},
"jump_list": []map[string]interface{}{
{
"type": "1",
"title": btnText,
"url": url,
},
},
},
}
return s.sendMessage(ctx, message)
}
// SendCertificationNotification 发送认证相关通知
func (s *WeChatWorkService) SendCertificationNotification(ctx context.Context, notificationType string, data map[string]interface{}) error {
s.logger.Info("发送认证通知", zap.String("type", notificationType))
switch notificationType {
case "new_application":
return s.sendNewApplicationNotification(ctx, data)
case "ocr_success":
return s.sendOCRSuccessNotification(ctx, data)
case "ocr_failed":
return s.sendOCRFailedNotification(ctx, data)
case "face_verify_success":
return s.sendFaceVerifySuccessNotification(ctx, data)
case "face_verify_failed":
return s.sendFaceVerifyFailedNotification(ctx, data)
case "admin_approved":
return s.sendAdminApprovedNotification(ctx, data)
case "admin_rejected":
return s.sendAdminRejectedNotification(ctx, data)
case "contract_signed":
return s.sendContractSignedNotification(ctx, data)
case "certification_completed":
return s.sendCertificationCompletedNotification(ctx, data)
default:
return fmt.Errorf("不支持的通知类型: %s", notificationType)
}
}
// sendNewApplicationNotification 发送新申请通知
func (s *WeChatWorkService) sendNewApplicationNotification(ctx context.Context, data map[string]interface{}) error {
companyName := data["company_name"].(string)
applicantName := data["applicant_name"].(string)
applicationID := data["application_id"].(string)
content := fmt.Sprintf(`## 🆕 新的企业认证申请
**企业名称**: %s
**申请人**: %s
**申请ID**: %s
**申请时间**: %s
请管理员及时审核处理。`,
companyName,
applicantName,
applicationID,
time.Now().Format("2006-01-02 15:04:05"))
return s.SendMarkdownMessage(ctx, content)
}
// sendOCRSuccessNotification 发送OCR识别成功通知
func (s *WeChatWorkService) sendOCRSuccessNotification(ctx context.Context, data map[string]interface{}) error {
companyName := data["company_name"].(string)
confidence := data["confidence"].(float64)
applicationID := data["application_id"].(string)
content := fmt.Sprintf(`## ✅ OCR识别成功
**企业名称**: %s
**识别置信度**: %.2f%%
**申请ID**: %s
**识别时间**: %s
营业执照信息已自动提取,请用户确认信息。`,
companyName,
confidence*100,
applicationID,
time.Now().Format("2006-01-02 15:04:05"))
return s.SendMarkdownMessage(ctx, content)
}
// sendOCRFailedNotification 发送OCR识别失败通知
func (s *WeChatWorkService) sendOCRFailedNotification(ctx context.Context, data map[string]interface{}) error {
applicationID := data["application_id"].(string)
errorMsg := data["error_message"].(string)
content := fmt.Sprintf(`## ❌ OCR识别失败
**申请ID**: %s
**错误信息**: %s
**失败时间**: %s
请检查营业执照图片质量或联系技术支持。`,
applicationID,
errorMsg,
time.Now().Format("2006-01-02 15:04:05"))
return s.SendMarkdownMessage(ctx, content)
}
// sendFaceVerifySuccessNotification 发送人脸识别成功通知
func (s *WeChatWorkService) sendFaceVerifySuccessNotification(ctx context.Context, data map[string]interface{}) error {
applicantName := data["applicant_name"].(string)
applicationID := data["application_id"].(string)
confidence := data["confidence"].(float64)
content := fmt.Sprintf(`## ✅ 人脸识别成功
**申请人**: %s
**申请ID**: %s
**识别置信度**: %.2f%%
**识别时间**: %s
身份验证通过,可以进行下一步操作。`,
applicantName,
applicationID,
confidence*100,
time.Now().Format("2006-01-02 15:04:05"))
return s.SendMarkdownMessage(ctx, content)
}
// sendFaceVerifyFailedNotification 发送人脸识别失败通知
func (s *WeChatWorkService) sendFaceVerifyFailedNotification(ctx context.Context, data map[string]interface{}) error {
applicantName := data["applicant_name"].(string)
applicationID := data["application_id"].(string)
errorMsg := data["error_message"].(string)
content := fmt.Sprintf(`## ❌ 人脸识别失败
**申请人**: %s
**申请ID**: %s
**错误信息**: %s
**失败时间**: %s
请重新进行人脸识别或联系技术支持。`,
applicantName,
applicationID,
errorMsg,
time.Now().Format("2006-01-02 15:04:05"))
return s.SendMarkdownMessage(ctx, content)
}
// sendAdminApprovedNotification 发送管理员审核通过通知
func (s *WeChatWorkService) sendAdminApprovedNotification(ctx context.Context, data map[string]interface{}) error {
companyName := data["company_name"].(string)
applicationID := data["application_id"].(string)
adminName := data["admin_name"].(string)
comment := data["comment"].(string)
content := fmt.Sprintf(`## ✅ 管理员审核通过
**企业名称**: %s
**申请ID**: %s
**审核人**: %s
**审核意见**: %s
**审核时间**: %s
认证申请已通过审核,请用户签署电子合同。`,
companyName,
applicationID,
adminName,
comment,
time.Now().Format("2006-01-02 15:04:05"))
return s.SendMarkdownMessage(ctx, content)
}
// sendAdminRejectedNotification 发送管理员审核拒绝通知
func (s *WeChatWorkService) sendAdminRejectedNotification(ctx context.Context, data map[string]interface{}) error {
companyName := data["company_name"].(string)
applicationID := data["application_id"].(string)
adminName := data["admin_name"].(string)
reason := data["reason"].(string)
content := fmt.Sprintf(`## ❌ 管理员审核拒绝
**企业名称**: %s
**申请ID**: %s
**审核人**: %s
**拒绝原因**: %s
**审核时间**: %s
认证申请被拒绝,请根据反馈意见重新提交。`,
companyName,
applicationID,
adminName,
reason,
time.Now().Format("2006-01-02 15:04:05"))
return s.SendMarkdownMessage(ctx, content)
}
// sendContractSignedNotification 发送合同签署通知
func (s *WeChatWorkService) sendContractSignedNotification(ctx context.Context, data map[string]interface{}) error {
companyName := data["company_name"].(string)
applicationID := data["application_id"].(string)
signerName := data["signer_name"].(string)
content := fmt.Sprintf(`## 📝 电子合同已签署
**企业名称**: %s
**申请ID**: %s
**签署人**: %s
**签署时间**: %s
电子合同签署完成系统将自动生成钱包和Access Key。`,
companyName,
applicationID,
signerName,
time.Now().Format("2006-01-02 15:04:05"))
return s.SendMarkdownMessage(ctx, content)
}
// sendCertificationCompletedNotification 发送认证完成通知
func (s *WeChatWorkService) sendCertificationCompletedNotification(ctx context.Context, data map[string]interface{}) error {
companyName := data["company_name"].(string)
applicationID := data["application_id"].(string)
walletAddress := data["wallet_address"].(string)
content := fmt.Sprintf(`## 🎉 企业认证完成
**企业名称**: %s
**申请ID**: %s
**钱包地址**: %s
**完成时间**: %s
恭喜企业认证流程已完成钱包和Access Key已生成。`,
companyName,
applicationID,
walletAddress,
time.Now().Format("2006-01-02 15:04:05"))
return s.SendMarkdownMessage(ctx, content)
}
// sendMessage 发送消息到企业微信
func (s *WeChatWorkService) sendMessage(ctx context.Context, message map[string]interface{}) error {
// 生成签名URL
signedURL := s.generateSignedURL()
// 序列化消息
messageBytes, err := json.Marshal(message)
if err != nil {
return fmt.Errorf("序列化消息失败: %w", err)
}
// 创建HTTP客户端
client := &http.Client{
Timeout: s.timeout,
}
// 创建请求
req, err := http.NewRequestWithContext(ctx, "POST", signedURL, bytes.NewBuffer(messageBytes))
if err != nil {
return fmt.Errorf("创建请求失败: %w", err)
}
// 设置请求头
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", "tyapi-server/1.0")
// 发送请求
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("发送请求失败: %w", err)
}
defer resp.Body.Close()
// 检查响应状态
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("请求失败,状态码: %d", resp.StatusCode)
}
// 解析响应
var response map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
return fmt.Errorf("解析响应失败: %w", err)
}
// 检查错误码
if errCode, ok := response["errcode"].(float64); ok && errCode != 0 {
errmsg := response["errmsg"].(string)
return fmt.Errorf("企业微信API错误: %d - %s", int(errCode), errmsg)
}
s.logger.Info("企业微信消息发送成功", zap.Any("response", response))
return nil
}
// generateSignedURL 生成带签名的URL
func (s *WeChatWorkService) generateSignedURL() string {
if s.secret == "" {
return s.webhookURL
}
// 生成时间戳
timestamp := time.Now().Unix()
// 生成随机字符串(这里简化处理,实际应该使用随机字符串)
nonce := fmt.Sprintf("%d", timestamp)
// 构建签名字符串
signStr := fmt.Sprintf("%d\n%s", timestamp, s.secret)
// 计算签名
h := hmac.New(sha256.New, []byte(s.secret))
h.Write([]byte(signStr))
signature := base64.StdEncoding.EncodeToString(h.Sum(nil))
// 构建签名URL
return fmt.Sprintf("%s&timestamp=%d&nonce=%s&sign=%s",
s.webhookURL, timestamp, nonce, signature)
}
// SendSystemAlert 发送系统告警
func (s *WeChatWorkService) SendSystemAlert(ctx context.Context, level, title, message string) error {
s.logger.Info("发送系统告警",
zap.String("level", level),
zap.String("title", title),
)
// 根据告警级别选择图标
var icon string
switch level {
case "info":
icon = ""
case "warning":
icon = "⚠️"
case "error":
icon = "🚨"
case "critical":
icon = "💥"
default:
icon = "📢"
}
content := fmt.Sprintf(`## %s 系统告警
**级别**: %s
**标题**: %s
**消息**: %s
**时间**: %s
请相关人员及时处理。`,
icon,
level,
title,
message,
time.Now().Format("2006-01-02 15:04:05"))
return s.SendMarkdownMessage(ctx, content)
}
// SendDailyReport 发送每日报告
func (s *WeChatWorkService) SendDailyReport(ctx context.Context, reportData map[string]interface{}) error {
s.logger.Info("发送每日报告")
content := fmt.Sprintf(`## 📊 企业认证系统每日报告
**报告日期**: %s
### 统计数据
- **新增申请**: %d
- **OCR识别成功**: %d
- **OCR识别失败**: %d
- **人脸识别成功**: %d
- **人脸识别失败**: %d
- **审核通过**: %d
- **审核拒绝**: %d
- **认证完成**: %d
### 系统状态
- **系统运行时间**: %s
- **API调用次数**: %d
- **错误次数**: %d
祝您工作愉快!`,
time.Now().Format("2006-01-02"),
reportData["new_applications"],
reportData["ocr_success"],
reportData["ocr_failed"],
reportData["face_verify_success"],
reportData["face_verify_failed"],
reportData["admin_approved"],
reportData["admin_rejected"],
reportData["certification_completed"],
reportData["uptime"],
reportData["api_calls"],
reportData["errors"])
return s.SendMarkdownMessage(ctx, content)
}

View File

@@ -0,0 +1,505 @@
package ocr
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"go.uber.org/zap"
"tyapi-server/internal/application/certification/dto/responses"
)
// BaiduOCRService 百度OCR服务
type BaiduOCRService struct {
apiKey string
secretKey string
endpoint string
timeout time.Duration
logger *zap.Logger
}
// NewBaiduOCRService 创建百度OCR服务
func NewBaiduOCRService(apiKey, secretKey string, logger *zap.Logger) *BaiduOCRService {
return &BaiduOCRService{
apiKey: apiKey,
secretKey: secretKey,
endpoint: "https://aip.baidubce.com",
timeout: 30 * time.Second,
logger: logger,
}
}
// RecognizeBusinessLicense 识别营业执照
func (s *BaiduOCRService) RecognizeBusinessLicense(ctx context.Context, imageBytes []byte) (*responses.BusinessLicenseResult, error) {
s.logger.Info("开始识别营业执照", zap.Int("image_size", len(imageBytes)))
// 获取访问令牌
accessToken, err := s.getAccessToken(ctx)
if err != nil {
return nil, fmt.Errorf("获取访问令牌失败: %w", err)
}
// 将图片转换为base64并进行URL编码
imageBase64 := base64.StdEncoding.EncodeToString(imageBytes)
imageBase64UrlEncoded := url.QueryEscape(imageBase64)
// 构建请求URL只包含access_token
apiURL := fmt.Sprintf("%s/rest/2.0/ocr/v1/business_license?access_token=%s", s.endpoint, accessToken)
// 构建POST请求体
payload := strings.NewReader(fmt.Sprintf("image=%s", imageBase64UrlEncoded))
resp, err := s.sendRequest(ctx, "POST", apiURL, payload)
if err != nil {
return nil, fmt.Errorf("营业执照识别请求失败: %w", err)
}
// 解析响应
var result map[string]interface{}
if err := json.Unmarshal(resp, &result); err != nil {
return nil, fmt.Errorf("解析响应失败: %w", err)
}
// 检查错误
if errCode, ok := result["error_code"].(float64); ok && errCode != 0 {
errorMsg := result["error_msg"].(string)
return nil, fmt.Errorf("OCR识别失败: %s", errorMsg)
}
// 解析识别结果
licenseResult := s.parseBusinessLicenseResult(result)
s.logger.Info("营业执照识别成功",
zap.String("company_name", licenseResult.CompanyName),
zap.String("legal_representative", licenseResult.LegalPersonName),
zap.String("registered_capital", licenseResult.RegisteredCapital),
)
return licenseResult, nil
}
// RecognizeIDCard 识别身份证
func (s *BaiduOCRService) RecognizeIDCard(ctx context.Context, imageBytes []byte, side string) (*responses.IDCardResult, error) {
s.logger.Info("开始识别身份证", zap.String("side", side), zap.Int("image_size", len(imageBytes)))
// 获取访问令牌
accessToken, err := s.getAccessToken(ctx)
if err != nil {
return nil, fmt.Errorf("获取访问令牌失败: %w", err)
}
// 将图片转换为base64并进行URL编码
imageBase64 := base64.StdEncoding.EncodeToString(imageBytes)
imageBase64UrlEncoded := url.QueryEscape(imageBase64)
// 构建请求URL只包含access_token
apiURL := fmt.Sprintf("%s/rest/2.0/ocr/v1/idcard?access_token=%s", s.endpoint, accessToken)
// 构建POST请求体
payload := strings.NewReader(fmt.Sprintf("image=%s&side=%s", imageBase64UrlEncoded, side))
resp, err := s.sendRequest(ctx, "POST", apiURL, payload)
if err != nil {
return nil, fmt.Errorf("身份证识别请求失败: %w", err)
}
// 解析响应
var result map[string]interface{}
if err := json.Unmarshal(resp, &result); err != nil {
return nil, fmt.Errorf("解析响应失败: %w", err)
}
// 检查错误
if errCode, ok := result["error_code"].(float64); ok && errCode != 0 {
errorMsg := result["error_msg"].(string)
return nil, fmt.Errorf("OCR识别失败: %s", errorMsg)
}
// 解析识别结果
idCardResult := s.parseIDCardResult(result, side)
s.logger.Info("身份证识别成功",
zap.String("name", idCardResult.Name),
zap.String("id_number", idCardResult.IDCardNumber),
zap.String("side", side),
)
return idCardResult, nil
}
// RecognizeGeneralText 通用文字识别
func (s *BaiduOCRService) RecognizeGeneralText(ctx context.Context, imageBytes []byte) (*responses.GeneralTextResult, error) {
s.logger.Info("开始通用文字识别", zap.Int("image_size", len(imageBytes)))
// 获取访问令牌
accessToken, err := s.getAccessToken(ctx)
if err != nil {
return nil, fmt.Errorf("获取访问令牌失败: %w", err)
}
// 将图片转换为base64并进行URL编码
imageBase64 := base64.StdEncoding.EncodeToString(imageBytes)
imageBase64UrlEncoded := url.QueryEscape(imageBase64)
// 构建请求URL只包含access_token
apiURL := fmt.Sprintf("%s/rest/2.0/ocr/v1/general_basic?access_token=%s", s.endpoint, accessToken)
// 构建POST请求体
payload := strings.NewReader(fmt.Sprintf("image=%s", imageBase64UrlEncoded))
resp, err := s.sendRequest(ctx, "POST", apiURL, payload)
if err != nil {
return nil, fmt.Errorf("通用文字识别请求失败: %w", err)
}
// 解析响应
var result map[string]interface{}
if err := json.Unmarshal(resp, &result); err != nil {
return nil, fmt.Errorf("解析响应失败: %w", err)
}
// 检查错误
if errCode, ok := result["error_code"].(float64); ok && errCode != 0 {
errorMsg := result["error_msg"].(string)
return nil, fmt.Errorf("OCR识别失败: %s", errorMsg)
}
// 解析识别结果
textResult := s.parseGeneralTextResult(result)
s.logger.Info("通用文字识别成功",
zap.Int("word_count", len(textResult.Words)),
zap.Float64("confidence", textResult.Confidence),
)
return textResult, nil
}
// RecognizeFromURL 从URL识别图片
func (s *BaiduOCRService) RecognizeFromURL(ctx context.Context, imageURL string, ocrType string) (interface{}, error) {
s.logger.Info("从URL识别图片", zap.String("url", imageURL), zap.String("type", ocrType))
// 下载图片
imageBytes, err := s.downloadImage(ctx, imageURL)
if err != nil {
s.logger.Error("下载图片失败", zap.Error(err))
return nil, fmt.Errorf("下载图片失败: %w", err)
}
// 根据类型调用相应的识别方法
switch ocrType {
case "business_license":
return s.RecognizeBusinessLicense(ctx, imageBytes)
case "idcard_front":
return s.RecognizeIDCard(ctx, imageBytes, "front")
case "idcard_back":
return s.RecognizeIDCard(ctx, imageBytes, "back")
case "general_text":
return s.RecognizeGeneralText(ctx, imageBytes)
default:
return nil, fmt.Errorf("不支持的OCR类型: %s", ocrType)
}
}
// getAccessToken 获取百度API访问令牌
func (s *BaiduOCRService) getAccessToken(ctx context.Context) (string, error) {
// 构建获取访问令牌的URL
tokenURL := fmt.Sprintf("%s/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s",
s.endpoint, s.apiKey, s.secretKey)
// 发送请求
resp, err := s.sendRequest(ctx, "POST", tokenURL, nil)
if err != nil {
return "", fmt.Errorf("获取访问令牌请求失败: %w", err)
}
// 解析响应
var result map[string]interface{}
if err := json.Unmarshal(resp, &result); err != nil {
return "", fmt.Errorf("解析访问令牌响应失败: %w", err)
}
// 检查错误
if errCode, ok := result["error"].(string); ok && errCode != "" {
errorDesc := result["error_description"].(string)
return "", fmt.Errorf("获取访问令牌失败: %s - %s", errCode, errorDesc)
}
// 提取访问令牌
accessToken, ok := result["access_token"].(string)
if !ok {
return "", fmt.Errorf("响应中未找到访问令牌")
}
return accessToken, nil
}
// sendRequest 发送HTTP请求
func (s *BaiduOCRService) sendRequest(ctx context.Context, method, url string, body io.Reader) ([]byte, error) {
// 创建HTTP客户端
client := &http.Client{
Timeout: s.timeout,
}
// 创建请求
req, err := http.NewRequestWithContext(ctx, method, url, body)
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
// 设置请求头
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("User-Agent", "tyapi-server/1.0")
// 发送请求
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("发送请求失败: %w", err)
}
defer resp.Body.Close()
// 检查响应状态
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("请求失败,状态码: %d", resp.StatusCode)
}
// 读取响应内容
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("读取响应内容失败: %w", err)
}
return responseBody, nil
}
// parseBusinessLicenseResult 解析营业执照识别结果
func (s *BaiduOCRService) parseBusinessLicenseResult(result map[string]interface{}) *responses.BusinessLicenseResult {
wordsResult := result["words_result"].(map[string]interface{})
// 提取企业信息
companyName := ""
if companyNameObj, ok := wordsResult["单位名称"].(map[string]interface{}); ok {
companyName = companyNameObj["words"].(string)
}
unifiedSocialCode := ""
if socialCreditCodeObj, ok := wordsResult["社会信用代码"].(map[string]interface{}); ok {
unifiedSocialCode = socialCreditCodeObj["words"].(string)
}
legalPersonName := ""
if legalPersonObj, ok := wordsResult["法人"].(map[string]interface{}); ok {
legalPersonName = legalPersonObj["words"].(string)
}
// 提取注册资本等其他信息
registeredCapital := ""
if registeredCapitalObj, ok := wordsResult["注册资本"].(map[string]interface{}); ok {
registeredCapital = registeredCapitalObj["words"].(string)
}
// 计算置信度这里简化处理实际应该从OCR结果中获取
confidence := 0.9 // 默认置信度
return &responses.BusinessLicenseResult{
CompanyName: companyName,
UnifiedSocialCode: unifiedSocialCode,
LegalPersonName: legalPersonName,
RegisteredCapital: registeredCapital,
Confidence: confidence,
}
}
// parseIDCardResult 解析身份证识别结果
func (s *BaiduOCRService) parseIDCardResult(result map[string]interface{}, side string) *responses.IDCardResult {
wordsResult := result["words_result"].(map[string]interface{})
idCardResult := &responses.IDCardResult{
Side: side,
Confidence: s.extractConfidence(result),
}
if side == "front" {
if name, ok := wordsResult["姓名"]; ok {
if word, ok := name.(map[string]interface{}); ok {
idCardResult.Name = word["words"].(string)
}
}
if gender, ok := wordsResult["性别"]; ok {
if word, ok := gender.(map[string]interface{}); ok {
idCardResult.Gender = word["words"].(string)
}
}
if nation, ok := wordsResult["民族"]; ok {
if word, ok := nation.(map[string]interface{}); ok {
idCardResult.Nation = word["words"].(string)
}
}
if birthday, ok := wordsResult["出生"]; ok {
if word, ok := birthday.(map[string]interface{}); ok {
idCardResult.Birthday = word["words"].(string)
}
}
if address, ok := wordsResult["住址"]; ok {
if word, ok := address.(map[string]interface{}); ok {
idCardResult.Address = word["words"].(string)
}
}
if idNumber, ok := wordsResult["公民身份号码"]; ok {
if word, ok := idNumber.(map[string]interface{}); ok {
idCardResult.IDCardNumber = word["words"].(string)
}
}
} else {
if issuingAgency, ok := wordsResult["签发机关"]; ok {
if word, ok := issuingAgency.(map[string]interface{}); ok {
idCardResult.IssuingAgency = word["words"].(string)
}
}
if validPeriod, ok := wordsResult["有效期限"]; ok {
if word, ok := validPeriod.(map[string]interface{}); ok {
idCardResult.ValidPeriod = word["words"].(string)
}
}
}
return idCardResult
}
// parseGeneralTextResult 解析通用文字识别结果
func (s *BaiduOCRService) parseGeneralTextResult(result map[string]interface{}) *responses.GeneralTextResult {
wordsResult := result["words_result"].([]interface{})
textResult := &responses.GeneralTextResult{
Confidence: s.extractConfidence(result),
Words: make([]responses.TextLine, 0, len(wordsResult)),
}
for _, word := range wordsResult {
if wordMap, ok := word.(map[string]interface{}); ok {
line := responses.TextLine{
Text: wordMap["words"].(string),
Confidence: 1.0, // 百度返回的通用文字识别没有单独置信度
}
textResult.Words = append(textResult.Words, line)
}
}
return textResult
}
// extractConfidence 提取置信度
func (s *BaiduOCRService) extractConfidence(result map[string]interface{}) float64 {
if confidence, ok := result["confidence"].(float64); ok {
return confidence
}
return 0.0
}
// extractWords 提取识别的文字
func (s *BaiduOCRService) extractWords(result map[string]interface{}) []string {
words := make([]string, 0)
if wordsResult, ok := result["words_result"]; ok {
switch v := wordsResult.(type) {
case map[string]interface{}:
// 营业执照等结构化文档
for _, word := range v {
if wordMap, ok := word.(map[string]interface{}); ok {
if wordsStr, ok := wordMap["words"].(string); ok {
words = append(words, wordsStr)
}
}
}
case []interface{}:
// 通用文字识别
for _, word := range v {
if wordMap, ok := word.(map[string]interface{}); ok {
if wordsStr, ok := wordMap["words"].(string); ok {
words = append(words, wordsStr)
}
}
}
}
}
return words
}
// downloadImage 下载图片
func (s *BaiduOCRService) downloadImage(ctx context.Context, imageURL string) ([]byte, error) {
// 创建HTTP客户端
client := &http.Client{
Timeout: 30 * time.Second,
}
// 创建请求
req, err := http.NewRequestWithContext(ctx, "GET", imageURL, nil)
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
// 发送请求
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("下载图片失败: %w", err)
}
defer resp.Body.Close()
// 检查响应状态
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("下载图片失败,状态码: %d", resp.StatusCode)
}
// 读取响应内容
imageBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("读取图片内容失败: %w", err)
}
return imageBytes, nil
}
// ValidateBusinessLicense 验证营业执照识别结果
func (s *BaiduOCRService) ValidateBusinessLicense(result *responses.BusinessLicenseResult) error {
if result.Confidence < 0.8 {
return fmt.Errorf("识别置信度过低: %.2f", result.Confidence)
}
if result.CompanyName == "" {
return fmt.Errorf("未能识别公司名称")
}
if result.LegalPersonName == "" {
return fmt.Errorf("未能识别法定代表人")
}
if result.UnifiedSocialCode == "" {
return fmt.Errorf("未能识别统一社会信用代码")
}
return nil
}
// ValidateIDCard 验证身份证识别结果
func (s *BaiduOCRService) ValidateIDCard(result *responses.IDCardResult) error {
if result.Confidence < 0.8 {
return fmt.Errorf("识别置信度过低: %.2f", result.Confidence)
}
if result.Side == "front" {
if result.Name == "" {
return fmt.Errorf("未能识别姓名")
}
if result.IDCardNumber == "" {
return fmt.Errorf("未能识别身份证号码")
}
} else {
if result.IssuingAgency == "" {
return fmt.Errorf("未能识别签发机关")
}
if result.ValidPeriod == "" {
return fmt.Errorf("未能识别有效期限")
}
}
return nil
}

View File

@@ -0,0 +1,123 @@
package sms
import (
"context"
"crypto/rand"
"fmt"
"math/big"
"github.com/aliyun/alibaba-cloud-sdk-go/services/dysmsapi"
"go.uber.org/zap"
"tyapi-server/internal/config"
)
// AliSMSService 阿里云短信服务
type AliSMSService struct {
client *dysmsapi.Client
config config.SMSConfig
logger *zap.Logger
}
// NewAliSMSService 创建阿里云短信服务
func NewAliSMSService(cfg config.SMSConfig, logger *zap.Logger) (*AliSMSService, error) {
client, err := dysmsapi.NewClientWithAccessKey("cn-hangzhou", cfg.AccessKeyID, cfg.AccessKeySecret)
if err != nil {
return nil, fmt.Errorf("创建短信客户端失败: %w", err)
}
return &AliSMSService{
client: client,
config: cfg,
logger: logger,
}, nil
}
// SendVerificationCode 发送验证码
func (s *AliSMSService) SendVerificationCode(ctx context.Context, phone string, code string) error {
request := dysmsapi.CreateSendSmsRequest()
request.Scheme = "https"
request.PhoneNumbers = phone
request.SignName = s.config.SignName
request.TemplateCode = s.config.TemplateCode
request.TemplateParam = fmt.Sprintf(`{"code":"%s"}`, code)
response, err := s.client.SendSms(request)
if err != nil {
s.logger.Error("Failed to send SMS",
zap.String("phone", phone),
zap.Error(err))
return fmt.Errorf("短信发送失败: %w", err)
}
if response.Code != "OK" {
s.logger.Error("SMS send failed",
zap.String("phone", phone),
zap.String("code", response.Code),
zap.String("message", response.Message))
return fmt.Errorf("短信发送失败: %s - %s", response.Code, response.Message)
}
s.logger.Info("SMS sent successfully",
zap.String("phone", phone),
zap.String("bizId", response.BizId))
return nil
}
// GenerateCode 生成验证码
func (s *AliSMSService) GenerateCode(length int) string {
if length <= 0 {
length = 6
}
// 生成指定长度的数字验证码
max := big.NewInt(int64(pow10(length)))
n, _ := rand.Int(rand.Reader, max)
// 格式化为指定长度不足时前面补0
format := fmt.Sprintf("%%0%dd", length)
return fmt.Sprintf(format, n.Int64())
}
// pow10 计算10的n次方
func pow10(n int) int {
result := 1
for i := 0; i < n; i++ {
result *= 10
}
return result
}
// MockSMSService 模拟短信服务(用于开发和测试)
type MockSMSService struct {
logger *zap.Logger
}
// NewMockSMSService 创建模拟短信服务
func NewMockSMSService(logger *zap.Logger) *MockSMSService {
return &MockSMSService{
logger: logger,
}
}
// SendVerificationCode 模拟发送验证码
func (s *MockSMSService) SendVerificationCode(ctx context.Context, phone string, code string) error {
s.logger.Info("Mock SMS sent",
zap.String("phone", phone),
zap.String("code", code))
return nil
}
// GenerateCode 生成验证码
func (s *MockSMSService) GenerateCode(length int) string {
if length <= 0 {
length = 6
}
// 开发环境使用固定验证码便于测试
result := ""
for i := 0; i < length; i++ {
result += "1"
}
return result
}

View File

@@ -0,0 +1,281 @@
package storage
import (
"context"
"crypto/hmac"
"crypto/sha1"
"encoding/base64"
"fmt"
"io"
"path/filepath"
"strings"
"time"
"github.com/qiniu/go-sdk/v7/auth/qbox"
"github.com/qiniu/go-sdk/v7/storage"
"go.uber.org/zap"
sharedStorage "tyapi-server/internal/shared/storage"
)
// QiNiuStorageService 七牛云存储服务
type QiNiuStorageService struct {
accessKey string
secretKey string
bucket string
domain string
logger *zap.Logger
mac *qbox.Mac
bucketManager *storage.BucketManager
}
// QiNiuStorageConfig 七牛云存储配置
type QiNiuStorageConfig struct {
AccessKey string `yaml:"access_key"`
SecretKey string `yaml:"secret_key"`
Bucket string `yaml:"bucket"`
Domain string `yaml:"domain"`
}
// NewQiNiuStorageService 创建七牛云存储服务
func NewQiNiuStorageService(accessKey, secretKey, bucket, domain string, logger *zap.Logger) *QiNiuStorageService {
mac := qbox.NewMac(accessKey, secretKey)
// 使用默认配置不需要指定region
cfg := storage.Config{}
bucketManager := storage.NewBucketManager(mac, &cfg)
return &QiNiuStorageService{
accessKey: accessKey,
secretKey: secretKey,
bucket: bucket,
domain: domain,
logger: logger,
mac: mac,
bucketManager: bucketManager,
}
}
// UploadFile 上传文件到七牛云
func (s *QiNiuStorageService) UploadFile(ctx context.Context, fileBytes []byte, fileName string) (*sharedStorage.UploadResult, error) {
s.logger.Info("开始上传文件到七牛云",
zap.String("file_name", fileName),
zap.Int("file_size", len(fileBytes)),
)
// 生成唯一的文件key
key := s.generateFileKey(fileName)
// 创建上传凭证
putPolicy := storage.PutPolicy{
Scope: s.bucket,
}
upToken := putPolicy.UploadToken(s.mac)
// 配置上传参数
cfg := storage.Config{}
formUploader := storage.NewFormUploader(&cfg)
ret := storage.PutRet{}
// 上传文件
err := formUploader.Put(ctx, &ret, upToken, key, strings.NewReader(string(fileBytes)), int64(len(fileBytes)), &storage.PutExtra{})
if err != nil {
s.logger.Error("文件上传失败",
zap.String("file_name", fileName),
zap.String("key", key),
zap.Error(err),
)
return nil, fmt.Errorf("文件上传失败: %w", err)
}
// 构建文件URL
fileURL := s.GetFileURL(ctx, key)
s.logger.Info("文件上传成功",
zap.String("file_name", fileName),
zap.String("key", key),
zap.String("url", fileURL),
)
return &sharedStorage.UploadResult{
Key: key,
URL: fileURL,
MimeType: s.getMimeType(fileName),
Size: int64(len(fileBytes)),
Hash: ret.Hash,
}, nil
}
// GenerateUploadToken 生成上传凭证
func (s *QiNiuStorageService) GenerateUploadToken(ctx context.Context, key string) (string, error) {
putPolicy := storage.PutPolicy{
Scope: s.bucket,
// 设置过期时间1小时
Expires: uint64(time.Now().Add(time.Hour).Unix()),
}
token := putPolicy.UploadToken(s.mac)
return token, nil
}
// GetFileURL 获取文件访问URL
func (s *QiNiuStorageService) GetFileURL(ctx context.Context, key string) string {
// 如果是私有空间需要生成带签名的URL
if s.isPrivateBucket() {
deadline := time.Now().Add(time.Hour).Unix() // 1小时过期
privateAccessURL := storage.MakePrivateURL(s.mac, s.domain, key, deadline)
return privateAccessURL
}
// 公开空间直接返回URL
return fmt.Sprintf("%s/%s", s.domain, key)
}
// GetPrivateFileURL 获取私有文件访问URL
func (s *QiNiuStorageService) GetPrivateFileURL(ctx context.Context, key string, expires int64) (string, error) {
baseURL := s.GetFileURL(ctx, key)
// TODO: 实际集成七牛云SDK生成私有URL
s.logger.Info("生成七牛云私有文件URL",
zap.String("key", key),
zap.Int64("expires", expires),
)
// 模拟返回私有URL
return fmt.Sprintf("%s?token=mock_private_token&expires=%d", baseURL, expires), nil
}
// DeleteFile 删除文件
func (s *QiNiuStorageService) DeleteFile(ctx context.Context, key string) error {
s.logger.Info("删除七牛云文件", zap.String("key", key))
err := s.bucketManager.Delete(s.bucket, key)
if err != nil {
s.logger.Error("删除文件失败",
zap.String("key", key),
zap.Error(err),
)
return fmt.Errorf("删除文件失败: %w", err)
}
s.logger.Info("文件删除成功", zap.String("key", key))
return nil
}
// FileExists 检查文件是否存在
func (s *QiNiuStorageService) FileExists(ctx context.Context, key string) (bool, error) {
// TODO: 实际集成七牛云SDK检查文件存在性
s.logger.Info("检查七牛云文件存在性", zap.String("key", key))
// 模拟文件存在
return true, nil
}
// GetFileInfo 获取文件信息
func (s *QiNiuStorageService) GetFileInfo(ctx context.Context, key string) (*sharedStorage.FileInfo, error) {
fileInfo, err := s.bucketManager.Stat(s.bucket, key)
if err != nil {
s.logger.Error("获取文件信息失败",
zap.String("key", key),
zap.Error(err),
)
return nil, fmt.Errorf("获取文件信息失败: %w", err)
}
return &sharedStorage.FileInfo{
Key: key,
Size: fileInfo.Fsize,
MimeType: fileInfo.MimeType,
Hash: fileInfo.Hash,
PutTime: fileInfo.PutTime,
}, nil
}
// ListFiles 列出文件
func (s *QiNiuStorageService) ListFiles(ctx context.Context, prefix string, limit int) ([]*sharedStorage.FileInfo, error) {
entries, _, _, hasMore, err := s.bucketManager.ListFiles(s.bucket, prefix, "", "", limit)
if err != nil {
s.logger.Error("列出文件失败",
zap.String("prefix", prefix),
zap.Error(err),
)
return nil, fmt.Errorf("列出文件失败: %w", err)
}
var fileInfos []*sharedStorage.FileInfo
for _, entry := range entries {
fileInfo := &sharedStorage.FileInfo{
Key: entry.Key,
Size: entry.Fsize,
MimeType: entry.MimeType,
Hash: entry.Hash,
PutTime: entry.PutTime,
}
fileInfos = append(fileInfos, fileInfo)
}
_ = hasMore // 暂时忽略hasMore
return fileInfos, nil
}
// generateFileKey 生成文件key
func (s *QiNiuStorageService) generateFileKey(fileName string) string {
// 生成时间戳
timestamp := time.Now().Format("20060102_150405")
// 生成随机字符串
randomStr := fmt.Sprintf("%d", time.Now().UnixNano()%1000000)
// 获取文件扩展名
ext := filepath.Ext(fileName)
// 构建key: 日期/时间戳_随机数.扩展名
key := fmt.Sprintf("certification/%s/%s_%s%s",
time.Now().Format("20060102"), timestamp, randomStr, ext)
return key
}
// getMimeType 根据文件名获取MIME类型
func (s *QiNiuStorageService) getMimeType(fileName string) string {
ext := strings.ToLower(filepath.Ext(fileName))
switch ext {
case ".jpg", ".jpeg":
return "image/jpeg"
case ".png":
return "image/png"
case ".pdf":
return "application/pdf"
case ".gif":
return "image/gif"
case ".bmp":
return "image/bmp"
case ".webp":
return "image/webp"
default:
return "application/octet-stream"
}
}
// isPrivateBucket 判断是否为私有空间
func (s *QiNiuStorageService) isPrivateBucket() bool {
// 这里可以根据配置或域名特征判断
// 私有空间的域名通常包含特定标识
return strings.Contains(s.domain, "private") ||
strings.Contains(s.domain, "auth") ||
strings.Contains(s.domain, "secure")
}
// generateSignature 生成签名(用于私有空间访问)
func (s *QiNiuStorageService) generateSignature(data string) string {
h := hmac.New(sha1.New, []byte(s.secretKey))
h.Write([]byte(data))
return base64.URLEncoding.EncodeToString(h.Sum(nil))
}
// UploadFromReader 从Reader上传文件
func (s *QiNiuStorageService) UploadFromReader(ctx context.Context, reader io.Reader, fileName string, fileSize int64) (*sharedStorage.UploadResult, error) {
// 读取文件内容
fileBytes, err := io.ReadAll(reader)
if err != nil {
return nil, fmt.Errorf("读取文件失败: %w", err)
}
return s.UploadFile(ctx, fileBytes, fileName)
}

View File

@@ -0,0 +1,280 @@
package handlers
import (
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"tyapi-server/internal/application/admin"
"tyapi-server/internal/application/admin/dto/commands"
"tyapi-server/internal/application/admin/dto/queries"
"tyapi-server/internal/shared/interfaces"
)
// AdminHandler 管理员HTTP处理器
type AdminHandler struct {
appService admin.AdminApplicationService
responseBuilder interfaces.ResponseBuilder
logger *zap.Logger
}
// NewAdminHandler 创建管理员HTTP处理器
func NewAdminHandler(
appService admin.AdminApplicationService,
responseBuilder interfaces.ResponseBuilder,
logger *zap.Logger,
) *AdminHandler {
return &AdminHandler{
appService: appService,
responseBuilder: responseBuilder,
logger: logger,
}
}
// Login 管理员登录
// @Summary 管理员登录
// @Description 使用用户名和密码进行管理员登录返回JWT令牌
// @Tags 管理员认证
// @Accept json
// @Produce json
// @Param request body commands.AdminLoginCommand true "管理员登录请求"
// @Success 200 {object} responses.AdminLoginResponse "登录成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "用户名或密码错误"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/admin/auth/login [post]
func (h *AdminHandler) Login(c *gin.Context) {
var cmd commands.AdminLoginCommand
if err := c.ShouldBindJSON(&cmd); err != nil {
h.responseBuilder.BadRequest(c, "请求参数错误")
return
}
response, err := h.appService.Login(c.Request.Context(), &cmd)
if err != nil {
h.logger.Error("管理员登录失败", zap.Error(err))
h.responseBuilder.Unauthorized(c, err.Error())
return
}
h.responseBuilder.Success(c, response, "登录成功")
}
// CreateAdmin 创建管理员
// @Summary 创建管理员
// @Description 创建新的管理员账户,需要超级管理员权限
// @Tags 管理员管理
// @Accept json
// @Produce json
// @Security Bearer
// @Param request body commands.CreateAdminCommand true "创建管理员请求"
// @Success 201 {object} map[string]interface{} "管理员创建成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 403 {object} map[string]interface{} "权限不足"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/admin [post]
func (h *AdminHandler) CreateAdmin(c *gin.Context) {
var cmd commands.CreateAdminCommand
if err := c.ShouldBindJSON(&cmd); err != nil {
h.responseBuilder.BadRequest(c, "请求参数错误")
return
}
cmd.OperatorID = h.getCurrentAdminID(c)
if err := h.appService.CreateAdmin(c.Request.Context(), &cmd); err != nil {
h.logger.Error("创建管理员失败", zap.Error(err))
h.responseBuilder.BadRequest(c, err.Error())
return
}
h.responseBuilder.Created(c, nil, "管理员创建成功")
}
// UpdateAdmin 更新管理员
// @Summary 更新管理员信息
// @Description 更新指定管理员的基本信息
// @Tags 管理员管理
// @Accept json
// @Produce json
// @Security Bearer
// @Param id path string true "管理员ID"
// @Param request body commands.UpdateAdminCommand true "更新管理员请求"
// @Success 200 {object} map[string]interface{} "管理员更新成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 403 {object} map[string]interface{} "权限不足"
// @Failure 404 {object} map[string]interface{} "管理员不存在"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/admin/{id} [put]
func (h *AdminHandler) UpdateAdmin(c *gin.Context) {
var cmd commands.UpdateAdminCommand
if err := c.ShouldBindJSON(&cmd); err != nil {
h.responseBuilder.BadRequest(c, "请求参数错误")
return
}
cmd.AdminID = c.Param("id")
cmd.OperatorID = h.getCurrentAdminID(c)
if err := h.appService.UpdateAdmin(c.Request.Context(), &cmd); err != nil {
h.logger.Error("更新管理员失败", zap.Error(err))
h.responseBuilder.BadRequest(c, err.Error())
return
}
h.responseBuilder.Success(c, nil, "管理员更新成功")
}
// ChangePassword 修改密码
// @Summary 修改管理员密码
// @Description 修改当前登录管理员的密码
// @Tags 管理员管理
// @Accept json
// @Produce json
// @Security Bearer
// @Param request body commands.ChangeAdminPasswordCommand true "修改密码请求"
// @Success 200 {object} map[string]interface{} "密码修改成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/admin/change-password [post]
func (h *AdminHandler) ChangePassword(c *gin.Context) {
var cmd commands.ChangeAdminPasswordCommand
if err := c.ShouldBindJSON(&cmd); err != nil {
h.responseBuilder.BadRequest(c, "请求参数错误")
return
}
cmd.AdminID = h.getCurrentAdminID(c)
if err := h.appService.ChangePassword(c.Request.Context(), &cmd); err != nil {
h.logger.Error("修改密码失败", zap.Error(err))
h.responseBuilder.BadRequest(c, err.Error())
return
}
h.responseBuilder.Success(c, nil, "密码修改成功")
}
// ListAdmins 获取管理员列表
// @Summary 获取管理员列表
// @Description 分页获取管理员列表,支持搜索和筛选
// @Tags 管理员管理
// @Accept json
// @Produce json
// @Security Bearer
// @Param page query int false "页码" default(1)
// @Param size query int false "每页数量" default(10)
// @Param keyword query string false "搜索关键词"
// @Param status query string false "状态筛选"
// @Success 200 {object} responses.AdminListResponse "获取管理员列表成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/admin [get]
func (h *AdminHandler) ListAdmins(c *gin.Context) {
var query queries.ListAdminsQuery
if err := c.ShouldBindQuery(&query); err != nil {
h.responseBuilder.BadRequest(c, "请求参数错误")
return
}
response, err := h.appService.ListAdmins(c.Request.Context(), &query)
if err != nil {
h.logger.Error("获取管理员列表失败", zap.Error(err))
h.responseBuilder.InternalError(c, "获取管理员列表失败")
return
}
h.responseBuilder.Success(c, response, "获取管理员列表成功")
}
// GetAdminByID 根据ID获取管理员
// @Summary 获取管理员详情
// @Description 根据管理员ID获取详细信息
// @Tags 管理员管理
// @Accept json
// @Produce json
// @Security Bearer
// @Param id path string true "管理员ID"
// @Success 200 {object} responses.AdminInfoResponse "获取管理员详情成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 404 {object} map[string]interface{} "管理员不存在"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/admin/{id} [get]
func (h *AdminHandler) GetAdminByID(c *gin.Context) {
var query queries.GetAdminInfoQuery
if err := c.ShouldBindUri(&query); err != nil {
h.responseBuilder.BadRequest(c, "请求参数错误")
return
}
admin, err := h.appService.GetAdminByID(c.Request.Context(), &query)
if err != nil {
h.logger.Error("获取管理员详情失败", zap.Error(err))
h.responseBuilder.NotFound(c, err.Error())
return
}
h.responseBuilder.Success(c, admin, "获取管理员详情成功")
}
// DeleteAdmin 删除管理员
// @Summary 删除管理员
// @Description 删除指定的管理员账户
// @Tags 管理员管理
// @Accept json
// @Produce json
// @Security Bearer
// @Param id path string true "管理员ID"
// @Success 200 {object} map[string]interface{} "管理员删除成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 403 {object} map[string]interface{} "权限不足"
// @Failure 404 {object} map[string]interface{} "管理员不存在"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/admin/{id} [delete]
func (h *AdminHandler) DeleteAdmin(c *gin.Context) {
var cmd commands.DeleteAdminCommand
cmd.AdminID = c.Param("id")
cmd.OperatorID = h.getCurrentAdminID(c)
if err := h.appService.DeleteAdmin(c.Request.Context(), &cmd); err != nil {
h.logger.Error("删除管理员失败", zap.Error(err))
h.responseBuilder.BadRequest(c, err.Error())
return
}
h.responseBuilder.Success(c, nil, "管理员删除成功")
}
// GetAdminStats 获取管理员统计信息
// @Summary 获取管理员统计信息
// @Description 获取管理员相关的统计数据
// @Tags 管理员管理
// @Accept json
// @Produce json
// @Security Bearer
// @Success 200 {object} responses.AdminStatsResponse "获取统计信息成功"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/admin/stats [get]
func (h *AdminHandler) GetAdminStats(c *gin.Context) {
stats, err := h.appService.GetAdminStats(c.Request.Context())
if err != nil {
h.logger.Error("获取管理员统计失败", zap.Error(err))
h.responseBuilder.InternalError(c, "获取统计信息失败")
return
}
h.responseBuilder.Success(c, stats, "获取统计信息成功")
}
// getCurrentAdminID 获取当前管理员ID
func (h *AdminHandler) getCurrentAdminID(c *gin.Context) string {
if userID, exists := c.Get("user_id"); exists {
if id, ok := userID.(string); ok {
return id
}
}
return ""
}

View File

@@ -0,0 +1,472 @@
package handlers
import (
"io"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"tyapi-server/internal/application/certification"
"tyapi-server/internal/application/certification/dto/commands"
"tyapi-server/internal/application/certification/dto/queries"
"tyapi-server/internal/shared/interfaces"
)
// CertificationHandler 认证处理器
type CertificationHandler struct {
appService certification.CertificationApplicationService
response interfaces.ResponseBuilder
logger *zap.Logger
}
// NewCertificationHandler 创建认证处理器
func NewCertificationHandler(
appService certification.CertificationApplicationService,
response interfaces.ResponseBuilder,
logger *zap.Logger,
) *CertificationHandler {
return &CertificationHandler{
appService: appService,
response: response,
logger: logger,
}
}
// CreateCertification 创建认证申请
// @Summary 创建认证申请
// @Description 为用户创建新的企业认证申请
// @Tags 企业认证
// @Accept json
// @Produce json
// @Security Bearer
// @Success 200 {object} responses.CertificationResponse "认证申请创建成功"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/certification [post]
func (h *CertificationHandler) CreateCertification(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
h.response.Unauthorized(c, "用户未认证")
return
}
cmd := &commands.CreateCertificationCommand{UserID: userID}
result, err := h.appService.CreateCertification(c.Request.Context(), cmd)
if err != nil {
h.logger.Error("创建认证申请失败",
zap.String("user_id", userID),
zap.Error(err),
)
h.response.InternalError(c, "创建认证申请失败")
return
}
h.response.Success(c, result, "认证申请创建成功")
}
// UploadBusinessLicense 上传营业执照并同步OCR识别
// @Summary 上传营业执照并同步OCR识别
// @Description 上传营业执照文件立即进行OCR识别并返回结果
// @Tags 企业认证
// @Accept multipart/form-data
// @Produce json
// @Param file formData file true "营业执照文件"
// @Security Bearer
// @Success 200 {object} responses.UploadLicenseResponse "上传成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未授权"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/certification/upload-license [post]
func (h *CertificationHandler) UploadBusinessLicense(c *gin.Context) {
// 获取当前用户ID
userID, exists := c.Get("user_id")
if !exists {
h.response.Unauthorized(c, "用户未认证")
return
}
// 获取上传的文件
file, err := c.FormFile("file")
if err != nil {
h.response.BadRequest(c, "文件上传失败")
return
}
// 读取文件内容
openedFile, err := file.Open()
if err != nil {
h.response.BadRequest(c, "无法读取文件")
return
}
defer openedFile.Close()
fileBytes, err := io.ReadAll(openedFile)
if err != nil {
h.response.BadRequest(c, "文件读取失败")
return
}
// 调用应用服务
response, err := h.appService.UploadBusinessLicense(c.Request.Context(), userID.(string), fileBytes, file.Filename)
if err != nil {
h.logger.Error("营业执照上传失败", zap.Error(err))
h.response.InternalError(c, "营业执照上传失败")
return
}
h.response.Success(c, response, "营业执照上传成功")
}
// GetCertificationStatus 获取认证状态
// @Summary 获取认证状态
// @Description 获取当前用户的认证申请状态
// @Tags 企业认证
// @Accept json
// @Produce json
// @Security Bearer
// @Success 200 {object} responses.CertificationResponse "获取认证状态成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/certification/status [get]
func (h *CertificationHandler) GetCertificationStatus(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
h.response.Unauthorized(c, "用户未认证")
return
}
query := &queries.GetCertificationStatusQuery{UserID: userID}
result, err := h.appService.GetCertificationStatus(c.Request.Context(), query)
if err != nil {
h.logger.Error("获取认证状态失败",
zap.String("user_id", userID),
zap.Error(err),
)
h.response.BadRequest(c, err.Error())
return
}
h.response.Success(c, result, "获取认证状态成功")
}
// GetProgressStats 获取进度统计
// @Summary 获取进度统计
// @Description 获取认证申请的进度统计数据
// @Tags 企业认证
// @Accept json
// @Produce json
// @Security Bearer
// @Success 200 {object} map[string]interface{} "获取进度统计成功"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/certification/stats [get]
func (h *CertificationHandler) GetProgressStats(c *gin.Context) {
// 这里应该实现获取进度统计的逻辑
// 暂时返回空数据
h.response.Success(c, map[string]interface{}{
"total_applications": 0,
"pending": 0,
"in_progress": 0,
"completed": 0,
"rejected": 0,
}, "获取进度统计成功")
}
// GetCertificationProgress 获取认证进度
// @Summary 获取认证进度
// @Description 获取当前用户的认证申请详细进度信息
// @Tags 企业认证
// @Accept json
// @Produce json
// @Security Bearer
// @Success 200 {object} map[string]interface{} "获取认证进度成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 404 {object} map[string]interface{} "认证申请不存在"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/certification/progress [get]
func (h *CertificationHandler) GetCertificationProgress(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
h.response.Unauthorized(c, "用户未认证")
return
}
result, err := h.appService.GetCertificationProgress(c.Request.Context(), userID)
if err != nil {
h.logger.Error("获取认证进度失败",
zap.String("user_id", userID),
zap.Error(err),
)
h.response.BadRequest(c, err.Error())
return
}
h.response.Success(c, result, "获取认证进度成功")
}
// SubmitEnterpriseInfo 提交企业信息
// @Summary 提交企业信息
// @Description 提交企业基本信息,包括企业名称、统一社会信用代码、法定代表人信息等
// @Tags 企业认证
// @Accept json
// @Produce json
// @Security Bearer
// @Param request body commands.SubmitEnterpriseInfoCommand true "企业信息"
// @Success 200 {object} responses.CertificationResponse "企业信息提交成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/certification/enterprise-info [post]
func (h *CertificationHandler) SubmitEnterpriseInfo(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
h.response.Unauthorized(c, "用户未认证")
return
}
var cmd commands.SubmitEnterpriseInfoCommand
if err := c.ShouldBindJSON(&cmd); err != nil {
h.logger.Error("参数绑定失败", zap.Error(err))
h.response.BadRequest(c, "请求参数格式错误")
return
}
cmd.UserID = userID
result, err := h.appService.SubmitEnterpriseInfo(c.Request.Context(), &cmd)
if err != nil {
h.logger.Error("提交企业信息失败",
zap.String("user_id", userID),
zap.Error(err),
)
h.response.BadRequest(c, err.Error())
return
}
h.response.Success(c, result, "企业信息提交成功")
}
// InitiateFaceVerify 发起人脸验证
// @Summary 发起人脸验证
// @Description 发起企业法人人脸验证流程
// @Tags 企业认证
// @Accept json
// @Produce json
// @Security Bearer
// @Param request body commands.InitiateFaceVerifyCommand true "人脸验证请求"
// @Success 200 {object} responses.FaceVerifyResponse "人脸验证发起成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/certification/face-verify [post]
func (h *CertificationHandler) InitiateFaceVerify(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
h.response.Unauthorized(c, "用户未认证")
return
}
var cmd commands.InitiateFaceVerifyCommand
if err := c.ShouldBindJSON(&cmd); err != nil {
h.logger.Error("参数绑定失败", zap.Error(err))
h.response.BadRequest(c, "请求参数格式错误")
return
}
// 根据用户ID获取认证申请
query := &queries.GetCertificationStatusQuery{UserID: userID}
certification, err := h.appService.GetCertificationStatus(c.Request.Context(), query)
if err != nil {
h.logger.Error("获取认证申请失败",
zap.String("user_id", userID),
zap.Error(err),
)
h.response.BadRequest(c, err.Error())
return
}
// 如果用户没有认证申请,返回错误
if certification.ID == "" {
h.response.BadRequest(c, "用户尚未创建认证申请")
return
}
cmd.CertificationID = certification.ID
result, err := h.appService.InitiateFaceVerify(c.Request.Context(), &cmd)
if err != nil {
h.logger.Error("发起人脸验证失败",
zap.String("certification_id", certification.ID),
zap.String("user_id", userID),
zap.Error(err),
)
h.response.BadRequest(c, err.Error())
return
}
h.response.Success(c, result, "人脸验证发起成功")
}
// ApplyContract 申请合同
// @Summary 申请合同
// @Description 申请企业认证合同
// @Tags 企业认证
// @Accept json
// @Produce json
// @Security Bearer
// @Success 200 {object} responses.CertificationResponse "合同申请成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/certification/contract [post]
func (h *CertificationHandler) ApplyContract(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
h.response.Unauthorized(c, "用户未认证")
return
}
result, err := h.appService.ApplyContract(c.Request.Context(), userID)
if err != nil {
h.logger.Error("申请合同失败",
zap.String("user_id", userID),
zap.Error(err),
)
h.response.BadRequest(c, err.Error())
return
}
h.response.Success(c, result, "合同申请成功")
}
// GetCertificationDetails 获取认证详情
// @Summary 获取认证详情
// @Description 获取当前用户的认证申请详细信息
// @Tags 企业认证
// @Accept json
// @Produce json
// @Security Bearer
// @Success 200 {object} responses.CertificationResponse "获取认证详情成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 404 {object} map[string]interface{} "认证申请不存在"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/certification/details [get]
func (h *CertificationHandler) GetCertificationDetails(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
h.response.Unauthorized(c, "用户未认证")
return
}
query := &queries.GetCertificationDetailsQuery{
UserID: userID,
}
result, err := h.appService.GetCertificationDetails(c.Request.Context(), query)
if err != nil {
h.logger.Error("获取认证详情失败",
zap.String("user_id", userID),
zap.Error(err),
)
h.response.BadRequest(c, err.Error())
return
}
h.response.Success(c, result, "获取认证详情成功")
}
// RetryStep 重试步骤
// @Summary 重试认证步骤
// @Description 重新执行指定的认证步骤
// @Tags 企业认证
// @Accept json
// @Produce json
// @Security Bearer
// @Param step path string true "步骤名称"
// @Success 200 {object} map[string]interface{} "步骤重试成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/certification/retry/{step} [post]
func (h *CertificationHandler) RetryStep(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
h.response.Unauthorized(c, "用户未认证")
return
}
step := c.Param("step")
if step == "" {
h.response.BadRequest(c, "步骤名称不能为空")
return
}
var result interface{}
var err error
switch step {
case "face_verify":
result, err = h.appService.RetryFaceVerify(c.Request.Context(), userID)
case "contract_sign":
result, err = h.appService.RetryContractSign(c.Request.Context(), userID)
default:
h.response.BadRequest(c, "不支持的步骤类型")
return
}
if err != nil {
h.logger.Error("重试认证步骤失败",
zap.String("user_id", userID),
zap.String("step", step),
zap.Error(err),
)
h.response.BadRequest(c, err.Error())
return
}
h.response.Success(c, result, "认证步骤重试成功")
}
// GetLicenseOCRResult 获取营业执照OCR识别结果
// @Summary 获取营业执照OCR识别结果
// @Description 根据上传记录ID获取OCR识别结果
// @Tags 企业认证
// @Accept json
// @Produce json
// @Security Bearer
// @Param record_id path string true "上传记录ID"
// @Success 200 {object} responses.UploadLicenseResponse "获取OCR结果成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 404 {object} map[string]interface{} "记录不存在"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/certification/license/{record_id}/ocr-result [get]
func (h *CertificationHandler) GetLicenseOCRResult(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
h.response.Unauthorized(c, "用户未认证")
return
}
recordID := c.Param("record_id")
if recordID == "" {
h.response.BadRequest(c, "上传记录ID不能为空")
return
}
result, err := h.appService.GetLicenseOCRResult(c.Request.Context(), recordID)
if err != nil {
h.logger.Error("获取OCR结果失败",
zap.String("user_id", userID),
zap.String("record_id", recordID),
zap.Error(err),
)
h.response.BadRequest(c, err.Error())
return
}
h.response.Success(c, result, "获取OCR结果成功")
}

View File

@@ -0,0 +1,429 @@
package handlers
import (
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"tyapi-server/internal/application/finance"
"tyapi-server/internal/application/finance/dto/commands"
"tyapi-server/internal/application/finance/dto/queries"
"tyapi-server/internal/shared/interfaces"
)
// FinanceHandler 财务HTTP处理器
type FinanceHandler struct {
appService finance.FinanceApplicationService
responseBuilder interfaces.ResponseBuilder
logger *zap.Logger
}
// NewFinanceHandler 创建财务HTTP处理器
func NewFinanceHandler(
appService finance.FinanceApplicationService,
responseBuilder interfaces.ResponseBuilder,
logger *zap.Logger,
) *FinanceHandler {
return &FinanceHandler{
appService: appService,
responseBuilder: responseBuilder,
logger: logger,
}
}
// CreateWallet 创建钱包
// @Summary 创建钱包
// @Description 为用户创建新的钱包账户
// @Tags 钱包管理
// @Accept json
// @Produce json
// @Param request body commands.CreateWalletCommand true "创建钱包请求"
// @Success 201 {object} responses.WalletResponse "钱包创建成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 409 {object} map[string]interface{} "钱包已存在"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/finance/wallet [post]
func (h *FinanceHandler) CreateWallet(c *gin.Context) {
var cmd commands.CreateWalletCommand
if err := c.ShouldBindJSON(&cmd); err != nil {
h.responseBuilder.BadRequest(c, "请求参数错误")
return
}
response, err := h.appService.CreateWallet(c.Request.Context(), &cmd)
if err != nil {
h.logger.Error("创建钱包失败", zap.Error(err))
h.responseBuilder.BadRequest(c, err.Error())
return
}
h.responseBuilder.Created(c, response, "钱包创建成功")
}
// GetWallet 获取钱包信息
// @Summary 获取钱包信息
// @Description 获取当前用户的钱包详细信息
// @Tags 钱包管理
// @Accept json
// @Produce json
// @Security Bearer
// @Success 200 {object} responses.WalletResponse "获取钱包信息成功"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 404 {object} map[string]interface{} "钱包不存在"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/finance/wallet [get]
func (h *FinanceHandler) GetWallet(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
h.responseBuilder.Unauthorized(c, "用户未认证")
return
}
query := &queries.GetWalletInfoQuery{UserID: userID}
result, err := h.appService.GetWallet(c.Request.Context(), query)
if err != nil {
h.logger.Error("获取钱包信息失败",
zap.String("user_id", userID),
zap.Error(err),
)
h.responseBuilder.BadRequest(c, err.Error())
return
}
h.responseBuilder.Success(c, result, "获取钱包信息成功")
}
// UpdateWallet 更新钱包
// @Summary 更新钱包信息
// @Description 更新当前用户的钱包基本信息
// @Tags 钱包管理
// @Accept json
// @Produce json
// @Security Bearer
// @Param request body commands.UpdateWalletCommand true "更新钱包请求"
// @Success 200 {object} map[string]interface{} "钱包更新成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/finance/wallet [put]
func (h *FinanceHandler) UpdateWallet(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
h.responseBuilder.Unauthorized(c, "用户未认证")
return
}
var cmd commands.UpdateWalletCommand
if err := c.ShouldBindJSON(&cmd); err != nil {
h.responseBuilder.BadRequest(c, "请求参数错误")
return
}
cmd.UserID = userID
err := h.appService.UpdateWallet(c.Request.Context(), &cmd)
if err != nil {
h.logger.Error("更新钱包失败",
zap.String("user_id", userID),
zap.Error(err),
)
h.responseBuilder.BadRequest(c, err.Error())
return
}
h.responseBuilder.Success(c, nil, "钱包更新成功")
}
// Recharge 充值
// @Summary 钱包充值
// @Description 为钱包进行充值操作
// @Tags 钱包管理
// @Accept json
// @Produce json
// @Security Bearer
// @Param request body commands.RechargeWalletCommand true "充值请求"
// @Success 200 {object} responses.TransactionResponse "充值成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/finance/wallet/recharge [post]
func (h *FinanceHandler) Recharge(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
h.responseBuilder.Unauthorized(c, "用户未认证")
return
}
var cmd commands.RechargeWalletCommand
if err := c.ShouldBindJSON(&cmd); err != nil {
h.responseBuilder.BadRequest(c, "请求参数错误")
return
}
cmd.UserID = userID
result, err := h.appService.Recharge(c.Request.Context(), &cmd)
if err != nil {
h.logger.Error("充值失败",
zap.String("user_id", userID),
zap.Error(err),
)
h.responseBuilder.BadRequest(c, err.Error())
return
}
h.responseBuilder.Success(c, result, "充值成功")
}
// Withdraw 提现
// @Summary 钱包提现
// @Description 从钱包进行提现操作
// @Tags 钱包管理
// @Accept json
// @Produce json
// @Security Bearer
// @Param request body commands.WithdrawWalletCommand true "提现请求"
// @Success 200 {object} responses.TransactionResponse "提现申请已提交"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/finance/wallet/withdraw [post]
func (h *FinanceHandler) Withdraw(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
h.responseBuilder.Unauthorized(c, "用户未认证")
return
}
var cmd commands.WithdrawWalletCommand
if err := c.ShouldBindJSON(&cmd); err != nil {
h.responseBuilder.BadRequest(c, "请求参数错误")
return
}
cmd.UserID = userID
result, err := h.appService.Withdraw(c.Request.Context(), &cmd)
if err != nil {
h.logger.Error("提现失败",
zap.String("user_id", userID),
zap.Error(err),
)
h.responseBuilder.BadRequest(c, err.Error())
return
}
h.responseBuilder.Success(c, result, "提现申请已提交")
}
// WalletTransaction 钱包交易
// @Summary 钱包交易
// @Description 执行钱包内部交易操作
// @Tags 钱包管理
// @Accept json
// @Produce json
// @Security Bearer
// @Param request body commands.WalletTransactionCommand true "交易请求"
// @Success 200 {object} responses.TransactionResponse "交易成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/finance/wallet/transaction [post]
func (h *FinanceHandler) WalletTransaction(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
h.responseBuilder.Unauthorized(c, "用户未认证")
return
}
var cmd commands.WalletTransactionCommand
if err := c.ShouldBindJSON(&cmd); err != nil {
h.responseBuilder.BadRequest(c, "请求参数错误")
return
}
cmd.UserID = userID
result, err := h.appService.WalletTransaction(c.Request.Context(), &cmd)
if err != nil {
h.logger.Error("钱包交易失败",
zap.String("user_id", userID),
zap.Error(err),
)
h.responseBuilder.BadRequest(c, err.Error())
return
}
h.responseBuilder.Success(c, result, "交易成功")
}
// GetWalletStats 获取钱包统计
// @Summary 获取钱包统计
// @Description 获取钱包相关的统计数据
// @Tags 钱包管理
// @Accept json
// @Produce json
// @Security Bearer
// @Success 200 {object} responses.WalletStatsResponse "获取钱包统计成功"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/finance/wallet/stats [get]
func (h *FinanceHandler) GetWalletStats(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
h.responseBuilder.Unauthorized(c, "用户未认证")
return
}
result, err := h.appService.GetWalletStats(c.Request.Context())
if err != nil {
h.logger.Error("获取钱包统计失败",
zap.String("user_id", userID),
zap.Error(err),
)
h.responseBuilder.InternalError(c, "获取钱包统计失败")
return
}
h.responseBuilder.Success(c, result, "获取钱包统计成功")
}
// CreateUserSecrets 创建用户密钥
// @Summary 创建用户密钥
// @Description 为用户创建API访问密钥
// @Tags 用户密钥管理
// @Accept json
// @Produce json
// @Security Bearer
// @Param request body commands.CreateUserSecretsCommand true "创建密钥请求"
// @Success 201 {object} responses.UserSecretsResponse "用户密钥创建成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 409 {object} map[string]interface{} "密钥已存在"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/finance/secrets [post]
func (h *FinanceHandler) CreateUserSecrets(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
h.responseBuilder.Unauthorized(c, "用户未认证")
return
}
var cmd commands.CreateUserSecretsCommand
if err := c.ShouldBindJSON(&cmd); err != nil {
h.responseBuilder.BadRequest(c, "请求参数错误")
return
}
cmd.UserID = userID
result, err := h.appService.CreateUserSecrets(c.Request.Context(), &cmd)
if err != nil {
h.logger.Error("创建用户密钥失败",
zap.String("user_id", userID),
zap.Error(err),
)
h.responseBuilder.BadRequest(c, err.Error())
return
}
h.responseBuilder.Created(c, result, "用户密钥创建成功")
}
// GetUserSecrets 获取用户密钥
// @Summary 获取用户密钥
// @Description 获取当前用户的API访问密钥信息
// @Tags 用户密钥管理
// @Accept json
// @Produce json
// @Security Bearer
// @Success 200 {object} responses.UserSecretsResponse "获取用户密钥成功"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 404 {object} map[string]interface{} "密钥不存在"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/finance/secrets [get]
func (h *FinanceHandler) GetUserSecrets(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
h.responseBuilder.Unauthorized(c, "用户未认证")
return
}
query := &queries.GetUserSecretsQuery{UserID: userID}
result, err := h.appService.GetUserSecrets(c.Request.Context(), query)
if err != nil {
h.logger.Error("获取用户密钥失败",
zap.String("user_id", userID),
zap.Error(err),
)
h.responseBuilder.BadRequest(c, err.Error())
return
}
h.responseBuilder.Success(c, result, "获取用户密钥成功")
}
// RegenerateAccessKey 重新生成访问密钥
// @Summary 重新生成访问密钥
// @Description 重新生成用户的API访问密钥
// @Tags 用户密钥管理
// @Accept json
// @Produce json
// @Security Bearer
// @Success 200 {object} responses.UserSecretsResponse "访问密钥重新生成成功"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 404 {object} map[string]interface{} "密钥不存在"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/finance/secrets/regenerate [post]
func (h *FinanceHandler) RegenerateAccessKey(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
h.responseBuilder.Unauthorized(c, "用户未认证")
return
}
cmd := &commands.RegenerateAccessKeyCommand{UserID: userID}
result, err := h.appService.RegenerateAccessKey(c.Request.Context(), cmd)
if err != nil {
h.logger.Error("重新生成访问密钥失败",
zap.String("user_id", userID),
zap.Error(err),
)
h.responseBuilder.BadRequest(c, err.Error())
return
}
h.responseBuilder.Success(c, result, "访问密钥重新生成成功")
}
// DeactivateUserSecrets 停用用户密钥
// @Summary 停用用户密钥
// @Description 停用用户的API访问密钥
// @Tags 用户密钥管理
// @Accept json
// @Produce json
// @Security Bearer
// @Success 200 {object} map[string]interface{} "用户密钥停用成功"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 404 {object} map[string]interface{} "密钥不存在"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/finance/secrets/deactivate [post]
func (h *FinanceHandler) DeactivateUserSecrets(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
h.responseBuilder.Unauthorized(c, "用户未认证")
return
}
cmd := &commands.DeactivateUserSecretsCommand{UserID: userID}
err := h.appService.DeactivateUserSecrets(c.Request.Context(), cmd)
if err != nil {
h.logger.Error("停用用户密钥失败",
zap.String("user_id", userID),
zap.Error(err),
)
h.responseBuilder.BadRequest(c, err.Error())
return
}
h.responseBuilder.Success(c, nil, "用户密钥停用成功")
}

View File

@@ -0,0 +1,224 @@
package handlers
import (
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"tyapi-server/internal/application/user"
"tyapi-server/internal/application/user/dto/commands"
"tyapi-server/internal/shared/interfaces"
"tyapi-server/internal/shared/middleware"
)
// UserHandler 用户HTTP处理器
type UserHandler struct {
appService user.UserApplicationService
response interfaces.ResponseBuilder
validator interfaces.RequestValidator
logger *zap.Logger
jwtAuth *middleware.JWTAuthMiddleware
}
// NewUserHandler 创建用户处理器
func NewUserHandler(
appService user.UserApplicationService,
response interfaces.ResponseBuilder,
validator interfaces.RequestValidator,
logger *zap.Logger,
jwtAuth *middleware.JWTAuthMiddleware,
) *UserHandler {
return &UserHandler{
appService: appService,
response: response,
validator: validator,
logger: logger,
jwtAuth: jwtAuth,
}
}
// SendCode 发送验证码
// @Summary 发送短信验证码
// @Description 向指定手机号发送验证码,支持注册、登录、修改密码等场景
// @Tags 用户认证
// @Accept json
// @Produce json
// @Param request body commands.SendCodeCommand true "发送验证码请求"
// @Success 200 {object} map[string]interface{} "验证码发送成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 429 {object} map[string]interface{} "请求频率限制"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/users/send-code [post]
func (h *UserHandler) SendCode(c *gin.Context) {
var cmd commands.SendCodeCommand
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
return
}
clientIP := c.ClientIP()
userAgent := c.GetHeader("User-Agent")
if err := h.appService.SendCode(c.Request.Context(), &cmd, clientIP, userAgent); err != nil {
h.response.BadRequest(c, err.Error())
return
}
h.response.Success(c, nil, "验证码发送成功")
}
// Register 用户注册
// @Summary 用户注册
// @Description 使用手机号、密码和验证码进行用户注册,需要确认密码
// @Tags 用户认证
// @Accept json
// @Produce json
// @Param request body commands.RegisterUserCommand true "用户注册请求"
// @Success 201 {object} responses.RegisterUserResponse "注册成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误或验证码无效"
// @Failure 409 {object} map[string]interface{} "手机号已存在"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/users/register [post]
func (h *UserHandler) Register(c *gin.Context) {
var cmd commands.RegisterUserCommand
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
return
}
resp, err := h.appService.Register(c.Request.Context(), &cmd)
if err != nil {
h.logger.Error("注册用户失败", zap.Error(err))
h.response.BadRequest(c, err.Error())
return
}
h.response.Created(c, resp, "用户注册成功")
}
// LoginWithPassword 密码登录
// @Summary 用户密码登录
// @Description 使用手机号和密码进行用户登录返回JWT令牌
// @Tags 用户认证
// @Accept json
// @Produce json
// @Param request body commands.LoginWithPasswordCommand true "密码登录请求"
// @Success 200 {object} responses.LoginUserResponse "登录成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "用户名或密码错误"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/users/login-password [post]
func (h *UserHandler) LoginWithPassword(c *gin.Context) {
var cmd commands.LoginWithPasswordCommand
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
return
}
resp, err := h.appService.LoginWithPassword(c.Request.Context(), &cmd)
if err != nil {
h.logger.Error("密码登录失败", zap.Error(err))
h.response.Unauthorized(c, "用户名或密码错误")
return
}
h.response.Success(c, resp, "登录成功")
}
// LoginWithSMS 短信验证码登录
// @Summary 用户短信验证码登录
// @Description 使用手机号和短信验证码进行用户登录返回JWT令牌
// @Tags 用户认证
// @Accept json
// @Produce json
// @Param request body commands.LoginWithSMSCommand true "短信登录请求"
// @Success 200 {object} responses.LoginUserResponse "登录成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误或验证码无效"
// @Failure 401 {object} map[string]interface{} "认证失败"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/users/login-sms [post]
func (h *UserHandler) LoginWithSMS(c *gin.Context) {
var cmd commands.LoginWithSMSCommand
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
return
}
resp, err := h.appService.LoginWithSMS(c.Request.Context(), &cmd)
if err != nil {
h.logger.Error("短信登录失败", zap.Error(err))
h.response.Unauthorized(c, err.Error())
return
}
h.response.Success(c, resp, "登录成功")
}
// GetProfile 获取当前用户信息
// @Summary 获取当前用户信息
// @Description 根据JWT令牌获取当前登录用户的详细信息
// @Tags 用户管理
// @Accept json
// @Produce json
// @Security Bearer
// @Success 200 {object} responses.UserProfileResponse "用户信息"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 404 {object} map[string]interface{} "用户不存在"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/users/me [get]
func (h *UserHandler) GetProfile(c *gin.Context) {
userID := h.getCurrentUserID(c)
if userID == "" {
h.response.Unauthorized(c, "用户未认证")
return
}
resp, err := h.appService.GetUserProfile(c.Request.Context(), userID)
if err != nil {
h.logger.Error("获取用户资料失败", zap.Error(err))
h.response.NotFound(c, "用户不存在")
return
}
h.response.Success(c, resp, "获取用户资料成功")
}
// ChangePassword 修改密码
// @Summary 修改密码
// @Description 使用旧密码、新密码确认和验证码修改当前用户的密码
// @Tags 用户管理
// @Accept json
// @Produce json
// @Security Bearer
// @Param request body commands.ChangePasswordCommand true "修改密码请求"
// @Success 200 {object} map[string]interface{} "密码修改成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误或验证码无效"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/users/me/password [put]
func (h *UserHandler) ChangePassword(c *gin.Context) {
userID := h.getCurrentUserID(c)
if userID == "" {
h.response.Unauthorized(c, "用户未认证")
return
}
var cmd commands.ChangePasswordCommand
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
return
}
cmd.UserID = userID
if err := h.appService.ChangePassword(c.Request.Context(), &cmd); err != nil {
h.logger.Error("修改密码失败", zap.Error(err))
h.response.BadRequest(c, err.Error())
return
}
h.response.Success(c, nil, "密码修改成功")
}
// getCurrentUserID 获取当前用户ID
func (h *UserHandler) getCurrentUserID(c *gin.Context) string {
if userID, exists := c.Get("user_id"); exists {
if id, ok := userID.(string); ok {
return id
}
}
return ""
}

View File

@@ -0,0 +1,58 @@
package routes
import (
"tyapi-server/internal/infrastructure/http/handlers"
sharedhttp "tyapi-server/internal/shared/http"
"tyapi-server/internal/shared/middleware"
"go.uber.org/zap"
)
// AdminRoutes 管理员路由注册器
type AdminRoutes struct {
handler *handlers.AdminHandler
authMiddleware *middleware.JWTAuthMiddleware
logger *zap.Logger
}
// NewAdminRoutes 创建管理员路由注册器
func NewAdminRoutes(
handler *handlers.AdminHandler,
authMiddleware *middleware.JWTAuthMiddleware,
logger *zap.Logger,
) *AdminRoutes {
return &AdminRoutes{
handler: handler,
authMiddleware: authMiddleware,
logger: logger,
}
}
// Register 注册管理员相关路由
func (r *AdminRoutes) Register(router *sharedhttp.GinRouter) {
// 管理员路由组
engine := router.GetEngine()
adminGroup := engine.Group("/api/v1/admin")
{
// 认证相关路由(无需认证)
authGroup := adminGroup.Group("/auth")
{
authGroup.POST("/login", r.handler.Login)
}
// 管理员管理路由(需要认证)
authenticated := adminGroup.Group("")
authenticated.Use(r.authMiddleware.Handle())
{
authenticated.POST("", r.handler.CreateAdmin) // 创建管理员
authenticated.GET("", r.handler.ListAdmins) // 获取管理员列表
authenticated.GET("/stats", r.handler.GetAdminStats) // 获取统计信息
authenticated.GET("/:id", r.handler.GetAdminByID) // 获取管理员详情
authenticated.PUT("/:id", r.handler.UpdateAdmin) // 更新管理员
authenticated.DELETE("/:id", r.handler.DeleteAdmin) // 删除管理员
authenticated.POST("/change-password", r.handler.ChangePassword) // 修改密码
}
}
r.logger.Info("管理员路由注册完成")
}

View File

@@ -0,0 +1,73 @@
package routes
import (
"tyapi-server/internal/infrastructure/http/handlers"
sharedhttp "tyapi-server/internal/shared/http"
"tyapi-server/internal/shared/middleware"
"go.uber.org/zap"
)
// CertificationRoutes 认证路由注册器
type CertificationRoutes struct {
certificationHandler *handlers.CertificationHandler
authMiddleware *middleware.JWTAuthMiddleware
logger *zap.Logger
}
// NewCertificationRoutes 创建认证路由注册器
func NewCertificationRoutes(
certificationHandler *handlers.CertificationHandler,
authMiddleware *middleware.JWTAuthMiddleware,
logger *zap.Logger,
) *CertificationRoutes {
return &CertificationRoutes{
certificationHandler: certificationHandler,
authMiddleware: authMiddleware,
logger: logger,
}
}
// Register 注册认证相关路由
func (r *CertificationRoutes) Register(router *sharedhttp.GinRouter) {
// 认证相关路由组,需要用户认证
engine := router.GetEngine()
certificationGroup := engine.Group("/api/v1/certification")
certificationGroup.Use(r.authMiddleware.Handle())
{
// 创建认证申请
certificationGroup.POST("", r.certificationHandler.CreateCertification)
// 营业执照上传
certificationGroup.POST("/upload-license", r.certificationHandler.UploadBusinessLicense)
// 获取OCR识别结果
certificationGroup.GET("/license/:record_id/ocr-result", r.certificationHandler.GetLicenseOCRResult)
// 获取认证状态
certificationGroup.GET("/status", r.certificationHandler.GetCertificationStatus)
// 获取进度统计
certificationGroup.GET("/stats", r.certificationHandler.GetProgressStats)
// 获取认证进度
certificationGroup.GET("/progress", r.certificationHandler.GetCertificationProgress)
// 提交企业信息
certificationGroup.POST("/enterprise-info", r.certificationHandler.SubmitEnterpriseInfo)
// 发起人脸识别验证
certificationGroup.POST("/face-verify", r.certificationHandler.InitiateFaceVerify)
// 申请合同签署
certificationGroup.POST("/contract", r.certificationHandler.ApplyContract)
// 获取认证详情
certificationGroup.GET("/details", r.certificationHandler.GetCertificationDetails)
// 重试认证步骤
certificationGroup.POST("/retry/:step", r.certificationHandler.RetryStep)
}
r.logger.Info("认证路由注册完成")
}

View File

@@ -0,0 +1,61 @@
package routes
import (
"tyapi-server/internal/infrastructure/http/handlers"
sharedhttp "tyapi-server/internal/shared/http"
"tyapi-server/internal/shared/middleware"
"go.uber.org/zap"
)
// FinanceRoutes 财务路由注册器
type FinanceRoutes struct {
financeHandler *handlers.FinanceHandler
authMiddleware *middleware.JWTAuthMiddleware
logger *zap.Logger
}
// NewFinanceRoutes 创建财务路由注册器
func NewFinanceRoutes(
financeHandler *handlers.FinanceHandler,
authMiddleware *middleware.JWTAuthMiddleware,
logger *zap.Logger,
) *FinanceRoutes {
return &FinanceRoutes{
financeHandler: financeHandler,
authMiddleware: authMiddleware,
logger: logger,
}
}
// Register 注册财务相关路由
func (r *FinanceRoutes) Register(router *sharedhttp.GinRouter) {
// 财务路由组,需要用户认证
engine := router.GetEngine()
financeGroup := engine.Group("/api/v1/finance")
financeGroup.Use(r.authMiddleware.Handle())
{
// 钱包相关路由
walletGroup := financeGroup.Group("/wallet")
{
walletGroup.POST("", r.financeHandler.CreateWallet) // 创建钱包
walletGroup.GET("", r.financeHandler.GetWallet) // 获取钱包信息
walletGroup.PUT("", r.financeHandler.UpdateWallet) // 更新钱包
walletGroup.POST("/recharge", r.financeHandler.Recharge) // 充值
walletGroup.POST("/withdraw", r.financeHandler.Withdraw) // 提现
walletGroup.POST("/transaction", r.financeHandler.WalletTransaction) // 钱包交易
walletGroup.GET("/stats", r.financeHandler.GetWalletStats) // 获取钱包统计
}
// 用户密钥相关路由
secretsGroup := financeGroup.Group("/secrets")
{
secretsGroup.POST("", r.financeHandler.CreateUserSecrets) // 创建用户密钥
secretsGroup.GET("", r.financeHandler.GetUserSecrets) // 获取用户密钥
secretsGroup.POST("/regenerate", r.financeHandler.RegenerateAccessKey) // 重新生成访问密钥
secretsGroup.POST("/deactivate", r.financeHandler.DeactivateUserSecrets) // 停用用户密钥
}
}
r.logger.Info("财务路由注册完成")
}

View File

@@ -0,0 +1,53 @@
package routes
import (
"tyapi-server/internal/infrastructure/http/handlers"
sharedhttp "tyapi-server/internal/shared/http"
"tyapi-server/internal/shared/middleware"
"go.uber.org/zap"
)
// UserRoutes 用户路由注册器
type UserRoutes struct {
handler *handlers.UserHandler
authMiddleware *middleware.JWTAuthMiddleware
logger *zap.Logger
}
// NewUserRoutes 创建用户路由注册器
func NewUserRoutes(
handler *handlers.UserHandler,
authMiddleware *middleware.JWTAuthMiddleware,
logger *zap.Logger,
) *UserRoutes {
return &UserRoutes{
handler: handler,
authMiddleware: authMiddleware,
logger: logger,
}
}
// Register 注册用户相关路由
func (r *UserRoutes) Register(router *sharedhttp.GinRouter) {
// 用户域路由组
engine := router.GetEngine()
usersGroup := engine.Group("/api/v1/users")
{
// 公开路由(不需要认证)
usersGroup.POST("/send-code", r.handler.SendCode) // 发送验证码
usersGroup.POST("/register", r.handler.Register) // 用户注册
usersGroup.POST("/login-password", r.handler.LoginWithPassword) // 密码登录
usersGroup.POST("/login-sms", r.handler.LoginWithSMS) // 短信验证码登录
// 需要认证的路由
authenticated := usersGroup.Group("")
authenticated.Use(r.authMiddleware.Handle())
{
authenticated.GET("/me", r.handler.GetProfile) // 获取当前用户信息
authenticated.PUT("/me/password", r.handler.ChangePassword) // 修改密码
}
}
r.logger.Info("用户路由注册完成")
}