Skip to content

Commit

Permalink
changes to support DLIO with MuMMI and DYAD
Browse files Browse the repository at this point in the history
  • Loading branch information
hariharan-devarajan committed Mar 5, 2024
1 parent 549d49f commit f30904c
Show file tree
Hide file tree
Showing 6 changed files with 311 additions and 26 deletions.
26 changes: 11 additions & 15 deletions pydyad/pydyad/hdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,35 +7,31 @@

class DyadFile(h5py.File):

def __init__(self, *args, **kwargs, dyad_ctx=None, metadata_wrapper=None):
def __init__(self, fname, mode, file=None, dyad_ctx=None, metadata_wrapper=None):
# According to H5PY, the first positional argument to File.__init__ is fname
self.fname = args[0]
self.fname = fname
if not isinstance(self.fname, Path):
self.fname = Path(self.fname)
self.fname = Path(fname)
self.fname = self.fname.expanduser().resolve()
self.mode = None
if "mode" in kwargs:
self.mode = kwargs["mode"]
elif len(args) > 1:
self.mode = args[1]
else:
raise NameError("'mode' argument not provided to pydyad.hdf.File constructor")
self.m = mode
if dyad_ctx is None:
raise NameError("'dyad_ctx' argument not provided to pydyad.hdf.File constructor")
self.dyad_ctx = dyad_ctx
if self.mode in ("r", "rb", "rt"):
if self.m in ("r"):
if (self.dyad_ctx.cons_path is not None and
self.dyad_ctx.cons_path in self.fname.parents):
if metadata_wrapper:
self.dyad_ctx.consume_w_metadata(str(self.fname), meatadata_wrapper)
self.dyad_ctx.consume_w_metadata(str(self.fname), metadata_wrapper)
else:
dyad_ctx.consume(str(self.fname))
super().__init__(*args, **kwargs)

if file:
super().__init__(file, mode)
else:
super().__init__(fname, mode)

def close(self):
super().close()
if self.mode in ("w", "wb", "wt"):
if self.m in ("w", "r+"):
if (self.dyad_ctx.prod_path is not None and
self.dyad_ctx.prod_path in self.fname.parents):
self.dyad_ctx.produce(str(self.fname))
30 changes: 30 additions & 0 deletions tests/integration/dlio_benchmark/configs/workload/dyad_mummi.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
model: mummi

framework: pytorch

workflow:
generate_data: False
train: True

dataset:
data_folder: data/mummi/
format: hdf5
num_files_train: 600
num_samples_per_file: 8000
record_length: 69528
enable_chunking: True
chunk_size: 17799168

reader:
data_loader: pytorch
batch_size: 256
read_threads: 6
file_shuffle: seed
sample_shuffle: seed
multiprocessing_context: spawn
data_loader_classname: dyad_h5_torch_data_loader.DyadH5TorchDataLoader
data_loader_sampler: index

train:
epochs: 10
computation_time: 0
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
model: mummi

framework: pytorch

workflow:
generate_data: False
train: True

dataset:
data_folder: data/mummi/
format: hdf5
num_files_train: 1
num_samples_per_file: 100
record_length: 69528
enable_chunking: True
chunk_size: 17799168

reader:
data_loader: pytorch
batch_size: 1
read_threads: 2
file_shuffle: seed
sample_shuffle: seed
multiprocessing_context: spawn
data_loader_classname: dyad_h5_torch_data_loader.DyadH5TorchDataLoader
data_loader_sampler: index

train:
epochs: 10
computation_time: .133
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,3 @@ reader:
train:
epochs: 10
computation_time: 0

