Skip to content

Commit

Permalink
Enable the easy download of the deployment.tar.gz (#379)
Browse files Browse the repository at this point in the history
  • Loading branch information
dbogunowicz authored and bfineran committed Nov 16, 2023
1 parent 178cbf5 commit 9b5daea
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 8 deletions.
34 changes: 27 additions & 7 deletions src/sparsezoo/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
save_outputs_to_tar,
)
from sparsezoo.objects import (
AliasedSelectDirectory,
Directory,
File,
NumpyDirectory,
Expand Down Expand Up @@ -138,12 +137,26 @@ def __init__(self, source: str, download_path: Optional[str] = None):
files, directory_class=Directory, display_name="sample-labels"
)

self.deployment: AliasedSelectDirectory = self._directory_from_files(
self.deployment: SelectDirectory = self._directory_from_files(
files,
directory_class=AliasedSelectDirectory,
directory_class=SelectDirectory,
display_name="deployment",
download_alias="deployment.tar.gz",
stub_params=self.stub_params,
allow_multiple_outputs=True,
)

if isinstance(self.deployment, list):
# if there are multiple deployment directories
# (this may happen due to the presence of both
# - deployment directory
# - deployment.tar.gz file
# we need to choose one (they are identical)
self.deployment = self.deployment[0]

self.deployment_tar: SelectDirectory = self._directory_from_files(
files,
directory_class=SelectDirectory,
display_name="deployment.tar.gz",
)

self.onnx_folder: Directory = self._directory_from_files(
Expand Down Expand Up @@ -196,6 +209,7 @@ def __init__(self, source: str, download_path: Optional[str] = None):
self._files_dictionary = {
"training": self.training,
"deployment": self.deployment,
"deployment.tar.gz": self.deployment_tar,
"onnx_folder": self.onnx_folder,
"logs": self.logs,
"sample_originals": self.sample_originals,
Expand Down Expand Up @@ -233,9 +247,9 @@ def deployment_directory_path(self) -> str:
deployment directory if compressed
"""
# trigger initial download if not downloaded
self.deployment.path
if self.deployment.is_archive:
self.deployment.unzip()
self.deployment_tar.path
if self.deployment_tar.is_archive:
self.deployment_tar.unzip()

return self.deployment.path

Expand Down Expand Up @@ -310,6 +324,12 @@ def download(
else:
downloads = []
for key, file in self._files_dictionary.items():
if key == "deployment":
# skip the download of the deployment directory
# since identical files will be downloaded
# in the deployment_tar
_LOGGER.debug(f"Intentionally skipping downloading the file {key}")
continue
if file is not None:
# save all the files to a temporary directory
downloads.append(self._download(file, download_path))
Expand Down
15 changes: 14 additions & 1 deletion src/sparsezoo/objects/directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,9 +298,22 @@ def unzip(self, extract_directory: Optional[str] = None, force: bool = False):
member.name = os.path.basename(member.name)
tar.extract(member=member, path=path)
files.append(
File(name=member.name, path=os.path.join(path, member.name))
File(
name=member.name,
path=os.path.join(path, member.name),
parent_directory=path,
)
)
tar.close()
# if path already exists, then the tar archive has already been unzipped
# and we can just use the files in the directory
elif os.path.exists(path):
for file in os.listdir(path):
files.append(
File(
name=file, path=os.path.join(path, file), parent_directory=path
)
)

self.name = name
self.files = files
Expand Down
56 changes: 56 additions & 0 deletions tests/sparsezoo/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@
import pytest

from sparsezoo import Model
from sparsezoo.objects.directories import SelectDirectory


files_ic = {
"training",
"deployment.tar.gz",
"deployment",
"logs",
"onnx",
Expand Down Expand Up @@ -182,6 +184,10 @@ def setup(self, stub, clone_sample_outputs, expected_files):
temp_dir = tempfile.TemporaryDirectory(dir="/tmp")
model = Model(stub, temp_dir.name)
model.download()
# since downloading the `deployment` file is
# disabled by default, we need to do it
# explicitly
model.deployment.download()
self._add_mock_files(temp_dir.name, clone_sample_outputs=clone_sample_outputs)
model = Model(temp_dir.name)

Expand Down Expand Up @@ -329,6 +335,56 @@ def test_model_gz_extraction_from_local_files(stub: str):
shutil.rmtree(temp_dir.name)


@pytest.mark.parametrize(
"stub",
[
"zoo:cv/classification/mobilenet_v1-1.0/pytorch/sparseml/"
"imagenet/pruned-moderate",
],
)
def test_model_deployment_directory(stub):
temp_dir = tempfile.TemporaryDirectory(dir="/tmp")
expected_deployment_files = ["model.onnx"]

model = Model(stub, temp_dir.name)
assert model.deployment_tar.is_archive
# download and extract deployment tar
deployment_dir_path = model.deployment_directory_path

# deployment and deployment_tar should be point to the same files
assert deployment_dir_path == model.deployment_tar.path == model.deployment.path
# make sure that the model contains expected files
assert set(os.listdir(temp_dir.name)) == {"deployment.tar.gz", "deployment"}
assert (
os.listdir(os.path.join(temp_dir.name, "deployment"))
== expected_deployment_files
)

assert isinstance(model.deployment, SelectDirectory)
# TODO: this should be 1. However, the API is returning for `deployment` file type
# both `model.onnx` and `deployment/model.onnx`.
# This should probably be fixed on the API side
assert (
len(model.deployment.files) == 2
) # should be == len(expected_deployment_files)

assert isinstance(model.deployment_tar, SelectDirectory)
assert len(model.deployment_tar.files) == len(expected_deployment_files)
assert not model.deployment_tar.is_archive

# test recreating the model from the local files
model = Model(temp_dir.name)

assert isinstance(model.deployment, SelectDirectory)
assert len(model.deployment.files) == len(expected_deployment_files)

assert isinstance(model.deployment_tar, SelectDirectory)
assert len(model.deployment_tar.files) == len(expected_deployment_files)
assert not model.deployment_tar.is_archive

shutil.rmtree(temp_dir.name)


def _extraction_test_helper(model: Model):
# download and extract model.onnx.tar.gz
# path should point to extracted model.onnx file
Expand Down

0 comments on commit 9b5daea

Please sign in to comment.