From 9b5daea0777bb94ea95bab76c1551d3b865de0f7 Mon Sep 17 00:00:00 2001 From: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com> Date: Tue, 14 Nov 2023 15:05:53 +0100 Subject: [PATCH] Enable the easy download of the deployment.tar.gz (#379) --- src/sparsezoo/model/model.py | 34 ++++++++++++++---- src/sparsezoo/objects/directory.py | 15 +++++++- tests/sparsezoo/model/test_model.py | 56 +++++++++++++++++++++++++++++ 3 files changed, 97 insertions(+), 8 deletions(-) diff --git a/src/sparsezoo/model/model.py b/src/sparsezoo/model/model.py index 48dce7ec..938aa1ae 100644 --- a/src/sparsezoo/model/model.py +++ b/src/sparsezoo/model/model.py @@ -31,7 +31,6 @@ save_outputs_to_tar, ) from sparsezoo.objects import ( - AliasedSelectDirectory, Directory, File, NumpyDirectory, @@ -138,12 +137,26 @@ def __init__(self, source: str, download_path: Optional[str] = None): files, directory_class=Directory, display_name="sample-labels" ) - self.deployment: AliasedSelectDirectory = self._directory_from_files( + self.deployment: SelectDirectory = self._directory_from_files( files, - directory_class=AliasedSelectDirectory, + directory_class=SelectDirectory, display_name="deployment", - download_alias="deployment.tar.gz", stub_params=self.stub_params, + allow_multiple_outputs=True, + ) + + if isinstance(self.deployment, list): + # if there are multiple deployment directories + # (this may happen due to the presence of both + # - deployment directory + # - deployment.tar.gz file + # we need to choose one (they are identical) + self.deployment = self.deployment[0] + + self.deployment_tar: SelectDirectory = self._directory_from_files( + files, + directory_class=SelectDirectory, + display_name="deployment.tar.gz", ) self.onnx_folder: Directory = self._directory_from_files( @@ -196,6 +209,7 @@ def __init__(self, source: str, download_path: Optional[str] = None): self._files_dictionary = { "training": self.training, "deployment": self.deployment, + "deployment.tar.gz": self.deployment_tar, "onnx_folder": self.onnx_folder, "logs": self.logs, "sample_originals": self.sample_originals, @@ -233,9 +247,9 @@ def deployment_directory_path(self) -> str: deployment directory if compressed """ # trigger initial download if not downloaded - self.deployment.path - if self.deployment.is_archive: - self.deployment.unzip() + self.deployment_tar.path + if self.deployment_tar.is_archive: + self.deployment_tar.unzip() return self.deployment.path @@ -310,6 +324,12 @@ def download( else: downloads = [] for key, file in self._files_dictionary.items(): + if key == "deployment": + # skip the download of the deployment directory + # since identical files will be downloaded + # in the deployment_tar + _LOGGER.debug(f"Intentionally skipping downloading the file {key}") + continue if file is not None: # save all the files to a temporary directory downloads.append(self._download(file, download_path)) diff --git a/src/sparsezoo/objects/directory.py b/src/sparsezoo/objects/directory.py index c03f1466..f0338a88 100644 --- a/src/sparsezoo/objects/directory.py +++ b/src/sparsezoo/objects/directory.py @@ -298,9 +298,22 @@ def unzip(self, extract_directory: Optional[str] = None, force: bool = False): member.name = os.path.basename(member.name) tar.extract(member=member, path=path) files.append( - File(name=member.name, path=os.path.join(path, member.name)) + File( + name=member.name, + path=os.path.join(path, member.name), + parent_directory=path, + ) ) tar.close() + # if path already exists, then the tar archive has already been unzipped + # and we can just use the files in the directory + elif os.path.exists(path): + for file in os.listdir(path): + files.append( + File( + name=file, path=os.path.join(path, file), parent_directory=path + ) + ) self.name = name self.files = files diff --git a/tests/sparsezoo/model/test_model.py b/tests/sparsezoo/model/test_model.py index 89b44ff7..78c47a44 100644 --- a/tests/sparsezoo/model/test_model.py +++ b/tests/sparsezoo/model/test_model.py @@ -23,10 +23,12 @@ import pytest from sparsezoo import Model +from sparsezoo.objects.directories import SelectDirectory files_ic = { "training", + "deployment.tar.gz", "deployment", "logs", "onnx", @@ -182,6 +184,10 @@ def setup(self, stub, clone_sample_outputs, expected_files): temp_dir = tempfile.TemporaryDirectory(dir="/tmp") model = Model(stub, temp_dir.name) model.download() + # since downloading the `deployment` file is + # disabled by default, we need to do it + # explicitly + model.deployment.download() self._add_mock_files(temp_dir.name, clone_sample_outputs=clone_sample_outputs) model = Model(temp_dir.name) @@ -329,6 +335,56 @@ def test_model_gz_extraction_from_local_files(stub: str): shutil.rmtree(temp_dir.name) +@pytest.mark.parametrize( + "stub", + [ + "zoo:cv/classification/mobilenet_v1-1.0/pytorch/sparseml/" + "imagenet/pruned-moderate", + ], +) +def test_model_deployment_directory(stub): + temp_dir = tempfile.TemporaryDirectory(dir="/tmp") + expected_deployment_files = ["model.onnx"] + + model = Model(stub, temp_dir.name) + assert model.deployment_tar.is_archive + # download and extract deployment tar + deployment_dir_path = model.deployment_directory_path + + # deployment and deployment_tar should be point to the same files + assert deployment_dir_path == model.deployment_tar.path == model.deployment.path + # make sure that the model contains expected files + assert set(os.listdir(temp_dir.name)) == {"deployment.tar.gz", "deployment"} + assert ( + os.listdir(os.path.join(temp_dir.name, "deployment")) + == expected_deployment_files + ) + + assert isinstance(model.deployment, SelectDirectory) + # TODO: this should be 1. However, the API is returning for `deployment` file type + # both `model.onnx` and `deployment/model.onnx`. + # This should probably be fixed on the API side + assert ( + len(model.deployment.files) == 2 + ) # should be == len(expected_deployment_files) + + assert isinstance(model.deployment_tar, SelectDirectory) + assert len(model.deployment_tar.files) == len(expected_deployment_files) + assert not model.deployment_tar.is_archive + + # test recreating the model from the local files + model = Model(temp_dir.name) + + assert isinstance(model.deployment, SelectDirectory) + assert len(model.deployment.files) == len(expected_deployment_files) + + assert isinstance(model.deployment_tar, SelectDirectory) + assert len(model.deployment_tar.files) == len(expected_deployment_files) + assert not model.deployment_tar.is_archive + + shutil.rmtree(temp_dir.name) + + def _extraction_test_helper(model: Model): # download and extract model.onnx.tar.gz # path should point to extracted model.onnx file