checkpoint:
checkpoint_folder: checkpoints/unet3d
checkpoint_after_epoch: 5
epochs_between_checkpoints: 2
model_size: 499153191
234 changes: 234 additions & 0 deletions tests/integration/dlio_benchmark/dyad_h5_torch_data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
"""
Copyright (c) 2022, UChicago Argonne, LLC
All Rights Reserved
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from time import time
import logging
import math
import pickle
import torch
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler

from dlio_benchmark.common.constants import MODULE_DATA_LOADER
from dlio_benchmark.common.enumerations import Shuffle, DatasetType, DataLoaderType
from dlio_benchmark.data_loader.base_data_loader import BaseDataLoader
from dlio_benchmark.reader.reader_factory import ReaderFactory
from dlio_benchmark.utils.utility import utcnow, DLIOMPI
from dlio_benchmark.utils.config import ConfigArguments
from dlio_profiler.logger import fn_interceptor as Profile

from pydyad import Dyad
from pydyad.hdf import DyadFile
from pydyad.bindings import DTLMode, DTLCommMode
import numpy as np
import flux
import os
import h5py
import fcntl
dlp = Profile(MODULE_DATA_LOADER)


class DYADH5TorchDataset(Dataset):
"""
Currently, we only support loading one sample per file
TODO: support multiple samples per file
"""
@dlp.log_init
def __init__(self, format_type, dataset_type, epoch, num_samples, num_workers, batch_size):
self.format_type = format_type
self.dataset_type = dataset_type
self.epoch_number = epoch
self.num_samples = num_samples
self.reader = None
self.num_images_read = 0
self.batch_size = batch_size
args = ConfigArguments.get_instance()
self.serial_args = pickle.dumps(args)
self.dlp_logger = None
if num_workers == 0:
self.worker_init(-1)

@dlp.log
def worker_init(self, worker_id):
pickle.loads(self.serial_args)
self._args = ConfigArguments.get_instance()
self._args.configure_dlio_logging(is_child=True)
self.dlp_logger = self._args.configure_dlio_profiler(is_child=True, use_pid=True)
logging.debug(f"{utcnow()} worker initialized {worker_id} with format {self.format_type}")
self.reader = ReaderFactory.get_reader(type=self.format_type,
dataset_type=self.dataset_type,
thread_index=worker_id,
epoch_number=self.epoch_number)
self.dyad_io = Dyad()
is_local = os.getenv("DYAD_LOCAL_TEST", "0") == "1"
self.broker_per_node = int(os.getenv("BROKERS_PER_NODE", "1"))

self.f = flux.Flux()
self.broker_rank = self.f.get_rank()
if is_local:
self.dyad_managed_directory = os.path.join(os.getenv("DYAD_PATH", ""), str(self.f.get_rank()))
else:
self.dyad_managed_directory = os.getenv("DYAD_PATH", "")
self.my_node_index = int(self.broker_rank*1.0 / self.broker_per_node)
dtl_str = os.getenv("DYAD_DTL_MODE", "FLUX_RPC")
mode = DTLMode.DYAD_DTL_FLUX_RPC
namespace = os.getenv("DYAD_KVS_NAMESPACE")
logging.debug(f"{utcnow()} Rank {DLIOMPI.get_instance().rank()} init dyad {self.dyad_managed_directory} {dtl_str} {namespace}")
if dtl_str == "UCX":
mode = DTLMode.DYAD_DTL_UCX
self.dyad_io.init(debug=self._args.debug, check=False, shared_storage=False, reinit=False,
async_publish=True, fsync_write=False, key_depth=3,
service_mux=self.broker_per_node,
key_bins=1024, kvs_namespace=os.getenv("DYAD_KVS_NAMESPACE"),
prod_managed_path=self.dyad_managed_directory, cons_managed_path=self.dyad_managed_directory,
dtl_mode=mode, dtl_comm_mode=DTLCommMode.DYAD_COMM_RECV)

def __del__(self):
if self.dlp_logger:
self.dlp_logger.finalize()
@dlp.log
def __len__(self):
return self.num_samples

@dlp.log
def __getitem__(self, image_idx):
logging.debug(f"{utcnow()} Rank {DLIOMPI.get_instance().rank()} reading {image_idx} image")
self.num_images_read += 1
step = int(math.ceil(self.num_images_read / self.batch_size))
filename, sample_index = self._args.global_index_map[image_idx]
is_present = False
file_obj = None
base_fname = filename
dlp.update(args={"fname":filename})
dlp.update(args={"image_idx":image_idx})
if self.dyad_managed_directory != "":
logging.debug(f"{utcnow()} Rank {DLIOMPI.get_instance().rank()} reading metadata")
base_fname = os.path.join(self.dyad_managed_directory, os.path.basename(filename))
file_obj = self.dyad_io.get_metadata(fname=base_fname, should_wait=False, raw=True)
logging.debug(f"Using managed directory {self.dyad_managed_directory} {base_fname} {file_obj}")
is_present = True
if file_obj:
access_mode = "remote"
file_node_index = int(file_obj.contents.owner_rank*1.0 / self.broker_per_node)
if self.my_node_index == file_node_index:
access_mode = "local"
dlp.update(args={"owner_rank":str(file_obj.contents.owner_rank)})
dlp.update(args={"my_broker":str(self.broker_rank)})
dlp.update(args={"mode":"dyad"})
dlp.update(args={"access":access_mode})
logging.debug(f"{utcnow()} Rank {DLIOMPI.get_instance().rank()} reading {image_idx} sample from {access_mode} dyad {file_obj.contents.owner_rank}")
logging.debug(f"Reading from managed directory {base_fname}")
hf = DyadFile(base_fname, "r", dyad_ctx=self.dyad_io, metadata_wrapper=file_obj)
try:
data = hf["records"][sample_index]
except:
logging.debug(f"{utcnow()} Rank {DLIOMPI.get_instance().rank()} got wierd {image_idx} sample from {access_mode} dyad {file_obj.contents.owner_rank}")
data = self._args.resized_image
hf.close()
logging.debug(f"{utcnow()} Rank {DLIOMPI.get_instance().rank()} read {image_idx} sample from {access_mode} dyad {file_obj.contents.owner_rank}")
self.dyad_io.free_metadata(file_obj)
else:
dlp.update(args={"mode":"pfs"})
dlp.update(args={"access":"remote"})
logging.debug(f"{utcnow()} Rank {DLIOMPI.get_instance().rank()} reading {image_idx} sample from pfs {base_fname}")
logging.debug(f"Reading from pfs {base_fname}")
dyad_f = open(base_fname, "wb")
fcntl.lockf(dyad_f, fcntl.LOCK_EX)
dyad_f.seek(0, 2)
size = dyad_f.tell()
if size == 0:
pfs_f = open(filename, "rb")
data = pfs_f.read()
dyad_f.write(data)
pfs_f.close()
fcntl.lockf(dyad_f, fcntl.LOCK_UN)
dyad_f.close()
hf = DyadFile(base_fname, "r+", dyad_ctx=self.dyad_io)
data = hf["records"][sample_index]
hf.close()
logging.debug(f"Read from pfs {base_fname}")
dlp.update(step=step)
dlp.update(image_size=data.nbytes)
return data

class DyadH5TorchDataLoader(BaseDataLoader):
@dlp.log_init
def __init__(self, format_type, dataset_type, epoch_number):
super().__init__(format_type, dataset_type, epoch_number, DataLoaderType.PYTORCH)

@dlp.log
def read(self):
do_shuffle = True if self._args.sample_shuffle != Shuffle.OFF else False
dataset = DYADH5TorchDataset(self.format_type, self.dataset_type, self.epoch_number, self.num_samples, self._args.read_threads, self.batch_size)
if do_shuffle:
sampler = RandomSampler(dataset)
else:
sampler = SequentialSampler(dataset)
if self._args.read_threads >= 1:
prefetch_factor = math.ceil(self._args.prefetch_size / self._args.read_threads)
else:
prefetch_factor = self._args.prefetch_size
if prefetch_factor > 0:
if self._args.my_rank == 0:
logging.debug(
f"{utcnow()} Prefetch size is {self._args.prefetch_size}; prefetch factor of {prefetch_factor} will be set to Torch DataLoader.")
else:
prefetch_factor = 2
if self._args.my_rank == 0:
logging.debug(
f"{utcnow()} Prefetch size is 0; a default prefetch factor of 2 will be set to Torch DataLoader.")
logging.debug(f"{utcnow()} Setup dataloader with {self._args.read_threads} workers {torch.__version__}")
if self._args.read_threads==0:
kwargs={}
else:
kwargs={'multiprocessing_context':self._args.multiprocessing_context,
'prefetch_factor': prefetch_factor}
if torch.__version__ != '1.3.1':
kwargs['persistent_workers'] = True
if torch.__version__ == '1.3.1':
if 'prefetch_factor' in kwargs:
del kwargs['prefetch_factor']
self._dataset = DataLoader(dataset,
batch_size=self.batch_size,
sampler=sampler,
num_workers=self._args.read_threads,
pin_memory=True,
drop_last=True,
worker_init_fn=dataset.worker_init,
**kwargs)
else:
self._dataset = DataLoader(dataset,
batch_size=self.batch_size,
sampler=sampler,
num_workers=self._args.read_threads,
pin_memory=True,
drop_last=True,
worker_init_fn=dataset.worker_init,
**kwargs) # 2 is the default value
logging.debug(f"{utcnow()} Rank {self._args.my_rank} will read {len(self._dataset) * self.batch_size} files")

# self._dataset.sampler.set_epoch(epoch_number)

@dlp.log
def next(self):
super().next()
total = self._args.training_steps if self.dataset_type is DatasetType.TRAIN else self._args.eval_steps
logging.debug(f"{utcnow()} Rank {self._args.my_rank} should read {total} batches")
for batch in self._dataset:
yield batch

@dlp.log
def finalize(self):
pass
11 changes: 6 additions & 5 deletions tests/integration/dlio_benchmark/setup-env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ module load python/3.9.12
module load openmpi/4.1.2

# Configurations
export DLIO_WORKLOAD=dyad_resnet50 #dyad_unet3d_large # unet3d_base dyad_unet3d dyad_unet3d_small resnet50_base dyad_resnet50 unet3d_base_large
export NUM_NODES=32
export DLIO_WORKLOAD=dyad_mummi #dyad_unet3d_large # unet3d_base dyad_unet3d dyad_unet3d_small resnet50_base dyad_resnet50 unet3d_base_large mummi_base dyad_mummi
export NUM_NODES=4
export PPN=8
export QUEUE=pbatch
export QUEUE=pdebug
export TIME=$((60))
export BROKERS_PER_NODE=1
export GENERATE_DATA="0"
Expand All @@ -21,11 +21,11 @@ export GITHUB_WORKSPACE=/usr/workspace/haridev/dyad
export SPACK_DIR=/usr/workspace/haridev/spack
export SPACK_ENV=/usr/workspace/haridev/dyad/env/spack
export PYTHON_ENV=/usr/workspace/haridev/dyad/env/python
export DLIO_DATA_DIR=/p/lustre1/iopp/dyad/dlio_benchmark/dyad_resnet50 # dyad_resnet50 dyad_unet3d_basic
export DLIO_DATA_DIR=/p/lustre1/iopp/dyad/dlio_benchmark/dyad_mummi # dyad_resnet50 dyad_unet3d_basic
#export DLIO_DATA_DIR=/p/lustre2/haridev/dyad/dlio_benchmark/dyad_unet3d_basic # dyad_resnet50

# DLIO Profiler Configurations
export DLIO_PROFILER_ENABLE=1
export DLIO_PROFILER_ENABLE=0
export DLIO_PROFILER_INC_METADATA=1
export DLIO_PROFILER_DATA_DIR=${DLIO_DATA_DIR}:${DYAD_PATH}
export DLIO_PROFILER_LOG_FILE=/usr/workspace/haridev/dyad/tests/integration/dlio_benchmark/profiler/dyad
Expand All @@ -35,6 +35,7 @@ export DLIO_PROFILER_LOG_LEVEL=ERROR

export DLIO_PROFILER_BIND_SIGNALS=0
export MV2_BCAST_HWLOC_TOPOLOGY=0
export HDF5_USE_FILE_LOCKING=0

#mkdir -p ${DYAD_PATH}
mkdir -p ${DLIO_PROFILER_LOG_FILE}
Expand Down

0 comments on commit f30904c

Please sign in to comment.