Skip to content

Commit

Permalink
add flexibility to define compressors for all zarr arrays
Browse files Browse the repository at this point in the history
Signed-off-by: Behrooz <[email protected]>
  • Loading branch information
drbeh committed Jun 26, 2023
1 parent e0c35eb commit 3295cff
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 12 deletions.
17 changes: 11 additions & 6 deletions monai/inferers/merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,14 +208,15 @@ class ZarrAvgMerger(Merger):
merged_shape: the shape of the tensor required to merge the patches.
cropped_shape: the shape of the final merged output tensor.
If not provided, it will be the same as `merged_shape`.
output_dtype: the dtype for the final result. Default is `float32`.
dtype: the dtype for the final merged result. Default is `float32`.
value_dtype: the dtype for value aggregating tensor and the final result. Default is `float32`.
count_dtype: the dtype for sample counting tensor. Default is `uint8`.
store: the zarr store to save the final results. Default is "merged.zarr".
value_store: the zarr store to save the value aggregating tensor. Default is a temporary store.
count_store: the zarr store to save the sample counting tensor. Default is a temporary store.
compressor: the compressor for final merged zarr array. Default is "default".
The compressor for temporary zarr arrays (values and counts) will be set to None.
value_compressor: the compressor for value aggregating zarr array. Default is None.
count_compressor: the compressor for sample counting zarr array. Default is None.
chunks : int or tuple of ints that defines the chunk shape, or boolean. Default is True.
If True, chunk shape will be guessed from `shape` and `dtype`.
If False, ir will be set to `shape`, i.e., single chunk for the whole array.
Expand All @@ -226,26 +227,30 @@ def __init__(
self,
merged_shape: Sequence[int],
cropped_shape: Sequence[int] | None = None,
output_dtype: np.dtype | str = "float32",
dtype: np.dtype | str = "float32",
value_dtype: np.dtype | str = "float32",
count_dtype: np.dtype | str = "uint8",
store: zarr.storage.Store | str = "merged.zarr",
value_store: zarr.storage.Store | str | None = None,
count_store: zarr.storage.Store | str | None = None,
compressor: str = "default",
value_compressor: str | None = None,
count_compressor: str | None = None,
chunks: Sequence[int] | bool = True,
) -> None:
super().__init__(merged_shape=merged_shape, cropped_shape=cropped_shape)
if not self.merged_shape:
raise ValueError(f"`merged_shape` must be provided for `ZarrAvgMerger`. {self.merged_shape} is give.")
self.output_dtype = output_dtype
self.output_dtype = dtype
self.value_dtype = value_dtype
self.count_dtype = count_dtype
self.store = store
self.value_store = zarr.storage.TempStore() if value_store is None else value_store
self.count_store = zarr.storage.TempStore() if count_store is None else count_store
self.chunks = chunks
self.compressor = compressor
self.value_compressor = value_compressor
self.count_compressor = count_compressor
self.output = zarr.empty(
shape=self.merged_shape,
chunks=self.chunks,
Expand All @@ -258,15 +263,15 @@ def __init__(
shape=self.merged_shape,
chunks=self.chunks,
dtype=self.value_dtype,
compressor=None,
compressor=self.value_compressor,
store=self.value_store,
overwrite=True,
)
self.counts = zarr.zeros(
shape=self.merged_shape,
chunks=self.chunks,
dtype=self.count_dtype,
compressor=None,
compressor=self.count_compressor,
store=self.count_store,
overwrite=True,
)
Expand Down
63 changes: 57 additions & 6 deletions tests/test_zarr_avg_merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

np.seterr(divide="ignore", invalid="ignore")
zarr, has_zarr = optional_import("zarr")
numcodecs, has_numcodecs = optional_import("numcodecs")

TENSOR_4x4 = torch.randint(low=0, high=255, size=(2, 3, 4, 4), dtype=torch.float32)
TENSOR_4x4_WITH_NAN = TENSOR_4x4.clone()
Expand Down Expand Up @@ -128,8 +129,8 @@
TENSOR_4x4,
]
# with both value_dtype, count_dtype set to double precision
TEST_CASE_8_OUTPUT_DTYPE = [
dict(merged_shape=TENSOR_4x4.shape, output_dtype=np.float64),
TEST_CASE_8_DTYPE = [
dict(merged_shape=TENSOR_4x4.shape, dtype=np.float64),
[
(TENSOR_4x4[..., :2, :2], (0, 0)),
(TENSOR_4x4[..., :2, 2:], (0, 2)),
Expand Down Expand Up @@ -196,6 +197,44 @@
]


# test for LZ4 compressor
TEST_CASE_13_COMPRESSOR_LZ4 = [
dict(merged_shape=TENSOR_4x4.shape, compressor="LZ4"),
[
(TENSOR_4x4[..., :2, :2], (0, 0)),
(TENSOR_4x4[..., :2, 2:], (0, 2)),
(TENSOR_4x4[..., 2:, :2], (2, 0)),
(TENSOR_4x4[..., 2:, 2:], (2, 2)),
],
TENSOR_4x4,
]

# test for pickle compressor
TEST_CASE_14_COMPRESSOR_PICKLE = [
dict(merged_shape=TENSOR_4x4.shape, compressor="Pickle"),
[
(TENSOR_4x4[..., :2, :2], (0, 0)),
(TENSOR_4x4[..., :2, 2:], (0, 2)),
(TENSOR_4x4[..., 2:, :2], (2, 0)),
(TENSOR_4x4[..., 2:, 2:], (2, 2)),
],
TENSOR_4x4,
]

# test for LZMA compressor
TEST_CASE_15_COMPRESSOR_LZMA = [
dict(merged_shape=TENSOR_4x4.shape, compressor="LZMA"),
[
(TENSOR_4x4[..., :2, :2], (0, 0)),
(TENSOR_4x4[..., :2, 2:], (0, 2)),
(TENSOR_4x4[..., 2:, :2], (2, 0)),
(TENSOR_4x4[..., 2:, 2:], (2, 2)),
],
TENSOR_4x4,
]


@unittest.skipIf(not has_zarr or not has_numcodecs, "Requires zarr (and numcodecs) packages.)")
class ZarrAvgMergerTests(unittest.TestCase):
@parameterized.expand(
[
Expand All @@ -207,14 +246,26 @@ class ZarrAvgMergerTests(unittest.TestCase):
TEST_CASE_5_VALUE_DTYPE,
TEST_CASE_6_COUNT_DTYPE,
TEST_CASE_7_COUNT_VALUE_DTYPE,
TEST_CASE_8_OUTPUT_DTYPE,
TEST_CASE_8_DTYPE,
TEST_CASE_9_LARGER_SHAPE,
TEST_CASE_10_DIRECTORY_STORE,
TEST_CASE_11_MEMORY_STORE,
TEST_CASE_12_CHUNKS,
TEST_CASE_13_COMPRESSOR_LZ4,
TEST_CASE_14_COMPRESSOR_PICKLE,
TEST_CASE_15_COMPRESSOR_LZMA,
]
)
def test_avg_merger_patches(self, arguments, patch_locations, expected):
def test_zarr_avg_merger_patches(self, arguments, patch_locations, expected):
if "compressor" in arguments:
if arguments["compressor"] != "default":
arguments["compressor"] = zarr.codec_registry[arguments["compressor"].lower()]()
if "value_compressor" in arguments:
if arguments["value_compressor"] != "default":
arguments["value_compressor"] = zarr.codec_registry[arguments["value_compressor"].lower()]()
if "count_compressor" in arguments:
if arguments["count_compressor"] != "default":
arguments["count_compressor"] = zarr.codec_registry[arguments["count_compressor"].lower()]()
merger = ZarrAvgMerger(**arguments)
for pl in patch_locations:
merger.aggregate(pl[0], pl[1])
Expand All @@ -228,13 +279,13 @@ def test_avg_merger_patches(self, arguments, patch_locations, expected):
# check if the result is matching the expectation
assert_allclose(output[:], expected.numpy())

def test_avg_merger_finalized_error(self):
def test_zarr_avg_merger_finalized_error(self):
with self.assertRaises(ValueError):
merger = ZarrAvgMerger(merged_shape=(1, 3, 2, 3))
merger.finalize()
merger.aggregate(torch.zeros(1, 3, 2, 2), (3, 3))

def test_avg_merge_none_merged_shape_error(self):
def test_zarr_avg_merge_none_merged_shape_error(self):
with self.assertRaises(ValueError):
ZarrAvgMerger(merged_shape=None)

Expand Down

0 comments on commit 3295cff

Please sign in to comment.