Skip to content

Commit

Permalink
Fix worker_init_fn to update DataPipe graph and move worker prefetc…
Browse files Browse the repository at this point in the history
…h to the end of Worker pipeline (#1100)

Summary:
Fixes #1083

### Changes

- Fix worker loop to update the DataPipe graph with `worker_init_fn`
  - Add corresponding Tests
- Since `worker_init_fn` function might attach new DataPipe to worker graph, guarantee `prefetch` in worker attached to the end of the pipeline in worker process

Pull Request resolved: #1100

Reviewed By: NivekT

Differential Revision: D44221216

Pulled By: ejguan

fbshipit-source-id: dfcbaa3e8a0d82df6dd21d308cc94f6d06bcf8f4
  • Loading branch information
ejguan committed Apr 20, 2023
1 parent aff31f9 commit 1957295
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 6 deletions.
14 changes: 13 additions & 1 deletion test/dataloader2/test_dataloader2.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,9 +386,21 @@ def test_worker_fns(self, ctx):
)
dl = DataLoader2(dp, reading_service=rs)

# Test worker_reset_fn to set the same random seed across epoches
res1 = list(dl)
res2 = list(dl)

# Test worker_init_fn to set sharding
def _expand_fn(res):
result = []
for batch in res:
result.extend(batch)
return result

exp = list(range(100))
self.assertEqual(sorted(_expand_fn(res1)), exp)
self.assertEqual(sorted(_expand_fn(res2)), exp)

# Test worker_reset_fn to set the same random seed across epoches
self.assertEqual(res1, res2)

@mp_ctx_parametrize
Expand Down
17 changes: 16 additions & 1 deletion torchdata/dataloader2/communication/eventloop.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ def MultipleDataPipesToQueuesLoop(source_datapipes, req_queues, res_queues, proc
r"""
Set the appropriate pipes and protocol server type, and create a loop over multiple datapipes
with the protocol server in a non-blocking manner.
Args:
source_datapipe: DataPipe being iterated in the dispatching process
req_queue: Multiprocessing queue providing requests from the worker process
res_queue: Multiprocessing queue sending results to the worker process
process_name: The name of process (used for logging and exception handling)
call_on_process_init: Not allowed by dispatching process for now.
"""
assert call_on_process_init is None, "``MultipleDataPipesToQueuesLoop`` does not support call_on_process_init"
num_loops = len(source_datapipes)
Expand Down Expand Up @@ -104,12 +111,20 @@ def DataPipeToQueuesLoop(source_datapipe, req_queue, res_queue, process_name, ca
r"""
Initialize with the given init function, set the appropriate pipe and protocol server type, and
create a loop with the protocol server.
Args:
source_datapipe: DataPipe being iterated in the worker process
req_queue: Multiprocessing queue providing requests from the main process
res_queue: Multiprocessing queue sending results to the main process
process_name: The name of process (used for logging and exception handling)
call_on_process_init: Callable function will be called at the time of worker process initialization.
Users can provide it to modify the DataPipe grpah in the worker process.
"""
# Extract Serialization Wrapper
source_datapipe = extract_wrapper(source_datapipe)

if call_on_process_init is not None:
call_on_process_init(source_datapipe)
source_datapipe = call_on_process_init(source_datapipe)

torch.set_num_threads(1)

Expand Down
6 changes: 2 additions & 4 deletions torchdata/dataloader2/reading_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def initialize(self, datapipe: DataPipe) -> DataPipe:
if not self._mp:
# TODO(616): Warn and recommend usage of InProcessReadingService
worker_info = WorkerInfo(1, 0)
process_init_fn(datapipe, worker_info, self.worker_init_fn)
datapipe = process_init_fn(datapipe, worker_info, self.worker_init_fn)
self._end_datapipe = datapipe
return datapipe

Expand Down Expand Up @@ -260,9 +260,6 @@ def initialize(self, datapipe: DataPipe) -> DataPipe:
len(replicable_dps) == 1
), "MultiProcessingReadingService only supports single replicable branch currently"
replicable_dp = replicable_dps[0]

if self.worker_prefetch_cnt > 0:
replicable_dp = replicable_dp.prefetch(self.worker_prefetch_cnt)
replicable_dp = attach_wrapper(replicable_dp)

for worker_id in range(self.num_workers):
Expand All @@ -274,6 +271,7 @@ def initialize(self, datapipe: DataPipe) -> DataPipe:
process_init_fn,
worker_info=worker_info,
custom_init_fn=self.worker_init_fn,
worker_prefetch_cnt=self.worker_prefetch_cnt,
dispatching_req_queue=dispatching_req_queue,
dispatching_res_queue=dispatching_res_queue,
)
Expand Down
4 changes: 4 additions & 0 deletions torchdata/dataloader2/utils/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def process_init_fn(
datapipe: DataPipe,
worker_info: WorkerInfo,
custom_init_fn: Optional[Callable[[DataPipe, WorkerInfo], DataPipe]] = None,
worker_prefetch_cnt: int = 0,
dispatching_req_queue: Optional[Queue] = None,
dispatching_res_queue: Optional[Queue] = None,
) -> DataPipe:
Expand Down Expand Up @@ -96,6 +97,9 @@ def process_init_fn(
datapipe = custom_init_fn(datapipe, worker_info)
assert isinstance(datapipe, (IterDataPipe, MapDataPipe))

if worker_prefetch_cnt > 0:
datapipe = datapipe.prefetch(worker_prefetch_cnt)

return datapipe


Expand Down

0 comments on commit 1957295

Please sign in to comment.