Skip to content

Commit

Permalink
Skip FullSync operation when world_size == 1 (#1065)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #1065

Test Plan: Imported from OSS

Reviewed By: ejguan

Differential Revision: D43751382

Pulled By: NivekT

fbshipit-source-id: cd85e6dc360494289b82facb0cee34502a178c27
  • Loading branch information
NivekT authored and ejguan committed Apr 20, 2023
1 parent 1957295 commit e78c1a9
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 10 deletions.
7 changes: 5 additions & 2 deletions test/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
if dist.is_nccl_available() and torch.cuda.device_count() > 0:
_backends.append("nccl")


world_size_parametrize = parametrize("world_size", [1, DEFAULT_WORLD_SIZE])
backend_parametrize = parametrize("backend", _backends)


Expand Down Expand Up @@ -149,9 +151,10 @@ def _test_fullsync(rank, world_size, backend, q):

_finalize_distributed_queue(rank, q)

@world_size_parametrize
@backend_parametrize
def test_fullsync(self, backend) -> None:
world_size = DEFAULT_WORLD_SIZE if backend != "nccl" else torch.cuda.device_count()
def test_fullsync(self, world_size, backend) -> None:
world_size = world_size if backend != "nccl" else torch.cuda.device_count()
launch_distributed_training(backend, world_size, fn=DistributedTest._test_fullsync)

@staticmethod
Expand Down
18 changes: 10 additions & 8 deletions torchdata/datapipes/iter/util/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,13 +171,17 @@ def _callback_fn(self, exp: Expected) -> None:

def __iter__(self) -> Iterator[T_co]:
assert self._executor is None

if not (dist.is_available() and dist.is_initialized()):
raise RuntimeError("Torch Distributed is required to be initialized")
raise RuntimeError("Torch Distributed is required to be initialized to use `FullSync`.")

if self._process_group is None:
self._process_group = dist.new_group(backend="gloo")
self._world_size = dist.get_world_size()

if self._world_size == 1: # The below functionalities are not needed if `_world_size == 1`
yield from self.datapipe
return

self._executor = _PrefetchExecutor(iter(self.datapipe), 1, self._callback_fn, self.timeout)
while True:
with self._cv:
Expand Down Expand Up @@ -231,11 +235,9 @@ def __setstate__(self, state):
self._done_callback = False

def pause(self):
raise RuntimeError("`pause` is not supported for FullSync at the moment.")
# if self._executor is not None:
# self._executor.shutdown()
# self._executor = None
if self._world_size > 1 and self._executor is not None:
raise RuntimeError("`pause` is not supported for FullSync at the moment.")

def resume(self):
raise RuntimeError("`resume` is not supported for FullSync at the moment.")
# self._executor = _PrefetchExecutor(iter(self.datapipe), 1, self._callback_fn, self.timeout)
if self._world_size > 1 and self._executor is not None:
raise RuntimeError("`resume` is not supported for FullSync at the moment.")

0 comments on commit e78c1a9

Please sign in to comment.