package repositories import ( "context" "errors" "fmt" "time" "go.uber.org/zap" "gorm.io/gorm" "tyapi-server/internal/domains/user/entities" "tyapi-server/internal/shared/interfaces" ) // 定义错误常量 var ( // ErrUserNotFound 用户不存在错误 ErrUserNotFound = errors.New("用户不存在") ) // UserRepository 用户仓储实现 type UserRepository struct { db *gorm.DB cache interfaces.CacheService logger *zap.Logger } // NewUserRepository 创建用户仓储 func NewUserRepository(db *gorm.DB, cache interfaces.CacheService, logger *zap.Logger) *UserRepository { return &UserRepository{ db: db, cache: cache, logger: logger, } } // Create 创建用户 func (r *UserRepository) Create(ctx context.Context, user *entities.User) error { if err := r.db.WithContext(ctx).Create(user).Error; err != nil { r.logger.Error("创建用户失败", zap.Error(err)) return err } // 清除相关缓存 r.deleteCacheByPhone(ctx, user.Phone) r.logger.Info("用户创建成功", zap.String("user_id", user.ID)) return nil } // GetByID 根据ID获取用户 func (r *UserRepository) GetByID(ctx context.Context, id string) (*entities.User, error) { // 尝试从缓存获取 cacheKey := fmt.Sprintf("user:id:%s", id) var user entities.User if err := r.cache.Get(ctx, cacheKey, &user); err == nil { return &user, nil } // 从数据库查询 if err := r.db.WithContext(ctx).Where("id = ?", id).First(&user).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrUserNotFound } r.logger.Error("根据ID查询用户失败", zap.Error(err)) return nil, err } // 缓存结果 r.cache.Set(ctx, cacheKey, &user, 10*time.Minute) return &user, nil } // FindByPhone 根据手机号查找用户 func (r *UserRepository) FindByPhone(ctx context.Context, phone string) (*entities.User, error) { // 尝试从缓存获取 cacheKey := fmt.Sprintf("user:phone:%s", phone) var user entities.User if err := r.cache.Get(ctx, cacheKey, &user); err == nil { return &user, nil } // 从数据库查询 if err := r.db.WithContext(ctx).Where("phone = ?", phone).First(&user).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrUserNotFound } r.logger.Error("根据手机号查询用户失败", zap.Error(err)) return nil, err } // 缓存结果 r.cache.Set(ctx, cacheKey, &user, 10*time.Minute) return &user, nil } // Update 更新用户 func (r *UserRepository) Update(ctx context.Context, user *entities.User) error { if err := r.db.WithContext(ctx).Save(user).Error; err != nil { r.logger.Error("更新用户失败", zap.Error(err)) return err } // 清除相关缓存 r.deleteCacheByID(ctx, user.ID) r.deleteCacheByPhone(ctx, user.Phone) r.logger.Info("用户更新成功", zap.String("user_id", user.ID)) return nil } // Delete 删除用户 func (r *UserRepository) Delete(ctx context.Context, id string) error { // 先获取用户信息用于清除缓存 user, err := r.GetByID(ctx, id) if err != nil { return err } if err := r.db.WithContext(ctx).Delete(&entities.User{}, "id = ?", id).Error; err != nil { r.logger.Error("删除用户失败", zap.Error(err)) return err } // 清除相关缓存 r.deleteCacheByID(ctx, id) r.deleteCacheByPhone(ctx, user.Phone) r.logger.Info("用户删除成功", zap.String("user_id", id)) return nil } // List 分页获取用户列表 func (r *UserRepository) List(ctx context.Context, offset, limit int) ([]*entities.User, error) { var users []*entities.User if err := r.db.WithContext(ctx).Offset(offset).Limit(limit).Find(&users).Error; err != nil { r.logger.Error("查询用户列表失败", zap.Error(err)) return nil, err } return users, nil } // Count 获取用户总数 func (r *UserRepository) Count(ctx context.Context) (int64, error) { var count int64 if err := r.db.WithContext(ctx).Model(&entities.User{}).Count(&count).Error; err != nil { r.logger.Error("统计用户数量失败", zap.Error(err)) return 0, err } return count, nil } // ExistsByPhone 检查手机号是否存在 func (r *UserRepository) ExistsByPhone(ctx context.Context, phone string) (bool, error) { var count int64 if err := r.db.WithContext(ctx).Model(&entities.User{}).Where("phone = ?", phone).Count(&count).Error; err != nil { r.logger.Error("检查手机号是否存在失败", zap.Error(err)) return false, err } return count > 0, nil } // 私有辅助方法 // deleteCacheByID 根据ID删除缓存 func (r *UserRepository) deleteCacheByID(ctx context.Context, id string) { cacheKey := fmt.Sprintf("user:id:%s", id) if err := r.cache.Delete(ctx, cacheKey); err != nil { r.logger.Warn("删除用户ID缓存失败", zap.String("cache_key", cacheKey), zap.Error(err)) } } // deleteCacheByPhone 根据手机号删除缓存 func (r *UserRepository) deleteCacheByPhone(ctx context.Context, phone string) { cacheKey := fmt.Sprintf("user:phone:%s", phone) if err := r.cache.Delete(ctx, cacheKey); err != nil { r.logger.Warn("删除用户手机号缓存失败", zap.String("cache_key", cacheKey), zap.Error(err)) } }