diff --git a/test/test_distributed.py b/test/test_distributed.py index 9d70d1a86..176e0ee84 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -42,6 +42,8 @@ if dist.is_nccl_available() and torch.cuda.device_count() > 0: _backends.append("nccl") + +world_size_parametrize = parametrize("world_size", [1, DEFAULT_WORLD_SIZE]) backend_parametrize = parametrize("backend", _backends) @@ -149,9 +151,10 @@ def _test_fullsync(rank, world_size, backend, q): _finalize_distributed_queue(rank, q) + @world_size_parametrize @backend_parametrize - def test_fullsync(self, backend) -> None: - world_size = DEFAULT_WORLD_SIZE if backend != "nccl" else torch.cuda.device_count() + def test_fullsync(self, world_size, backend) -> None: + world_size = world_size if backend != "nccl" else torch.cuda.device_count() launch_distributed_training(backend, world_size, fn=DistributedTest._test_fullsync) @staticmethod diff --git a/torchdata/datapipes/iter/util/distributed.py b/torchdata/datapipes/iter/util/distributed.py index 2e6fe813c..bc90481e8 100644 --- a/torchdata/datapipes/iter/util/distributed.py +++ b/torchdata/datapipes/iter/util/distributed.py @@ -171,13 +171,17 @@ def _callback_fn(self, exp: Expected) -> None: def __iter__(self) -> Iterator[T_co]: assert self._executor is None - if not (dist.is_available() and dist.is_initialized()): - raise RuntimeError("Torch Distributed is required to be initialized") + raise RuntimeError("Torch Distributed is required to be initialized to use `FullSync`.") + if self._process_group is None: self._process_group = dist.new_group(backend="gloo") self._world_size = dist.get_world_size() + if self._world_size == 1: # The below functionalities are not needed if `_world_size == 1` + yield from self.datapipe + return + self._executor = _PrefetchExecutor(iter(self.datapipe), 1, self._callback_fn, self.timeout) while True: with self._cv: @@ -231,11 +235,9 @@ def __setstate__(self, state): self._done_callback = False def pause(self): - raise RuntimeError("`pause` is not supported for FullSync at the moment.") - # if self._executor is not None: - # self._executor.shutdown() - # self._executor = None + if self._world_size > 1 and self._executor is not None: + raise RuntimeError("`pause` is not supported for FullSync at the moment.") def resume(self): - raise RuntimeError("`resume` is not supported for FullSync at the moment.") - # self._executor = _PrefetchExecutor(iter(self.datapipe), 1, self._callback_fn, self.timeout) + if self._world_size > 1 and self._executor is not None: + raise RuntimeError("`resume` is not supported for FullSync at the moment.")