f
This commit is contained in:
@@ -0,0 +1,101 @@
|
||||
// internal/infrastructure/database/repositories/user/gorm_contract_info_repository.go
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"hyapi-server/internal/domains/user/entities"
|
||||
"hyapi-server/internal/domains/user/repositories"
|
||||
"hyapi-server/internal/shared/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
ContractInfosTable = "contract_infos"
|
||||
)
|
||||
|
||||
type GormContractInfoRepository struct {
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
func NewGormContractInfoRepository(db *gorm.DB, logger *zap.Logger) repositories.ContractInfoRepository {
|
||||
return &GormContractInfoRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, ContractInfosTable),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *GormContractInfoRepository) Save(ctx context.Context, contract *entities.ContractInfo) error {
|
||||
return r.CreateEntity(ctx, contract)
|
||||
}
|
||||
|
||||
func (r *GormContractInfoRepository) FindByID(ctx context.Context, contractID string) (*entities.ContractInfo, error) {
|
||||
var contract entities.ContractInfo
|
||||
err := r.SmartGetByID(ctx, contractID, &contract)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &contract, nil
|
||||
}
|
||||
|
||||
func (r *GormContractInfoRepository) Delete(ctx context.Context, contractID string) error {
|
||||
return r.DeleteEntity(ctx, contractID, &entities.ContractInfo{})
|
||||
}
|
||||
|
||||
func (r *GormContractInfoRepository) FindByEnterpriseInfoID(ctx context.Context, enterpriseInfoID string) ([]*entities.ContractInfo, error) {
|
||||
var contracts []entities.ContractInfo
|
||||
err := r.GetDB(ctx).Where("enterprise_info_id = ?", enterpriseInfoID).Find(&contracts).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make([]*entities.ContractInfo, len(contracts))
|
||||
for i := range contracts {
|
||||
result[i] = &contracts[i]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *GormContractInfoRepository) FindByUserID(ctx context.Context, userID string) ([]*entities.ContractInfo, error) {
|
||||
var contracts []entities.ContractInfo
|
||||
err := r.GetDB(ctx).Where("user_id = ?", userID).Find(&contracts).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make([]*entities.ContractInfo, len(contracts))
|
||||
for i := range contracts {
|
||||
result[i] = &contracts[i]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *GormContractInfoRepository) FindByContractType(ctx context.Context, enterpriseInfoID string, contractType entities.ContractType) ([]*entities.ContractInfo, error) {
|
||||
var contracts []entities.ContractInfo
|
||||
err := r.GetDB(ctx).Where("enterprise_info_id = ? AND contract_type = ?", enterpriseInfoID, contractType).Find(&contracts).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make([]*entities.ContractInfo, len(contracts))
|
||||
for i := range contracts {
|
||||
result[i] = &contracts[i]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *GormContractInfoRepository) ExistsByContractFileID(ctx context.Context, contractFileID string) (bool, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.ContractInfo{}).Where("contract_file_id = ?", contractFileID).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
func (r *GormContractInfoRepository) ExistsByContractFileIDExcludeID(ctx context.Context, contractFileID, excludeID string) (bool, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.ContractInfo{}).Where("contract_file_id = ? AND id != ?", contractFileID, excludeID).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
@@ -0,0 +1,272 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"hyapi-server/internal/domains/user/entities"
|
||||
"hyapi-server/internal/domains/user/repositories"
|
||||
"hyapi-server/internal/shared/interfaces"
|
||||
)
|
||||
|
||||
// GormEnterpriseInfoRepository 企业信息GORM仓储实现
|
||||
type GormEnterpriseInfoRepository struct {
|
||||
db *gorm.DB
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewGormEnterpriseInfoRepository 创建企业信息GORM仓储
|
||||
func NewGormEnterpriseInfoRepository(db *gorm.DB, logger *zap.Logger) repositories.EnterpriseInfoRepository {
|
||||
return &GormEnterpriseInfoRepository{
|
||||
db: db,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建企业信息
|
||||
func (r *GormEnterpriseInfoRepository) Create(ctx context.Context, enterpriseInfo entities.EnterpriseInfo) (entities.EnterpriseInfo, error) {
|
||||
if err := r.db.WithContext(ctx).Create(&enterpriseInfo).Error; err != nil {
|
||||
r.logger.Error("创建企业信息失败", zap.Error(err))
|
||||
return entities.EnterpriseInfo{}, fmt.Errorf("创建企业信息失败: %w", err)
|
||||
}
|
||||
return enterpriseInfo, nil
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取企业信息
|
||||
func (r *GormEnterpriseInfoRepository) GetByID(ctx context.Context, id string) (entities.EnterpriseInfo, error) {
|
||||
var enterpriseInfo entities.EnterpriseInfo
|
||||
if err := r.db.WithContext(ctx).Where("id = ?", id).First(&enterpriseInfo).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return entities.EnterpriseInfo{}, fmt.Errorf("企业信息不存在")
|
||||
}
|
||||
r.logger.Error("获取企业信息失败", zap.Error(err))
|
||||
return entities.EnterpriseInfo{}, fmt.Errorf("获取企业信息失败: %w", err)
|
||||
}
|
||||
return enterpriseInfo, nil
|
||||
}
|
||||
|
||||
// Update 更新企业信息
|
||||
func (r *GormEnterpriseInfoRepository) Update(ctx context.Context, enterpriseInfo entities.EnterpriseInfo) error {
|
||||
if err := r.db.WithContext(ctx).Save(&enterpriseInfo).Error; err != nil {
|
||||
r.logger.Error("更新企业信息失败", zap.Error(err))
|
||||
return fmt.Errorf("更新企业信息失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete 删除企业信息
|
||||
func (r *GormEnterpriseInfoRepository) Delete(ctx context.Context, id string) error {
|
||||
if err := r.db.WithContext(ctx).Delete(&entities.EnterpriseInfo{}, "id = ?", id).Error; err != nil {
|
||||
r.logger.Error("删除企业信息失败", zap.Error(err))
|
||||
return fmt.Errorf("删除企业信息失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SoftDelete 软删除企业信息
|
||||
func (r *GormEnterpriseInfoRepository) SoftDelete(ctx context.Context, id string) error {
|
||||
if err := r.db.WithContext(ctx).Delete(&entities.EnterpriseInfo{}, "id = ?", id).Error; err != nil {
|
||||
r.logger.Error("软删除企业信息失败", zap.Error(err))
|
||||
return fmt.Errorf("软删除企业信息失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Restore 恢复软删除的企业信息
|
||||
func (r *GormEnterpriseInfoRepository) Restore(ctx context.Context, id string) error {
|
||||
if err := r.db.WithContext(ctx).Unscoped().Model(&entities.EnterpriseInfo{}).Where("id = ?", id).Update("deleted_at", nil).Error; err != nil {
|
||||
r.logger.Error("恢复企业信息失败", zap.Error(err))
|
||||
return fmt.Errorf("恢复企业信息失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetByUserID 根据用户ID获取企业信息
|
||||
func (r *GormEnterpriseInfoRepository) GetByUserID(ctx context.Context, userID string) (*entities.EnterpriseInfo, error) {
|
||||
var enterpriseInfo entities.EnterpriseInfo
|
||||
if err := r.db.WithContext(ctx).Where("user_id = ?", userID).First(&enterpriseInfo).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("企业信息不存在")
|
||||
}
|
||||
r.logger.Error("获取企业信息失败", zap.Error(err))
|
||||
return nil, fmt.Errorf("获取企业信息失败: %w", err)
|
||||
}
|
||||
return &enterpriseInfo, nil
|
||||
}
|
||||
|
||||
// GetByUnifiedSocialCode 根据统一社会信用代码获取企业信息
|
||||
func (r *GormEnterpriseInfoRepository) GetByUnifiedSocialCode(ctx context.Context, unifiedSocialCode string) (*entities.EnterpriseInfo, error) {
|
||||
var enterpriseInfo entities.EnterpriseInfo
|
||||
if err := r.db.WithContext(ctx).Where("unified_social_code = ?", unifiedSocialCode).First(&enterpriseInfo).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("企业信息不存在")
|
||||
}
|
||||
r.logger.Error("获取企业信息失败", zap.Error(err))
|
||||
return nil, fmt.Errorf("获取企业信息失败: %w", err)
|
||||
}
|
||||
return &enterpriseInfo, nil
|
||||
}
|
||||
|
||||
// CheckUnifiedSocialCodeExists 检查统一社会信用代码是否已存在
|
||||
func (r *GormEnterpriseInfoRepository) CheckUnifiedSocialCodeExists(ctx context.Context, unifiedSocialCode string, excludeUserID string) (bool, error) {
|
||||
var count int64
|
||||
query := r.db.WithContext(ctx).Model(&entities.EnterpriseInfo{}).Where("unified_social_code = ?", unifiedSocialCode)
|
||||
|
||||
if excludeUserID != "" {
|
||||
query = query.Where("user_id != ?", excludeUserID)
|
||||
}
|
||||
|
||||
if err := query.Count(&count).Error; err != nil {
|
||||
r.logger.Error("检查统一社会信用代码失败", zap.Error(err))
|
||||
return false, fmt.Errorf("检查统一社会信用代码失败: %w", err)
|
||||
}
|
||||
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// UpdateVerificationStatus 更新验证状态
|
||||
func (r *GormEnterpriseInfoRepository) UpdateVerificationStatus(ctx context.Context, userID string, isOCRVerified, isFaceVerified, isCertified bool) error {
|
||||
updates := map[string]interface{}{
|
||||
"is_ocr_verified": isOCRVerified,
|
||||
"is_face_verified": isFaceVerified,
|
||||
"is_certified": isCertified,
|
||||
}
|
||||
|
||||
if err := r.db.WithContext(ctx).Model(&entities.EnterpriseInfo{}).Where("user_id = ?", userID).Updates(updates).Error; err != nil {
|
||||
r.logger.Error("更新验证状态失败", zap.Error(err))
|
||||
return fmt.Errorf("更新验证状态失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateOCRData 更新OCR数据
|
||||
func (r *GormEnterpriseInfoRepository) UpdateOCRData(ctx context.Context, userID string, rawData string, confidence float64) error {
|
||||
updates := map[string]interface{}{
|
||||
"ocr_raw_data": rawData,
|
||||
"ocr_confidence": confidence,
|
||||
"is_ocr_verified": true,
|
||||
}
|
||||
|
||||
if err := r.db.WithContext(ctx).Model(&entities.EnterpriseInfo{}).Where("user_id = ?", userID).Updates(updates).Error; err != nil {
|
||||
r.logger.Error("更新OCR数据失败", zap.Error(err))
|
||||
return fmt.Errorf("更新OCR数据失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CompleteCertification 完成认证
|
||||
func (r *GormEnterpriseInfoRepository) CompleteCertification(ctx context.Context, userID string) error {
|
||||
now := time.Now()
|
||||
updates := map[string]interface{}{
|
||||
"is_certified": true,
|
||||
"certified_at": &now,
|
||||
}
|
||||
|
||||
if err := r.db.WithContext(ctx).Model(&entities.EnterpriseInfo{}).Where("user_id = ?", userID).Updates(updates).Error; err != nil {
|
||||
r.logger.Error("完成认证失败", zap.Error(err))
|
||||
return fmt.Errorf("完成认证失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Count 统计企业信息数量
|
||||
func (r *GormEnterpriseInfoRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
|
||||
var count int64
|
||||
query := r.db.WithContext(ctx).Model(&entities.EnterpriseInfo{})
|
||||
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
if options.Search != "" {
|
||||
query = query.Where("company_name LIKE ? OR unified_social_code LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
err := query.Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// Exists 检查企业信息是否存在
|
||||
func (r *GormEnterpriseInfoRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&entities.EnterpriseInfo{}).Where("id = ?", id).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// CreateBatch 批量创建企业信息
|
||||
func (r *GormEnterpriseInfoRepository) CreateBatch(ctx context.Context, enterpriseInfos []entities.EnterpriseInfo) error {
|
||||
return r.db.WithContext(ctx).Create(&enterpriseInfos).Error
|
||||
}
|
||||
|
||||
// GetByIDs 根据ID列表获取企业信息
|
||||
func (r *GormEnterpriseInfoRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.EnterpriseInfo, error) {
|
||||
var enterpriseInfos []entities.EnterpriseInfo
|
||||
err := r.db.WithContext(ctx).Where("id IN ?", ids).Order("created_at DESC").Find(&enterpriseInfos).Error
|
||||
return enterpriseInfos, err
|
||||
}
|
||||
|
||||
// UpdateBatch 批量更新企业信息
|
||||
func (r *GormEnterpriseInfoRepository) UpdateBatch(ctx context.Context, enterpriseInfos []entities.EnterpriseInfo) error {
|
||||
return r.db.WithContext(ctx).Save(&enterpriseInfos).Error
|
||||
}
|
||||
|
||||
// DeleteBatch 批量删除企业信息
|
||||
func (r *GormEnterpriseInfoRepository) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
return r.db.WithContext(ctx).Delete(&entities.EnterpriseInfo{}, "id IN ?", ids).Error
|
||||
}
|
||||
|
||||
// List 获取企业信息列表
|
||||
func (r *GormEnterpriseInfoRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.EnterpriseInfo, error) {
|
||||
var enterpriseInfos []entities.EnterpriseInfo
|
||||
query := r.db.WithContext(ctx).Model(&entities.EnterpriseInfo{})
|
||||
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
if options.Search != "" {
|
||||
query = query.Where("company_name LIKE ? OR unified_social_code LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
if options.Sort != "" {
|
||||
order := "ASC"
|
||||
if options.Order != "" {
|
||||
order = options.Order
|
||||
}
|
||||
query = query.Order(options.Sort + " " + order)
|
||||
} else {
|
||||
// 默认按创建时间倒序
|
||||
query = query.Order("created_at DESC")
|
||||
}
|
||||
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
query = query.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
|
||||
err := query.Find(&enterpriseInfos).Error
|
||||
return enterpriseInfos, err
|
||||
}
|
||||
|
||||
// WithTx 使用事务
|
||||
func (r *GormEnterpriseInfoRepository) WithTx(tx interface{}) interfaces.Repository[entities.EnterpriseInfo] {
|
||||
if gormTx, ok := tx.(*gorm.DB); ok {
|
||||
return &GormEnterpriseInfoRepository{
|
||||
db: gormTx,
|
||||
logger: r.logger,
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
@@ -0,0 +1,374 @@
|
||||
//go:build !test
|
||||
// +build !test
|
||||
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"hyapi-server/internal/domains/user/entities"
|
||||
"hyapi-server/internal/domains/user/repositories"
|
||||
"hyapi-server/internal/domains/user/repositories/queries"
|
||||
"hyapi-server/internal/shared/database"
|
||||
"hyapi-server/internal/shared/interfaces"
|
||||
)
|
||||
|
||||
const (
|
||||
SMSCodesTable = "sms_codes"
|
||||
)
|
||||
|
||||
// GormSMSCodeRepository 短信验证码GORM仓储实现(无缓存,确保安全性)
|
||||
type GormSMSCodeRepository struct {
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
// NewGormSMSCodeRepository 创建短信验证码仓储
|
||||
func NewGormSMSCodeRepository(db *gorm.DB, logger *zap.Logger) repositories.SMSCodeRepository {
|
||||
return &GormSMSCodeRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, SMSCodesTable),
|
||||
}
|
||||
}
|
||||
|
||||
// 确保 GormSMSCodeRepository 实现了 SMSCodeRepository 接口
|
||||
var _ repositories.SMSCodeRepository = (*GormSMSCodeRepository)(nil)
|
||||
|
||||
// ================ Repository[T] 接口实现 ================
|
||||
|
||||
// Create 创建短信验证码记录(不缓存,确保安全性)
|
||||
func (r *GormSMSCodeRepository) Create(ctx context.Context, smsCode entities.SMSCode) (entities.SMSCode, error) {
|
||||
err := r.GetDB(ctx).Create(&smsCode).Error
|
||||
return smsCode, err
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取短信验证码
|
||||
func (r *GormSMSCodeRepository) GetByID(ctx context.Context, id string) (entities.SMSCode, error) {
|
||||
var smsCode entities.SMSCode
|
||||
err := r.GetDB(ctx).Where("id = ?", id).First(&smsCode).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return entities.SMSCode{}, fmt.Errorf("短信验证码不存在")
|
||||
}
|
||||
return entities.SMSCode{}, err
|
||||
}
|
||||
return smsCode, nil
|
||||
}
|
||||
|
||||
// Update 更新验证码记录
|
||||
func (r *GormSMSCodeRepository) Update(ctx context.Context, smsCode entities.SMSCode) error {
|
||||
return r.GetDB(ctx).Save(&smsCode).Error
|
||||
}
|
||||
|
||||
// CreateBatch 批量创建短信验证码
|
||||
func (r *GormSMSCodeRepository) CreateBatch(ctx context.Context, smsCodes []entities.SMSCode) error {
|
||||
return r.GetDB(ctx).Create(&smsCodes).Error
|
||||
}
|
||||
|
||||
// GetByIDs 根据ID列表获取短信验证码
|
||||
func (r *GormSMSCodeRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.SMSCode, error) {
|
||||
var smsCodes []entities.SMSCode
|
||||
err := r.GetDB(ctx).Where("id IN ?", ids).Order("created_at DESC").Find(&smsCodes).Error
|
||||
return smsCodes, err
|
||||
}
|
||||
|
||||
// UpdateBatch 批量更新短信验证码
|
||||
func (r *GormSMSCodeRepository) UpdateBatch(ctx context.Context, smsCodes []entities.SMSCode) error {
|
||||
return r.GetDB(ctx).Save(&smsCodes).Error
|
||||
}
|
||||
|
||||
// DeleteBatch 批量删除短信验证码
|
||||
func (r *GormSMSCodeRepository) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
return r.GetDB(ctx).Delete(&entities.SMSCode{}, "id IN ?", ids).Error
|
||||
}
|
||||
|
||||
// List 获取短信验证码列表
|
||||
func (r *GormSMSCodeRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.SMSCode, error) {
|
||||
var smsCodes []entities.SMSCode
|
||||
query := r.GetDB(ctx).Model(&entities.SMSCode{})
|
||||
|
||||
// 应用筛选条件
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
// 应用搜索条件
|
||||
if options.Search != "" {
|
||||
query = query.Where("phone LIKE ?", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
// 应用预加载
|
||||
for _, include := range options.Include {
|
||||
query = query.Preload(include)
|
||||
}
|
||||
|
||||
// 应用排序
|
||||
if options.Sort != "" {
|
||||
order := "ASC"
|
||||
if options.Order == "desc" || options.Order == "DESC" {
|
||||
order = "DESC"
|
||||
}
|
||||
query = query.Order(options.Sort + " " + order)
|
||||
} else {
|
||||
query = query.Order("created_at DESC")
|
||||
}
|
||||
|
||||
// 应用分页
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
query = query.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
|
||||
return smsCodes, query.Find(&smsCodes).Error
|
||||
}
|
||||
|
||||
// ================ BaseRepository 接口实现 ================
|
||||
|
||||
// Delete 删除短信验证码
|
||||
func (r *GormSMSCodeRepository) Delete(ctx context.Context, id string) error {
|
||||
return r.GetDB(ctx).Delete(&entities.SMSCode{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
// Exists 检查短信验证码是否存在
|
||||
func (r *GormSMSCodeRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.SMSCode{}).Where("id = ?", id).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// Count 统计短信验证码数量
|
||||
func (r *GormSMSCodeRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
|
||||
var count int64
|
||||
query := r.GetDB(ctx).Model(&entities.SMSCode{})
|
||||
|
||||
// 应用筛选条件
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
// 应用搜索条件
|
||||
if options.Search != "" {
|
||||
query = query.Where("phone LIKE ?", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
err := query.Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// SoftDelete 软删除短信验证码
|
||||
func (r *GormSMSCodeRepository) SoftDelete(ctx context.Context, id string) error {
|
||||
return r.GetDB(ctx).Delete(&entities.SMSCode{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
// Restore 恢复短信验证码
|
||||
func (r *GormSMSCodeRepository) Restore(ctx context.Context, id string) error {
|
||||
return r.GetDB(ctx).Unscoped().Model(&entities.SMSCode{}).Where("id = ?", id).Update("deleted_at", nil).Error
|
||||
}
|
||||
|
||||
// ================ 业务专用方法 ================
|
||||
|
||||
// GetByPhone 根据手机号获取短信验证码
|
||||
func (r *GormSMSCodeRepository) GetByPhone(ctx context.Context, phone string) (*entities.SMSCode, error) {
|
||||
var smsCode entities.SMSCode
|
||||
if err := r.GetDB(ctx).Where("phone = ?", phone).Order("created_at DESC").First(&smsCode).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("短信验证码不存在")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &smsCode, nil
|
||||
}
|
||||
|
||||
// GetLatestByPhone 根据手机号获取最新短信验证码
|
||||
func (r *GormSMSCodeRepository) GetLatestByPhone(ctx context.Context, phone string) (*entities.SMSCode, error) {
|
||||
var smsCode entities.SMSCode
|
||||
if err := r.GetDB(ctx).Where("phone = ?", phone).Order("created_at DESC").First(&smsCode).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("短信验证码不存在")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &smsCode, nil
|
||||
}
|
||||
|
||||
// GetValidByPhone 根据手机号获取有效的短信验证码
|
||||
func (r *GormSMSCodeRepository) GetValidByPhone(ctx context.Context, phone string) (*entities.SMSCode, error) {
|
||||
var smsCode entities.SMSCode
|
||||
if err := r.GetDB(ctx).
|
||||
Where("phone = ? AND expires_at > ? AND used_at IS NULL", phone, time.Now()).
|
||||
Order("created_at DESC").
|
||||
First(&smsCode).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("有效的短信验证码不存在")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &smsCode, nil
|
||||
}
|
||||
|
||||
// GetValidByPhoneAndScene 根据手机号和场景获取有效的短信验证码
|
||||
func (r *GormSMSCodeRepository) GetValidByPhoneAndScene(ctx context.Context, phone string, scene entities.SMSScene) (*entities.SMSCode, error) {
|
||||
var smsCode entities.SMSCode
|
||||
if err := r.GetDB(ctx).
|
||||
Where("phone = ? AND scene = ? AND expires_at > ? AND used_at IS NULL", phone, scene, time.Now()).
|
||||
Order("created_at DESC").
|
||||
First(&smsCode).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("有效的短信验证码不存在")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &smsCode, nil
|
||||
}
|
||||
|
||||
// ListSMSCodes 获取短信验证码列表(带分页和筛选)
|
||||
func (r *GormSMSCodeRepository) ListSMSCodes(ctx context.Context, query *queries.ListSMSCodesQuery) ([]*entities.SMSCode, int64, error) {
|
||||
var smsCodes []*entities.SMSCode
|
||||
var total int64
|
||||
|
||||
// 构建查询条件
|
||||
db := r.GetDB(ctx).Model(&entities.SMSCode{})
|
||||
|
||||
// 应用筛选条件
|
||||
if query.Phone != "" {
|
||||
db = db.Where("phone = ?", query.Phone)
|
||||
}
|
||||
if query.Purpose != "" {
|
||||
db = db.Where("scene = ?", query.Purpose)
|
||||
}
|
||||
if query.Status != "" {
|
||||
db = db.Where("used = ?", query.Status == "used")
|
||||
}
|
||||
if query.StartDate != "" {
|
||||
db = db.Where("created_at >= ?", query.StartDate)
|
||||
}
|
||||
if query.EndDate != "" {
|
||||
db = db.Where("created_at <= ?", query.EndDate)
|
||||
}
|
||||
|
||||
// 统计总数
|
||||
if err := db.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 应用分页
|
||||
offset := (query.Page - 1) * query.PageSize
|
||||
if err := db.Offset(offset).Limit(query.PageSize).Order("created_at DESC").Find(&smsCodes).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return smsCodes, total, nil
|
||||
}
|
||||
|
||||
// CreateCode 创建验证码
|
||||
func (r *GormSMSCodeRepository) CreateCode(ctx context.Context, phone string, code string, purpose string) (entities.SMSCode, error) {
|
||||
smsCode := entities.SMSCode{
|
||||
Phone: phone,
|
||||
Code: code,
|
||||
Scene: entities.SMSScene(purpose), // 使用Scene字段
|
||||
ExpiresAt: time.Now().Add(5 * time.Minute), // 5分钟有效期
|
||||
}
|
||||
|
||||
if err := r.GetDB(ctx).Create(&smsCode).Error; err != nil {
|
||||
r.GetLogger().Error("创建短信验证码失败", zap.Error(err))
|
||||
return entities.SMSCode{}, err
|
||||
}
|
||||
|
||||
return smsCode, nil
|
||||
}
|
||||
|
||||
// ValidateCode 验证验证码
|
||||
func (r *GormSMSCodeRepository) ValidateCode(ctx context.Context, phone string, code string, purpose string) (bool, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.SMSCode{}).
|
||||
Where("phone = ? AND code = ? AND scene = ? AND expires_at > ? AND used_at IS NULL", phone, code, purpose, time.Now()).
|
||||
Count(&count).Error
|
||||
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// InvalidateCode 使验证码失效
|
||||
func (r *GormSMSCodeRepository) InvalidateCode(ctx context.Context, phone string) error {
|
||||
now := time.Now()
|
||||
return r.GetDB(ctx).Model(&entities.SMSCode{}).
|
||||
Where("phone = ? AND used_at IS NULL", phone).
|
||||
Update("used_at", &now).Error
|
||||
}
|
||||
|
||||
// CheckSendFrequency 检查发送频率
|
||||
func (r *GormSMSCodeRepository) CheckSendFrequency(ctx context.Context, phone string, purpose string) (bool, error) {
|
||||
// 检查1分钟内是否已发送
|
||||
oneMinuteAgo := time.Now().Add(-1 * time.Minute)
|
||||
var count int64
|
||||
|
||||
err := r.GetDB(ctx).Model(&entities.SMSCode{}).
|
||||
Where("phone = ? AND scene = ? AND created_at > ?", phone, purpose, oneMinuteAgo).
|
||||
Count(&count).Error
|
||||
|
||||
// 如果1分钟内已发送,则返回false(不允许发送)
|
||||
return count == 0, err
|
||||
}
|
||||
|
||||
// GetTodaySendCount 获取今日发送数量
|
||||
func (r *GormSMSCodeRepository) GetTodaySendCount(ctx context.Context, phone string) (int64, error) {
|
||||
today := time.Now().Truncate(24 * time.Hour)
|
||||
var count int64
|
||||
|
||||
err := r.GetDB(ctx).Model(&entities.SMSCode{}).
|
||||
Where("phone = ? AND created_at >= ?", phone, today).
|
||||
Count(&count).Error
|
||||
|
||||
return count, err
|
||||
}
|
||||
|
||||
// GetCodeStats 获取验证码统计
|
||||
func (r *GormSMSCodeRepository) GetCodeStats(ctx context.Context, phone string, days int) (*repositories.SMSCodeStats, error) {
|
||||
var stats repositories.SMSCodeStats
|
||||
|
||||
// 计算指定天数前的日期
|
||||
startDate := time.Now().AddDate(0, 0, -days)
|
||||
|
||||
// 总发送数
|
||||
if err := r.GetDB(ctx).
|
||||
Model(&entities.SMSCode{}).
|
||||
Where("phone = ? AND created_at >= ?", phone, startDate).
|
||||
Count(&stats.TotalSent).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 总验证数
|
||||
if err := r.GetDB(ctx).
|
||||
Model(&entities.SMSCode{}).
|
||||
Where("phone = ? AND created_at >= ? AND used_at IS NOT NULL", phone, startDate).
|
||||
Count(&stats.TotalValidated).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 成功率
|
||||
if stats.TotalSent > 0 {
|
||||
stats.SuccessRate = float64(stats.TotalValidated) / float64(stats.TotalSent) * 100
|
||||
}
|
||||
|
||||
// 今日发送数
|
||||
today := time.Now().Truncate(24 * time.Hour)
|
||||
if err := r.GetDB(ctx).
|
||||
Model(&entities.SMSCode{}).
|
||||
Where("phone = ? AND created_at >= ?", phone, today).
|
||||
Count(&stats.TodaySent).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &stats, nil
|
||||
}
|
||||
@@ -0,0 +1,720 @@
|
||||
//go:build !test
|
||||
// +build !test
|
||||
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"hyapi-server/internal/domains/user/entities"
|
||||
"hyapi-server/internal/domains/user/repositories"
|
||||
"hyapi-server/internal/domains/user/repositories/queries"
|
||||
"hyapi-server/internal/shared/database"
|
||||
"hyapi-server/internal/shared/interfaces"
|
||||
)
|
||||
|
||||
const (
|
||||
UsersTable = "users"
|
||||
UserCacheTTL = 30 * 60 // 30分钟
|
||||
)
|
||||
|
||||
// 定义错误常量
|
||||
var (
|
||||
// ErrUserNotFound 用户不存在错误
|
||||
ErrUserNotFound = errors.New("用户不存在")
|
||||
)
|
||||
|
||||
type GormUserRepository struct {
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
var _ repositories.UserRepository = (*GormUserRepository)(nil)
|
||||
|
||||
func NewGormUserRepository(db *gorm.DB, logger *zap.Logger) repositories.UserRepository {
|
||||
return &GormUserRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, UsersTable),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) Create(ctx context.Context, user entities.User) (entities.User, error) {
|
||||
err := r.CreateEntity(ctx, &user)
|
||||
return user, err
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) GetByID(ctx context.Context, id string) (entities.User, error) {
|
||||
var user entities.User
|
||||
err := r.SmartGetByID(ctx, id, &user)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return entities.User{}, errors.New("用户不存在")
|
||||
}
|
||||
return entities.User{}, err
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) GetByIDWithEnterpriseInfo(ctx context.Context, id string) (entities.User, error) {
|
||||
var user entities.User
|
||||
if err := r.GetDB(ctx).Preload("EnterpriseInfo").Where("id = ?", id).First(&user).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return entities.User{}, ErrUserNotFound
|
||||
}
|
||||
r.GetLogger().Error("根据ID查询用户失败", zap.Error(err))
|
||||
return entities.User{}, err
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) BatchGetByIDsWithEnterpriseInfo(ctx context.Context, ids []string) ([]*entities.User, error) {
|
||||
if len(ids) == 0 {
|
||||
return []*entities.User{}, nil
|
||||
}
|
||||
|
||||
var users []*entities.User
|
||||
if err := r.GetDB(ctx).Preload("EnterpriseInfo").Where("id IN ?", ids).Find(&users).Error; err != nil {
|
||||
r.GetLogger().Error("批量查询用户失败", zap.Error(err), zap.Strings("ids", ids))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) ExistsByUnifiedSocialCode(ctx context.Context, unifiedSocialCode string, excludeUserID string) (bool, error) {
|
||||
var count int64
|
||||
query := r.GetDB(ctx).Model(&entities.User{}).
|
||||
Joins("JOIN enterprise_infos ON users.id = enterprise_infos.user_id").
|
||||
Where("enterprise_infos.unified_social_code = ?", unifiedSocialCode)
|
||||
|
||||
// 如果指定了排除的用户ID,则排除该用户的记录
|
||||
if excludeUserID != "" {
|
||||
query = query.Where("users.id != ?", excludeUserID)
|
||||
}
|
||||
|
||||
err := query.Count(&count).Error
|
||||
if err != nil {
|
||||
r.GetLogger().Error("检查统一社会信用代码是否存在失败", zap.Error(err))
|
||||
return false, err
|
||||
}
|
||||
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) Update(ctx context.Context, user entities.User) error {
|
||||
return r.UpdateEntity(ctx, &user)
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) CreateBatch(ctx context.Context, users []entities.User) error {
|
||||
r.GetLogger().Info("批量创建用户", zap.Int("count", len(users)))
|
||||
return r.GetDB(ctx).Create(&users).Error
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.User, error) {
|
||||
var users []entities.User
|
||||
err := r.GetDB(ctx).Where("id IN ?", ids).Order("created_at DESC").Find(&users).Error
|
||||
return users, err
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) UpdateBatch(ctx context.Context, users []entities.User) error {
|
||||
r.GetLogger().Info("批量更新用户", zap.Int("count", len(users)))
|
||||
return r.GetDB(ctx).Save(&users).Error
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
r.GetLogger().Info("批量删除用户", zap.Strings("ids", ids))
|
||||
return r.GetDB(ctx).Delete(&entities.User{}, "id IN ?", ids).Error
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.User, error) {
|
||||
var users []entities.User
|
||||
err := r.SmartList(ctx, &users, options)
|
||||
return users, err
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) Delete(ctx context.Context, id string) error {
|
||||
return r.DeleteEntity(ctx, id, &entities.User{})
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
return r.ExistsEntity(ctx, id, &entities.User{})
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.User{}).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) SoftDelete(ctx context.Context, id string) error {
|
||||
return r.GetDB(ctx).Delete(&entities.User{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) Restore(ctx context.Context, id string) error {
|
||||
return r.GetDB(ctx).Unscoped().Model(&entities.User{}).Where("id = ?", id).Update("deleted_at", nil).Error
|
||||
}
|
||||
|
||||
// ================ 业务专用方法 ================
|
||||
|
||||
func (r *GormUserRepository) GetByPhone(ctx context.Context, phone string) (*entities.User, error) {
|
||||
var user entities.User
|
||||
if err := r.GetDB(ctx).Where("phone = ?", phone).First(&user).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
r.GetLogger().Error("根据手机号查询用户失败", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) GetByUsername(ctx context.Context, username string) (*entities.User, error) {
|
||||
var user entities.User
|
||||
if err := r.GetDB(ctx).Where("username = ?", username).First(&user).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
r.GetLogger().Error("根据用户名查询用户失败", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) GetByUserType(ctx context.Context, userType string) ([]*entities.User, error) {
|
||||
var users []*entities.User
|
||||
err := r.GetDB(ctx).Where("user_type = ?", userType).Order("created_at DESC").Find(&users).Error
|
||||
return users, err
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) ListUsers(ctx context.Context, query *queries.ListUsersQuery) ([]*entities.User, int64, error) {
|
||||
var users []*entities.User
|
||||
var total int64
|
||||
|
||||
// 构建查询条件,预加载企业信息
|
||||
db := r.GetDB(ctx).Model(&entities.User{}).Preload("EnterpriseInfo")
|
||||
|
||||
// 应用筛选条件
|
||||
if query.Phone != "" {
|
||||
db = db.Where("users.phone LIKE ?", "%"+query.Phone+"%")
|
||||
}
|
||||
if query.UserType != "" {
|
||||
db = db.Where("users.user_type = ?", query.UserType)
|
||||
}
|
||||
if query.IsActive != nil {
|
||||
db = db.Where("users.active = ?", *query.IsActive)
|
||||
}
|
||||
if query.IsCertified != nil {
|
||||
db = db.Where("users.is_certified = ?", *query.IsCertified)
|
||||
}
|
||||
if query.CompanyName != "" {
|
||||
db = db.Joins("LEFT JOIN enterprise_infos ON users.id = enterprise_infos.user_id").
|
||||
Where("enterprise_infos.company_name LIKE ?", "%"+query.CompanyName+"%")
|
||||
}
|
||||
if query.StartDate != "" {
|
||||
db = db.Where("users.created_at >= ?", query.StartDate)
|
||||
}
|
||||
if query.EndDate != "" {
|
||||
db = db.Where("users.created_at <= ?", query.EndDate)
|
||||
}
|
||||
|
||||
// 统计总数
|
||||
if err := db.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 应用排序(默认按创建时间倒序)
|
||||
db = db.Order("users.created_at DESC")
|
||||
|
||||
// 应用分页
|
||||
offset := (query.Page - 1) * query.PageSize
|
||||
if err := db.Offset(offset).Limit(query.PageSize).Find(&users).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return users, total, nil
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) ValidateUser(ctx context.Context, phone, password string) (*entities.User, error) {
|
||||
var user entities.User
|
||||
err := r.GetDB(ctx).Where("phone = ? AND password = ?", phone, password).First(&user).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) UpdateLastLogin(ctx context.Context, userID string) error {
|
||||
now := time.Now()
|
||||
return r.GetDB(ctx).Model(&entities.User{}).
|
||||
Where("id = ?", userID).
|
||||
Updates(map[string]interface{}{
|
||||
"last_login_at": &now,
|
||||
"updated_at": now,
|
||||
}).Error
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) UpdatePassword(ctx context.Context, userID string, newPassword string) error {
|
||||
return r.GetDB(ctx).Model(&entities.User{}).
|
||||
Where("id = ?", userID).
|
||||
Update("password", newPassword).Error
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) CheckPassword(ctx context.Context, userID string, password string) (bool, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.User{}).
|
||||
Where("id = ? AND password = ?", userID, password).
|
||||
Count(&count).Error
|
||||
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) ActivateUser(ctx context.Context, userID string) error {
|
||||
return r.GetDB(ctx).Model(&entities.User{}).
|
||||
Where("id = ?", userID).
|
||||
Update("active", true).Error
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) DeactivateUser(ctx context.Context, userID string) error {
|
||||
return r.GetDB(ctx).Model(&entities.User{}).
|
||||
Where("id = ?", userID).
|
||||
Update("active", false).Error
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) UpdateLoginStats(ctx context.Context, userID string) error {
|
||||
return r.GetDB(ctx).Model(&entities.User{}).
|
||||
Where("id = ?", userID).
|
||||
Updates(map[string]interface{}{
|
||||
"login_count": gorm.Expr("login_count + 1"),
|
||||
"last_login_at": time.Now(),
|
||||
}).Error
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) GetStats(ctx context.Context) (*repositories.UserStats, error) {
|
||||
var stats repositories.UserStats
|
||||
|
||||
db := r.GetDB(ctx)
|
||||
|
||||
// 总用户数
|
||||
if err := db.Model(&entities.User{}).Count(&stats.TotalUsers).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 活跃用户数
|
||||
if err := db.Model(&entities.User{}).Where("active = ?", true).Count(&stats.ActiveUsers).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 已认证用户数
|
||||
if err := db.Model(&entities.User{}).Where("is_certified = ?", true).Count(&stats.CertifiedUsers).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 今日注册数
|
||||
today := time.Now().Truncate(24 * time.Hour)
|
||||
if err := db.Model(&entities.User{}).Where("created_at >= ?", today).Count(&stats.TodayRegistrations).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 今日登录数
|
||||
if err := db.Model(&entities.User{}).Where("last_login_at >= ?", today).Count(&stats.TodayLogins).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &stats, nil
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) GetStatsByDateRange(ctx context.Context, startDate, endDate string) (*repositories.UserStats, error) {
|
||||
var stats repositories.UserStats
|
||||
|
||||
db := r.GetDB(ctx)
|
||||
|
||||
// 指定时间范围内的注册数
|
||||
if err := db.Model(&entities.User{}).
|
||||
Where("created_at >= ? AND created_at <= ?", startDate, endDate).
|
||||
Count(&stats.TodayRegistrations).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 指定时间范围内的登录数
|
||||
if err := db.Model(&entities.User{}).
|
||||
Where("last_login_at >= ? AND last_login_at <= ?", startDate, endDate).
|
||||
Count(&stats.TodayLogins).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &stats, nil
|
||||
}
|
||||
|
||||
// GetSystemUserStats 获取系统用户统计信息
|
||||
func (r *GormUserRepository) GetSystemUserStats(ctx context.Context) (*repositories.UserStats, error) {
|
||||
var stats repositories.UserStats
|
||||
|
||||
db := r.GetDB(ctx)
|
||||
|
||||
// 总用户数
|
||||
if err := db.Model(&entities.User{}).Count(&stats.TotalUsers).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 活跃用户数(最近30天有登录)
|
||||
thirtyDaysAgo := time.Now().AddDate(0, 0, -30)
|
||||
if err := db.Model(&entities.User{}).Where("last_login_at >= ?", thirtyDaysAgo).Count(&stats.ActiveUsers).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 已认证用户数
|
||||
if err := db.Model(&entities.User{}).Where("is_certified = ?", true).Count(&stats.CertifiedUsers).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 今日注册数
|
||||
today := time.Now().Truncate(24 * time.Hour)
|
||||
if err := db.Model(&entities.User{}).Where("created_at >= ?", today).Count(&stats.TodayRegistrations).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 今日登录数
|
||||
if err := db.Model(&entities.User{}).Where("last_login_at >= ?", today).Count(&stats.TodayLogins).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &stats, nil
|
||||
}
|
||||
|
||||
// GetSystemUserStatsByDateRange 获取系统指定时间范围内的用户统计信息
|
||||
func (r *GormUserRepository) GetSystemUserStatsByDateRange(ctx context.Context, startDate, endDate time.Time) (*repositories.UserStats, error) {
|
||||
var stats repositories.UserStats
|
||||
|
||||
db := r.GetDB(ctx)
|
||||
|
||||
// 指定时间范围内的注册数
|
||||
if err := db.Model(&entities.User{}).
|
||||
Where("created_at >= ? AND created_at <= ?", startDate, endDate).
|
||||
Count(&stats.TodayRegistrations).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 指定时间范围内的登录数
|
||||
if err := db.Model(&entities.User{}).
|
||||
Where("last_login_at >= ? AND last_login_at <= ?", startDate, endDate).
|
||||
Count(&stats.TodayLogins).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &stats, nil
|
||||
}
|
||||
|
||||
// GetSystemDailyUserStats 获取系统每日用户统计
|
||||
func (r *GormUserRepository) GetSystemDailyUserStats(ctx context.Context, startDate, endDate time.Time) ([]map[string]interface{}, error) {
|
||||
var results []map[string]interface{}
|
||||
|
||||
sql := `
|
||||
SELECT
|
||||
DATE(created_at) as date,
|
||||
COUNT(*) as count
|
||||
FROM users
|
||||
WHERE DATE(created_at) >= $1
|
||||
AND DATE(created_at) <= $2
|
||||
GROUP BY DATE(created_at)
|
||||
ORDER BY date ASC
|
||||
`
|
||||
|
||||
err := r.GetDB(ctx).Raw(sql, startDate.Format("2006-01-02"), endDate.Format("2006-01-02")).Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetSystemMonthlyUserStats 获取系统每月用户统计
|
||||
func (r *GormUserRepository) GetSystemMonthlyUserStats(ctx context.Context, startDate, endDate time.Time) ([]map[string]interface{}, error) {
|
||||
var results []map[string]interface{}
|
||||
|
||||
sql := `
|
||||
SELECT
|
||||
TO_CHAR(created_at, 'YYYY-MM') as month,
|
||||
COUNT(*) as count
|
||||
FROM users
|
||||
WHERE created_at >= $1
|
||||
AND created_at <= $2
|
||||
GROUP BY TO_CHAR(created_at, 'YYYY-MM')
|
||||
ORDER BY month ASC
|
||||
`
|
||||
|
||||
err := r.GetDB(ctx).Raw(sql, startDate, endDate).Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetSystemDailyCertificationStats 获取系统每日认证用户统计(基于is_certified字段)
|
||||
func (r *GormUserRepository) GetSystemDailyCertificationStats(ctx context.Context, startDate, endDate time.Time) ([]map[string]interface{}, error) {
|
||||
var results []map[string]interface{}
|
||||
|
||||
sql := `
|
||||
SELECT
|
||||
DATE(updated_at) as date,
|
||||
COUNT(*) as count
|
||||
FROM users
|
||||
WHERE is_certified = true
|
||||
AND DATE(updated_at) >= $1
|
||||
AND DATE(updated_at) <= $2
|
||||
GROUP BY DATE(updated_at)
|
||||
ORDER BY date ASC
|
||||
`
|
||||
|
||||
err := r.GetDB(ctx).Raw(sql, startDate.Format("2006-01-02"), endDate.Format("2006-01-02")).Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetSystemMonthlyCertificationStats 获取系统每月认证用户统计(基于is_certified字段)
|
||||
func (r *GormUserRepository) GetSystemMonthlyCertificationStats(ctx context.Context, startDate, endDate time.Time) ([]map[string]interface{}, error) {
|
||||
var results []map[string]interface{}
|
||||
|
||||
sql := `
|
||||
SELECT
|
||||
TO_CHAR(updated_at, 'YYYY-MM') as month,
|
||||
COUNT(*) as count
|
||||
FROM users
|
||||
WHERE is_certified = true
|
||||
AND updated_at >= $1
|
||||
AND updated_at <= $2
|
||||
GROUP BY TO_CHAR(updated_at, 'YYYY-MM')
|
||||
ORDER BY month ASC
|
||||
`
|
||||
|
||||
err := r.GetDB(ctx).Raw(sql, startDate, endDate).Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetUserCallRankingByCalls 按调用次数获取用户排行
|
||||
func (r *GormUserRepository) GetUserCallRankingByCalls(ctx context.Context, period string, limit int) ([]map[string]interface{}, error) {
|
||||
var sql string
|
||||
var args []interface{}
|
||||
|
||||
switch period {
|
||||
case "today":
|
||||
sql = `
|
||||
SELECT
|
||||
u.id as user_id,
|
||||
COALESCE(ei.company_name, u.username, u.phone) as username,
|
||||
COUNT(ac.id) as calls
|
||||
FROM users u
|
||||
LEFT JOIN enterprise_infos ei ON u.id = ei.user_id
|
||||
LEFT JOIN api_calls ac ON u.id = ac.user_id
|
||||
AND DATE(ac.created_at) = CURRENT_DATE
|
||||
WHERE u.deleted_at IS NULL
|
||||
GROUP BY u.id, ei.company_name, u.username, u.phone
|
||||
HAVING COUNT(ac.id) > 0
|
||||
ORDER BY calls DESC
|
||||
LIMIT $1
|
||||
`
|
||||
args = []interface{}{limit}
|
||||
case "month":
|
||||
sql = `
|
||||
SELECT
|
||||
u.id as user_id,
|
||||
COALESCE(ei.company_name, u.username, u.phone) as username,
|
||||
COUNT(ac.id) as calls
|
||||
FROM users u
|
||||
LEFT JOIN enterprise_infos ei ON u.id = ei.user_id
|
||||
LEFT JOIN api_calls ac ON u.id = ac.user_id
|
||||
AND DATE_TRUNC('month', ac.created_at) = DATE_TRUNC('month', CURRENT_DATE)
|
||||
WHERE u.deleted_at IS NULL
|
||||
GROUP BY u.id, ei.company_name, u.username, u.phone
|
||||
HAVING COUNT(ac.id) > 0
|
||||
ORDER BY calls DESC
|
||||
LIMIT $1
|
||||
`
|
||||
args = []interface{}{limit}
|
||||
case "total":
|
||||
sql = `
|
||||
SELECT
|
||||
u.id as user_id,
|
||||
COALESCE(ei.company_name, u.username, u.phone) as username,
|
||||
COUNT(ac.id) as calls
|
||||
FROM users u
|
||||
LEFT JOIN enterprise_infos ei ON u.id = ei.user_id
|
||||
LEFT JOIN api_calls ac ON u.id = ac.user_id
|
||||
WHERE u.deleted_at IS NULL
|
||||
GROUP BY u.id, ei.company_name, u.username, u.phone
|
||||
HAVING COUNT(ac.id) > 0
|
||||
ORDER BY calls DESC
|
||||
LIMIT $1
|
||||
`
|
||||
args = []interface{}{limit}
|
||||
default:
|
||||
return nil, fmt.Errorf("不支持的时间周期: %s", period)
|
||||
}
|
||||
|
||||
var results []map[string]interface{}
|
||||
err := r.GetDB(ctx).Raw(sql, args...).Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetUserCallRankingByConsumption 按消费金额获取用户排行
|
||||
func (r *GormUserRepository) GetUserCallRankingByConsumption(ctx context.Context, period string, limit int) ([]map[string]interface{}, error) {
|
||||
var sql string
|
||||
var args []interface{}
|
||||
|
||||
switch period {
|
||||
case "today":
|
||||
sql = `
|
||||
SELECT
|
||||
u.id as user_id,
|
||||
COALESCE(ei.company_name, u.username, u.phone) as username,
|
||||
COALESCE(SUM(wt.amount), 0) as consumption
|
||||
FROM users u
|
||||
LEFT JOIN enterprise_infos ei ON u.id = ei.user_id
|
||||
LEFT JOIN wallet_transactions wt ON u.id = wt.user_id
|
||||
AND DATE(wt.created_at) = CURRENT_DATE
|
||||
WHERE u.deleted_at IS NULL
|
||||
GROUP BY u.id, ei.company_name, u.username, u.phone
|
||||
HAVING COALESCE(SUM(wt.amount), 0) > 0
|
||||
ORDER BY consumption DESC
|
||||
LIMIT $1
|
||||
`
|
||||
args = []interface{}{limit}
|
||||
case "month":
|
||||
sql = `
|
||||
SELECT
|
||||
u.id as user_id,
|
||||
COALESCE(ei.company_name, u.username, u.phone) as username,
|
||||
COALESCE(SUM(wt.amount), 0) as consumption
|
||||
FROM users u
|
||||
LEFT JOIN enterprise_infos ei ON u.id = ei.user_id
|
||||
LEFT JOIN wallet_transactions wt ON u.id = wt.user_id
|
||||
AND DATE_TRUNC('month', wt.created_at) = DATE_TRUNC('month', CURRENT_DATE)
|
||||
WHERE u.deleted_at IS NULL
|
||||
GROUP BY u.id, ei.company_name, u.username, u.phone
|
||||
HAVING COALESCE(SUM(wt.amount), 0) > 0
|
||||
ORDER BY consumption DESC
|
||||
LIMIT $1
|
||||
`
|
||||
args = []interface{}{limit}
|
||||
case "total":
|
||||
sql = `
|
||||
SELECT
|
||||
u.id as user_id,
|
||||
COALESCE(ei.company_name, u.username, u.phone) as username,
|
||||
COALESCE(SUM(wt.amount), 0) as consumption
|
||||
FROM users u
|
||||
LEFT JOIN enterprise_infos ei ON u.id = ei.user_id
|
||||
LEFT JOIN wallet_transactions wt ON u.id = wt.user_id
|
||||
WHERE u.deleted_at IS NULL
|
||||
GROUP BY u.id, ei.company_name, u.username, u.phone
|
||||
HAVING COALESCE(SUM(wt.amount), 0) > 0
|
||||
ORDER BY consumption DESC
|
||||
LIMIT $1
|
||||
`
|
||||
args = []interface{}{limit}
|
||||
default:
|
||||
return nil, fmt.Errorf("不支持的时间周期: %s", period)
|
||||
}
|
||||
|
||||
var results []map[string]interface{}
|
||||
err := r.GetDB(ctx).Raw(sql, args...).Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetRechargeRanking 获取充值排行(排除赠送,只统计成功状态)
|
||||
func (r *GormUserRepository) GetRechargeRanking(ctx context.Context, period string, limit int) ([]map[string]interface{}, error) {
|
||||
var sql string
|
||||
var args []interface{}
|
||||
|
||||
switch period {
|
||||
case "today":
|
||||
sql = `
|
||||
SELECT
|
||||
u.id as user_id,
|
||||
COALESCE(ei.company_name, u.username, u.phone) as username,
|
||||
COALESCE(SUM(rr.amount), 0) as amount
|
||||
FROM users u
|
||||
LEFT JOIN enterprise_infos ei ON u.id = ei.user_id
|
||||
LEFT JOIN recharge_records rr ON u.id = rr.user_id
|
||||
AND DATE(rr.created_at) = CURRENT_DATE
|
||||
AND rr.status = 'success'
|
||||
AND rr.recharge_type != 'gift'
|
||||
WHERE u.deleted_at IS NULL
|
||||
GROUP BY u.id, ei.company_name, u.username, u.phone
|
||||
HAVING COALESCE(SUM(rr.amount), 0) > 0
|
||||
ORDER BY amount DESC
|
||||
LIMIT $1
|
||||
`
|
||||
args = []interface{}{limit}
|
||||
case "month":
|
||||
sql = `
|
||||
SELECT
|
||||
u.id as user_id,
|
||||
COALESCE(ei.company_name, u.username, u.phone) as username,
|
||||
COALESCE(SUM(rr.amount), 0) as amount
|
||||
FROM users u
|
||||
LEFT JOIN enterprise_infos ei ON u.id = ei.user_id
|
||||
LEFT JOIN recharge_records rr ON u.id = rr.user_id
|
||||
AND DATE_TRUNC('month', rr.created_at) = DATE_TRUNC('month', CURRENT_DATE)
|
||||
AND rr.status = 'success'
|
||||
AND rr.recharge_type != 'gift'
|
||||
WHERE u.deleted_at IS NULL
|
||||
GROUP BY u.id, ei.company_name, u.username, u.phone
|
||||
HAVING COALESCE(SUM(rr.amount), 0) > 0
|
||||
ORDER BY amount DESC
|
||||
LIMIT $1
|
||||
`
|
||||
args = []interface{}{limit}
|
||||
case "total":
|
||||
sql = `
|
||||
SELECT
|
||||
u.id as user_id,
|
||||
COALESCE(ei.company_name, u.username, u.phone) as username,
|
||||
COALESCE(SUM(rr.amount), 0) as amount
|
||||
FROM users u
|
||||
LEFT JOIN enterprise_infos ei ON u.id = ei.user_id
|
||||
LEFT JOIN recharge_records rr ON u.id = rr.user_id
|
||||
AND rr.status = 'success'
|
||||
AND rr.recharge_type != 'gift'
|
||||
WHERE u.deleted_at IS NULL
|
||||
GROUP BY u.id, ei.company_name, u.username, u.phone
|
||||
HAVING COALESCE(SUM(rr.amount), 0) > 0
|
||||
ORDER BY amount DESC
|
||||
LIMIT $1
|
||||
`
|
||||
args = []interface{}{limit}
|
||||
default:
|
||||
return nil, fmt.Errorf("不支持的时间周期: %s", period)
|
||||
}
|
||||
|
||||
var results []map[string]interface{}
|
||||
err := r.GetDB(ctx).Raw(sql, args...).Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
Reference in New Issue
Block a user