From b3c4ad437a9ee48834a972c00483f7374ed25a0e Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Mon, 19 Jun 2023 10:12:49 -0400 Subject: [PATCH 01/19] add zarr dependency Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- docs/source/installation.md | 4 ++-- requirements-dev.txt | 1 + setup.cfg | 3 +++ 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/docs/source/installation.md b/docs/source/installation.md index c3e7297da6..eb7adb06fb 100644 --- a/docs/source/installation.md +++ b/docs/source/installation.md @@ -254,10 +254,10 @@ Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is - The options are ``` -[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime] +[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime, zarr] ``` which correspond to `nibabel`, `scikit-image`, `pillow`, `tensorboard`, -`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, respectively. +`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, and `zarr` respectively. - `pip install 'monai[all]'` installs all the optional dependencies. diff --git a/requirements-dev.txt b/requirements-dev.txt index 3f733ac723..71eb26fda2 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -52,3 +52,4 @@ onnx>=1.13.0 onnxruntime; python_version <= '3.10' typeguard<3 # https://github.com/microsoft/nni/issues/5457 filelock!=3.12.0 # https://github.com/microsoft/nni/issues/5523 +zarr \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index c7dcf384b8..c218b133ee 100644 --- a/setup.cfg +++ b/setup.cfg @@ -79,6 +79,7 @@ all = optuna onnx>=1.13.0 onnxruntime; python_version <= '3.10' + zarr nibabel = nibabel ninja = @@ -142,6 +143,8 @@ optuna = onnx = onnx>=1.13.0 onnxruntime; python_version <= '3.10' +zarr = + zarr # # workaround https://github.com/Project-MONAI/MONAI/issues/5882 # MetricsReloaded = # MetricsReloaded @ git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded From 5b7ea1ddccbb4feeb4437a6f0cf7fb63b42e39e4 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Mon, 19 Jun 2023 18:15:05 -0400 Subject: [PATCH 02/19] Implement ZarrAvgMerger Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- monai/inferers/__init__.py | 2 +- monai/inferers/merger.py | 198 +++++++++++++++++++++++++++++++++++-- 2 files changed, 193 insertions(+), 7 deletions(-) diff --git a/monai/inferers/__init__.py b/monai/inferers/__init__.py index bbd361ca79..960380bfb8 100644 --- a/monai/inferers/__init__.py +++ b/monai/inferers/__init__.py @@ -20,6 +20,6 @@ SlidingWindowInferer, SlidingWindowInfererAdapt, ) -from .merger import AvgMerger, Merger +from .merger import AvgMerger, Merger, ZarrAvgMerger from .splitter import SlidingWindowSplitter, Splitter, WSISlidingWindowSplitter from .utils import sliding_window_inference diff --git a/monai/inferers/merger.py b/monai/inferers/merger.py index 63c39aed24..fc51047997 100644 --- a/monai/inferers/merger.py +++ b/monai/inferers/merger.py @@ -13,13 +13,20 @@ from abc import ABC, abstractmethod from collections.abc import Sequence -from typing import Any +from typing import TYPE_CHECKING, Any +import numpy as np import torch -from monai.utils import ensure_tuple_size +from monai.utils import ensure_tuple_size, optional_import, require_pkg -__all__ = ["Merger", "AvgMerger"] +if TYPE_CHECKING: + import zarr +else: + zarr, _ = optional_import("zarr") + + +__all__ = ["Merger", "AvgMerger", "ZarrAvgMerger"] class Merger(ABC): @@ -97,9 +104,9 @@ def __init__( self, merged_shape: Sequence[int], cropped_shape: Sequence[int] | None = None, - device: torch.device | str = "cpu", value_dtype: torch.dtype = torch.float32, count_dtype: torch.dtype = torch.uint8, + device: torch.device | str = "cpu", ) -> None: super().__init__(merged_shape=merged_shape, cropped_shape=cropped_shape, device=device) if not self.merged_shape: @@ -152,12 +159,21 @@ def finalize(self) -> torch.Tensor: return self.values + def get_output(self) -> torch.Tensor: + """ + Get the final merged output. + + Returns: + torch.Tensor: merged output. + """ + return self.finalize() + def get_values(self) -> torch.Tensor: """ Get the accumulated values during aggregation or final averaged values after it is finalized. Returns: - Merged (averaged) output tensor. + torch.tensor: aggregated values. Notes: - If called before calling `finalize()`, this method returns the accumulating values. @@ -170,6 +186,176 @@ def get_counts(self) -> torch.Tensor: Get the aggregator tensor for number of samples. Returns: - torch.Tensor: Number of accumulated samples at each location. + torch.Tensor: number of accumulated samples at each location. + """ + return self.counts + + +@require_pkg(pkg_name="zarr") +class ZarrAvgMerger(Merger): + """Merge patches by taking average of the overlapping area and store the results in zarr array. + + Args: + merged_shape: the shape of the tensor required to merge the patches. + cropped_shape: the shape of the final merged output tensor. + If not provided, it will be the same as `merged_shape`. + output_dtype: the dtype for the final result. + value_dtype: the dtype for value aggregating tensor and the final result. + count_dtype: the dtype for sample counting tensor. + store: the zarr store to save the final results. + value_store: the zarr store to save the value aggregating tensor. + count_store: the zarr store to save the sample counting tensor. + compressor: the compressor for zarr array. + chunks : int or tuple of ints, optional + Chunk shape. If True, will be guessed from `shape` and `dtype`. If + False, will be set to `shape`, i.e., single chunk for the whole array. + If an int, the chunk size in each dimension will be given by the value + of `chunks`. Default is True. + """ + + def __init__( + self, + merged_shape: Sequence[int], + cropped_shape: Sequence[int] | None = None, + output_dtype: np.dtype | str = "float32", + value_dtype: np.dtype | str = "float32", + count_dtype: np.dtype | str = "uint8", + store: zarr.storage.Store | str = "merged.zarr", + value_store: zarr.storage.Store | str = zarr.storage.TempStore(), + count_store: zarr.storage.Store | str = zarr.storage.TempStore(), + compressor: str = "default", + chunks: Sequence[int] | bool = True, + ) -> None: + super().__init__(merged_shape=merged_shape, cropped_shape=cropped_shape) + if not self.merged_shape: + raise ValueError(f"`merged_shape` must be provided for `ZarrAvgMerger`. {self.merged_shape} is give.") + self.output_dtype = output_dtype + self.value_dtype = value_dtype + self.count_dtype = count_dtype + self.store = store + self.chunks = chunks + self.compressor = compressor + self.output = zarr.empty( + shape=self.merged_shape, + chunks=self.chunks, + dtype=self.output_dtype, + compressor=self.compressor, + store=store, + overwrite=True, + ) + self.values = zarr.zeros( + shape=self.merged_shape, + chunks=self.chunks, + dtype=self.value_dtype, + compressor=self.compressor, + store=value_store, + overwrite=True, + ) + self.counts = zarr.zeros( + shape=self.merged_shape, + chunks=self.chunks, + dtype=self.count_dtype, + compressor=self.compressor, + store=count_store, + overwrite=True, + ) + + def aggregate(self, values: torch.Tensor, location: Sequence[int]) -> None: + """ + Aggregate values for merging. + + Args: + values: a tensor of shape BCHW[D], representing the values of inference output. + location: a tuple/list giving the top left location of the patch in the original image. + + Raises: + NotImplementedError: When the subclass does not override this method. + + """ + if self.is_finalized: + raise ValueError("`ZarrAvgMerger` is already finalized. Please instantiate a new object to aggregate.") + patch_size = values.shape[2:] + map_slice = tuple(slice(loc, loc + size) for loc, size in zip(location, patch_size)) + map_slice = ensure_tuple_size(map_slice, values.ndim, pad_val=slice(None), pad_from_start=True) + self.values[map_slice] += values.numpy() + self.counts[map_slice] += 1 + + def finalize(self) -> zarr.Array: + """ + Finalize merging by dividing values by counts and return the merged tensor. + + Notes: + To avoid creating a new tensor for the final results (to save memory space), + after this method is called, `get_values()` method will return the "final" averaged values, + and not the accumulating values. Also calling `finalize()` multiple times does not have any effect. + + Returns: + zarr.Array: a zarr array of of merged patches + """ + # guard against multiple call to finalize + if not self.is_finalized: + # use chunks for division to be able to fit them into memory + self.output[:] = self.values[:] / self.counts[:] + # for chunk in iterate_over_chunks(self.values.chunks, self.values.cdata_shape): + # self.output[chunk] = self.values[chunk] / self.counts[chunk] + # finalize the shape + self.output.resize(self.cropped_shape) + # set finalize flag to protect performing in-place division again + self.is_finalized = True + + return self.output + + def get_output(self) -> zarr.Array: + """ + Get the final merged output. + + Returns: + zarr.Array: Merged (averaged) output tensor. + """ + return self.output + + def get_values(self) -> zarr.Array: + """ + Get the accumulated values during aggregation + + Returns: + zarr.Array: aggregated values. + + """ + return self.values + + def get_counts(self) -> zarr.Array: + """ + Get the aggregator tensor for number of samples. + + Returns: + zarr.Array: Number of accumulated samples at each location. """ return self.counts + + +def iterate_over_chunks(chunks, cdata_shape, slice_tuple=()): + """ + Iterate over chunks of a given shape. + + Args: + chunks: the chunk shape + cdata_shape: the shape of the data in chunks + slice_tuple: the slice tuple to be used for indexing + + Raises: + ValueError: When the length of chunks and cdata_shape are not the same. + + Yields: + slices of the data + """ + if len(chunks) != len(cdata_shape): + raise ValueError("chunks and cdata_shape must have the same length") + if len(chunks) == 1: + for i in range(cdata_shape[0]): + yield slice_tuple + (slice(i * chunks[0], (i + 1) * chunks[0]),) + else: + for i in range(cdata_shape[0]): + yield from iterate_over_chunks( + chunks[1:], cdata_shape[1:], slice_tuple + (slice(i * chunks[0], (i + 1) * chunks[0]),) + ) From 571498c60cabdca210f5e08fa28a6deb808c17b5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 20 Jun 2023 18:32:20 +0000 Subject: [PATCH 03/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- requirements-dev.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index 71eb26fda2..78e3b7381a 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -52,4 +52,4 @@ onnx>=1.13.0 onnxruntime; python_version <= '3.10' typeguard<3 # https://github.com/microsoft/nni/issues/5457 filelock!=3.12.0 # https://github.com/microsoft/nni/issues/5523 -zarr \ No newline at end of file +zarr From fe1cbc81372514c053ecfcb4bf1b228f19d1ada4 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Tue, 20 Jun 2023 14:47:19 -0400 Subject: [PATCH 04/19] update docs Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- docs/source/inferers.rst | 5 +++++ monai/inferers/merger.py | 12 +++++++----- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/docs/source/inferers.rst b/docs/source/inferers.rst index 3bf6af15b0..5f94e1a899 100644 --- a/docs/source/inferers.rst +++ b/docs/source/inferers.rst @@ -77,6 +77,11 @@ Mergers :members: :special-members: __call__ +`ZarrAvgMerger` +~~~~~~~~~~~ +.. autoclass:: ZarrAvgMerger + :members: + :special-members: __call__ Sliding Window Inference Function diff --git a/monai/inferers/merger.py b/monai/inferers/merger.py index fc51047997..04131cbc57 100644 --- a/monai/inferers/merger.py +++ b/monai/inferers/merger.py @@ -221,8 +221,8 @@ def __init__( value_dtype: np.dtype | str = "float32", count_dtype: np.dtype | str = "uint8", store: zarr.storage.Store | str = "merged.zarr", - value_store: zarr.storage.Store | str = zarr.storage.TempStore(), - count_store: zarr.storage.Store | str = zarr.storage.TempStore(), + value_store: zarr.storage.Store | str | None = None, + count_store: zarr.storage.Store | str | None = None, compressor: str = "default", chunks: Sequence[int] | bool = True, ) -> None: @@ -233,6 +233,8 @@ def __init__( self.value_dtype = value_dtype self.count_dtype = count_dtype self.store = store + self.value_store = zarr.storage.TempStore() if value_store is None else value_store + self.count_store = zarr.storage.TempStore() if count_store is None else count_store self.chunks = chunks self.compressor = compressor self.output = zarr.empty( @@ -240,7 +242,7 @@ def __init__( chunks=self.chunks, dtype=self.output_dtype, compressor=self.compressor, - store=store, + store=self.store, overwrite=True, ) self.values = zarr.zeros( @@ -248,7 +250,7 @@ def __init__( chunks=self.chunks, dtype=self.value_dtype, compressor=self.compressor, - store=value_store, + store=self.value_store, overwrite=True, ) self.counts = zarr.zeros( @@ -256,7 +258,7 @@ def __init__( chunks=self.chunks, dtype=self.count_dtype, compressor=self.compressor, - store=count_store, + store=self.count_store, overwrite=True, ) From 0a1aa3d085e9821920fb06169b4ed01fd59e41a5 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Tue, 20 Jun 2023 15:05:08 -0400 Subject: [PATCH 05/19] update docs Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- docs/source/inferers.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/inferers.rst b/docs/source/inferers.rst index 5f94e1a899..0011a489f3 100644 --- a/docs/source/inferers.rst +++ b/docs/source/inferers.rst @@ -78,7 +78,7 @@ Mergers :special-members: __call__ `ZarrAvgMerger` -~~~~~~~~~~~ +~~~~~~~~~~~~~~~ .. autoclass:: ZarrAvgMerger :members: :special-members: __call__ From 8e8594e500b3f3f7c29f1e86048ab7aad93efd25 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Tue, 20 Jun 2023 15:09:49 -0400 Subject: [PATCH 06/19] update docstring Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- monai/inferers/merger.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/monai/inferers/merger.py b/monai/inferers/merger.py index 04131cbc57..54622a94a1 100644 --- a/monai/inferers/merger.py +++ b/monai/inferers/merger.py @@ -199,18 +199,17 @@ class ZarrAvgMerger(Merger): merged_shape: the shape of the tensor required to merge the patches. cropped_shape: the shape of the final merged output tensor. If not provided, it will be the same as `merged_shape`. - output_dtype: the dtype for the final result. - value_dtype: the dtype for value aggregating tensor and the final result. - count_dtype: the dtype for sample counting tensor. - store: the zarr store to save the final results. - value_store: the zarr store to save the value aggregating tensor. - count_store: the zarr store to save the sample counting tensor. - compressor: the compressor for zarr array. - chunks : int or tuple of ints, optional - Chunk shape. If True, will be guessed from `shape` and `dtype`. If - False, will be set to `shape`, i.e., single chunk for the whole array. - If an int, the chunk size in each dimension will be given by the value - of `chunks`. Default is True. + output_dtype: the dtype for the final result. Default is `float32`. + value_dtype: the dtype for value aggregating tensor and the final result. Default is `float32`. + count_dtype: the dtype for sample counting tensor. Default is `uint8`. + store: the zarr store to save the final results. Default is "merged.zarr". + value_store: the zarr store to save the value aggregating tensor. Default is a temporary store. + count_store: the zarr store to save the sample counting tensor. Default is a temporary store. + compressor: the compressor for zarr array. Default is "default". + chunks : int or tuple of ints that defines the chunk shape, or boolean. Default is True. + If True, chunk shape will be guessed from `shape` and `dtype`. + If False, ir will be set to `shape`, i.e., single chunk for the whole array. + If an int, the chunk size in each dimension will be given by the value of `chunks`. """ def __init__( From 1a9d0f81260d7d8744fdae7a6b2b777f62fc213f Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Wed, 21 Jun 2023 10:36:29 -0400 Subject: [PATCH 07/19] add info about Zarr in docstring Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- docs/requirements.txt | 1 + monai/inferers/merger.py | 8 ++++++++ 2 files changed, 9 insertions(+) diff --git a/docs/requirements.txt b/docs/requirements.txt index 9369548c67..0f8ccbdadd 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -37,3 +37,4 @@ optuna opencv-python-headless onnx>=1.13.0 onnxruntime; python_version <= '3.10' +zarr \ No newline at end of file diff --git a/monai/inferers/merger.py b/monai/inferers/merger.py index 54622a94a1..94eb70817c 100644 --- a/monai/inferers/merger.py +++ b/monai/inferers/merger.py @@ -195,6 +195,14 @@ def get_counts(self) -> torch.Tensor: class ZarrAvgMerger(Merger): """Merge patches by taking average of the overlapping area and store the results in zarr array. + Zarr is a format for the storage of chunked, compressed, N-dimensional arrays. + Zarr data can be stored in any storage system that can be represented as a key-value store, + like POSIX file systems, cloud object storage, zip files, and relational and document databases. + See https://zarr.readthedocs.io/en/stable/ for more details. + It is particularly useful for storing N-dimensional arrays too large to fit into memory. + One specific use case of this class is to merge patches extracted from whole slide images (WSI), + where the merged results does not fit into memory and need to be stored on a file system. + Args: merged_shape: the shape of the tensor required to merge the patches. cropped_shape: the shape of the final merged output tensor. From 5c5120010a73043525a6f488ce2c4590ad95a5f6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 21 Jun 2023 14:36:59 +0000 Subject: [PATCH 08/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index 0f8ccbdadd..07b189dd79 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -37,4 +37,4 @@ optuna opencv-python-headless onnx>=1.13.0 onnxruntime; python_version <= '3.10' -zarr \ No newline at end of file +zarr From fdb396988d5dd419868b0c6d3067653c0f23aa03 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Wed, 21 Jun 2023 11:00:04 -0400 Subject: [PATCH 09/19] add unit tests Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- tests/test_zarr_avg_merger.py | 243 ++++++++++++++++++++++++++++++++++ 1 file changed, 243 insertions(+) create mode 100644 tests/test_zarr_avg_merger.py diff --git a/tests/test_zarr_avg_merger.py b/tests/test_zarr_avg_merger.py new file mode 100644 index 0000000000..ef3bf2670e --- /dev/null +++ b/tests/test_zarr_avg_merger.py @@ -0,0 +1,243 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +from parameterized import parameterized +from torch.nn.functional import pad + +from monai.inferers import ZarrAvgMerger +from monai.utils import optional_import +from tests.utils import assert_allclose + +np.seterr(divide="ignore", invalid="ignore") +zarr, has_zarr = optional_import("zarr") + +TENSOR_4x4 = torch.randint(low=0, high=255, size=(2, 3, 4, 4), dtype=torch.float32) +TENSOR_4x4_WITH_NAN = TENSOR_4x4.clone() +TENSOR_4x4_WITH_NAN[..., 2:, 2:] = float("nan") + +# no-overlapping 2x2 +TEST_CASE_0_DEFAULT_DTYPE = [ + dict(merged_shape=TENSOR_4x4.shape), + [ + (TENSOR_4x4[..., :2, :2], (0, 0)), + (TENSOR_4x4[..., :2, 2:], (0, 2)), + (TENSOR_4x4[..., 2:, :2], (2, 0)), + (TENSOR_4x4[..., 2:, 2:], (2, 2)), + ], + TENSOR_4x4, +] + +# overlapping 2x2 +TEST_CASE_1_DEFAULT_DTYPE = [ + dict(merged_shape=TENSOR_4x4.shape), + [ + (TENSOR_4x4[..., 0:2, 0:2], (0, 0)), + (TENSOR_4x4[..., 0:2, 1:3], (0, 1)), + (TENSOR_4x4[..., 0:2, 2:4], (0, 2)), + (TENSOR_4x4[..., 1:3, 0:2], (1, 0)), + (TENSOR_4x4[..., 1:3, 1:3], (1, 1)), + (TENSOR_4x4[..., 1:3, 2:4], (1, 2)), + (TENSOR_4x4[..., 2:4, 0:2], (2, 0)), + (TENSOR_4x4[..., 2:4, 1:3], (2, 1)), + (TENSOR_4x4[..., 2:4, 2:4], (2, 2)), + ], + TENSOR_4x4, +] + +# overlapping 3x3 (non-divisible) +TEST_CASE_2_DEFAULT_DTYPE = [ + dict(merged_shape=TENSOR_4x4.shape), + [ + (TENSOR_4x4[..., :3, :3], (0, 0)), + (TENSOR_4x4[..., :3, 1:], (0, 1)), + (TENSOR_4x4[..., 1:, :3], (1, 0)), + (TENSOR_4x4[..., 1:, 1:], (1, 1)), + ], + TENSOR_4x4, +] + +# overlapping 2x2 with NaN values +TEST_CASE_3_DEFAULT_DTYPE = [ + dict(merged_shape=TENSOR_4x4_WITH_NAN.shape), + [ + (TENSOR_4x4_WITH_NAN[..., 0:2, 0:2], (0, 0)), + (TENSOR_4x4_WITH_NAN[..., 0:2, 1:3], (0, 1)), + (TENSOR_4x4_WITH_NAN[..., 0:2, 2:4], (0, 2)), + (TENSOR_4x4_WITH_NAN[..., 1:3, 0:2], (1, 0)), + (TENSOR_4x4_WITH_NAN[..., 1:3, 1:3], (1, 1)), + (TENSOR_4x4_WITH_NAN[..., 1:3, 2:4], (1, 2)), + (TENSOR_4x4_WITH_NAN[..., 2:4, 0:2], (2, 0)), + (TENSOR_4x4_WITH_NAN[..., 2:4, 1:3], (2, 1)), + (TENSOR_4x4_WITH_NAN[..., 2:4, 2:4], (2, 2)), + ], + TENSOR_4x4_WITH_NAN, +] + +# non-overlapping 2x2 with missing patch +TEST_CASE_4_DEFAULT_DTYPE = [ + dict(merged_shape=TENSOR_4x4.shape), + [(TENSOR_4x4[..., :2, :2], (0, 0)), (TENSOR_4x4[..., :2, 2:], (0, 2)), (TENSOR_4x4[..., 2:, :2], (2, 0))], + TENSOR_4x4_WITH_NAN, +] + +# with value_dtype set to half precision +TEST_CASE_5_VALUE_DTYPE = [ + dict(merged_shape=TENSOR_4x4.shape, value_dtype=np.float16), + [ + (TENSOR_4x4[..., :2, :2], (0, 0)), + (TENSOR_4x4[..., :2, 2:], (0, 2)), + (TENSOR_4x4[..., 2:, :2], (2, 0)), + (TENSOR_4x4[..., 2:, 2:], (2, 2)), + ], + TENSOR_4x4, +] +# with count_dtype set to int32 +TEST_CASE_6_COUNT_DTYPE = [ + dict(merged_shape=TENSOR_4x4.shape, count_dtype=np.int32), + [ + (TENSOR_4x4[..., :2, :2], (0, 0)), + (TENSOR_4x4[..., :2, 2:], (0, 2)), + (TENSOR_4x4[..., 2:, :2], (2, 0)), + (TENSOR_4x4[..., 2:, 2:], (2, 2)), + ], + TENSOR_4x4, +] +# with both value_dtype, count_dtype set to double precision +TEST_CASE_7_COUNT_VALUE_DTYPE = [ + dict(merged_shape=TENSOR_4x4.shape, value_dtype=np.float64, count_dtype=np.float64), + [ + (TENSOR_4x4[..., :2, :2], (0, 0)), + (TENSOR_4x4[..., :2, 2:], (0, 2)), + (TENSOR_4x4[..., 2:, :2], (2, 0)), + (TENSOR_4x4[..., 2:, 2:], (2, 2)), + ], + TENSOR_4x4, +] +# with both value_dtype, count_dtype set to double precision +TEST_CASE_8_OUTPUT_DTYPE = [ + dict(merged_shape=TENSOR_4x4.shape, output_dtype=np.float64), + [ + (TENSOR_4x4[..., :2, :2], (0, 0)), + (TENSOR_4x4[..., :2, 2:], (0, 2)), + (TENSOR_4x4[..., 2:, :2], (2, 0)), + (TENSOR_4x4[..., 2:, 2:], (2, 2)), + ], + TENSOR_4x4, +] + + +# shape larger than what is covered by patches +TEST_CASE_9_LARGER_SHAPE = [ + dict(merged_shape=(2, 3, 4, 6)), + [ + (TENSOR_4x4[..., :2, :2], (0, 0)), + (TENSOR_4x4[..., :2, 2:], (0, 2)), + (TENSOR_4x4[..., 2:, :2], (2, 0)), + (TENSOR_4x4[..., 2:, 2:], (2, 2)), + ], + pad(TENSOR_4x4, (0, 2), value=float("nan")), +] + + +# explicit directory store +TEST_CASE_10_DIRECTORY_STORE = [ + dict(merged_shape=TENSOR_4x4.shape, store=zarr.storage.DirectoryStore("test.zarr")), + [ + (TENSOR_4x4[..., :2, :2], (0, 0)), + (TENSOR_4x4[..., :2, 2:], (0, 2)), + (TENSOR_4x4[..., 2:, :2], (2, 0)), + (TENSOR_4x4[..., 2:, 2:], (2, 2)), + ], + TENSOR_4x4, +] + +# memory store for all arrays +TEST_CASE_11_MEMORY_STORE = [ + dict( + merged_shape=TENSOR_4x4.shape, + store=zarr.storage.MemoryStore(), + value_store=zarr.storage.MemoryStore(), + count_store=zarr.storage.MemoryStore(), + ), + [ + (TENSOR_4x4[..., :2, :2], (0, 0)), + (TENSOR_4x4[..., :2, 2:], (0, 2)), + (TENSOR_4x4[..., 2:, :2], (2, 0)), + (TENSOR_4x4[..., 2:, 2:], (2, 2)), + ], + TENSOR_4x4, +] + + +# explicit chunk size +TEST_CASE_12_CHUNKS = [ + dict(merged_shape=TENSOR_4x4.shape, chunks=(1, 1, 2, 2)), + [ + (TENSOR_4x4[..., :2, :2], (0, 0)), + (TENSOR_4x4[..., :2, 2:], (0, 2)), + (TENSOR_4x4[..., 2:, :2], (2, 0)), + (TENSOR_4x4[..., 2:, 2:], (2, 2)), + ], + TENSOR_4x4, +] + + +class ZarrAvgMergerTests(unittest.TestCase): + @parameterized.expand( + [ + TEST_CASE_0_DEFAULT_DTYPE, + TEST_CASE_1_DEFAULT_DTYPE, + TEST_CASE_2_DEFAULT_DTYPE, + TEST_CASE_3_DEFAULT_DTYPE, + TEST_CASE_4_DEFAULT_DTYPE, + TEST_CASE_5_VALUE_DTYPE, + TEST_CASE_6_COUNT_DTYPE, + TEST_CASE_7_COUNT_VALUE_DTYPE, + TEST_CASE_8_OUTPUT_DTYPE, + TEST_CASE_9_LARGER_SHAPE, + TEST_CASE_10_DIRECTORY_STORE, + TEST_CASE_11_MEMORY_STORE, + TEST_CASE_12_CHUNKS, + ] + ) + def test_avg_merger_patches(self, arguments, patch_locations, expected): + merger = ZarrAvgMerger(**arguments) + for pl in patch_locations: + merger.aggregate(pl[0], pl[1]) + output = merger.finalize() + if "value_dtype" in arguments: + self.assertTrue(merger.get_values().dtype, arguments["value_dtype"]) + if "count_dtype" in arguments: + self.assertTrue(merger.get_counts().dtype, arguments["count_dtype"]) + # check for multiple call of finalize + self.assertIs(output, merger.finalize()) + # check if the result is matching the expectation + assert_allclose(output[:], expected.numpy()) + + def test_avg_merger_finalized_error(self): + with self.assertRaises(ValueError): + merger = ZarrAvgMerger(merged_shape=(1, 3, 2, 3)) + merger.finalize() + merger.aggregate(torch.zeros(1, 3, 2, 2), (3, 3)) + + def test_avg_merge_none_merged_shape_error(self): + with self.assertRaises(ValueError): + ZarrAvgMerger(merged_shape=None) + + +if __name__ == "__main__": + unittest.main() From 70bf188d2d1f113226313410c844e1d8e45b58ac Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Wed, 21 Jun 2023 12:10:53 -0400 Subject: [PATCH 10/19] exclude zarr tests from min tests Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- tests/min_tests.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/min_tests.py b/tests/min_tests.py index 2fc22452d0..f553dc4a50 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -202,6 +202,7 @@ def run_testsuit(): "test_metrics_reloaded", "test_spatial_combine_transforms", "test_bundle_workflow", + "test_zarr_avg_merger", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" From e74658b1af5c1a553b7542bb6a3197a40b06fa85 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Mon, 26 Jun 2023 09:25:57 -0400 Subject: [PATCH 11/19] add thread locks Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- monai/inferers/merger.py | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/monai/inferers/merger.py b/monai/inferers/merger.py index 94eb70817c..c954d62e08 100644 --- a/monai/inferers/merger.py +++ b/monai/inferers/merger.py @@ -11,6 +11,7 @@ from __future__ import annotations +import threading from abc import ABC, abstractmethod from collections.abc import Sequence from typing import TYPE_CHECKING, Any @@ -268,6 +269,8 @@ def __init__( store=self.count_store, overwrite=True, ) + self.lock = threading.Lock() + self.lock_finalize = threading.Lock() def aggregate(self, values: torch.Tensor, location: Sequence[int]) -> None: """ @@ -286,8 +289,9 @@ def aggregate(self, values: torch.Tensor, location: Sequence[int]) -> None: patch_size = values.shape[2:] map_slice = tuple(slice(loc, loc + size) for loc, size in zip(location, patch_size)) map_slice = ensure_tuple_size(map_slice, values.ndim, pad_val=slice(None), pad_from_start=True) - self.values[map_slice] += values.numpy() - self.counts[map_slice] += 1 + with self.lock: + self.values[map_slice] += values.numpy() + self.counts[map_slice] += 1 def finalize(self) -> zarr.Array: """ @@ -301,16 +305,17 @@ def finalize(self) -> zarr.Array: Returns: zarr.Array: a zarr array of of merged patches """ - # guard against multiple call to finalize - if not self.is_finalized: - # use chunks for division to be able to fit them into memory - self.output[:] = self.values[:] / self.counts[:] - # for chunk in iterate_over_chunks(self.values.chunks, self.values.cdata_shape): - # self.output[chunk] = self.values[chunk] / self.counts[chunk] - # finalize the shape - self.output.resize(self.cropped_shape) - # set finalize flag to protect performing in-place division again - self.is_finalized = True + # guard against possible multithreading calls + with self.lock_finalize: + # guard against multiple calls to finalize + if not self.is_finalized: + # use chunks for division to fit into memory + for chunk in iterate_over_chunks(self.values.chunks, self.values.cdata_shape): + self.output[chunk] = self.values[chunk] / self.counts[chunk] + # finalize the shape + self.output.resize(self.cropped_shape) + # set finalize flag to protect performing in-place division again + self.is_finalized = True return self.output From bf33fb555ffc26ead33ac5e6c185f4956f127513 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Mon, 26 Jun 2023 09:41:49 -0400 Subject: [PATCH 12/19] remove compression for temp zarr arrays Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- monai/inferers/merger.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/monai/inferers/merger.py b/monai/inferers/merger.py index c954d62e08..d6a4b84084 100644 --- a/monai/inferers/merger.py +++ b/monai/inferers/merger.py @@ -214,7 +214,8 @@ class ZarrAvgMerger(Merger): store: the zarr store to save the final results. Default is "merged.zarr". value_store: the zarr store to save the value aggregating tensor. Default is a temporary store. count_store: the zarr store to save the sample counting tensor. Default is a temporary store. - compressor: the compressor for zarr array. Default is "default". + compressor: the compressor for final merged zarr array. Default is "default". + The compressor for temporary zarr arrays (values and counts) will be set to None. chunks : int or tuple of ints that defines the chunk shape, or boolean. Default is True. If True, chunk shape will be guessed from `shape` and `dtype`. If False, ir will be set to `shape`, i.e., single chunk for the whole array. @@ -257,7 +258,7 @@ def __init__( shape=self.merged_shape, chunks=self.chunks, dtype=self.value_dtype, - compressor=self.compressor, + compressor=None, store=self.value_store, overwrite=True, ) @@ -265,7 +266,7 @@ def __init__( shape=self.merged_shape, chunks=self.chunks, dtype=self.count_dtype, - compressor=self.compressor, + compressor=None, store=self.count_store, overwrite=True, ) From 5bc1f943c0af3e971f398eff56e628228f68ce83 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Mon, 26 Jun 2023 11:01:11 -0400 Subject: [PATCH 13/19] add flexibility to define compressors for all zarr arrays Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- monai/inferers/merger.py | 17 ++++++---- tests/test_zarr_avg_merger.py | 63 +++++++++++++++++++++++++++++++---- 2 files changed, 68 insertions(+), 12 deletions(-) diff --git a/monai/inferers/merger.py b/monai/inferers/merger.py index d6a4b84084..c1cb825d63 100644 --- a/monai/inferers/merger.py +++ b/monai/inferers/merger.py @@ -208,14 +208,15 @@ class ZarrAvgMerger(Merger): merged_shape: the shape of the tensor required to merge the patches. cropped_shape: the shape of the final merged output tensor. If not provided, it will be the same as `merged_shape`. - output_dtype: the dtype for the final result. Default is `float32`. + dtype: the dtype for the final merged result. Default is `float32`. value_dtype: the dtype for value aggregating tensor and the final result. Default is `float32`. count_dtype: the dtype for sample counting tensor. Default is `uint8`. store: the zarr store to save the final results. Default is "merged.zarr". value_store: the zarr store to save the value aggregating tensor. Default is a temporary store. count_store: the zarr store to save the sample counting tensor. Default is a temporary store. compressor: the compressor for final merged zarr array. Default is "default". - The compressor for temporary zarr arrays (values and counts) will be set to None. + value_compressor: the compressor for value aggregating zarr array. Default is None. + count_compressor: the compressor for sample counting zarr array. Default is None. chunks : int or tuple of ints that defines the chunk shape, or boolean. Default is True. If True, chunk shape will be guessed from `shape` and `dtype`. If False, ir will be set to `shape`, i.e., single chunk for the whole array. @@ -226,19 +227,21 @@ def __init__( self, merged_shape: Sequence[int], cropped_shape: Sequence[int] | None = None, - output_dtype: np.dtype | str = "float32", + dtype: np.dtype | str = "float32", value_dtype: np.dtype | str = "float32", count_dtype: np.dtype | str = "uint8", store: zarr.storage.Store | str = "merged.zarr", value_store: zarr.storage.Store | str | None = None, count_store: zarr.storage.Store | str | None = None, compressor: str = "default", + value_compressor: str | None = None, + count_compressor: str | None = None, chunks: Sequence[int] | bool = True, ) -> None: super().__init__(merged_shape=merged_shape, cropped_shape=cropped_shape) if not self.merged_shape: raise ValueError(f"`merged_shape` must be provided for `ZarrAvgMerger`. {self.merged_shape} is give.") - self.output_dtype = output_dtype + self.output_dtype = dtype self.value_dtype = value_dtype self.count_dtype = count_dtype self.store = store @@ -246,6 +249,8 @@ def __init__( self.count_store = zarr.storage.TempStore() if count_store is None else count_store self.chunks = chunks self.compressor = compressor + self.value_compressor = value_compressor + self.count_compressor = count_compressor self.output = zarr.empty( shape=self.merged_shape, chunks=self.chunks, @@ -258,7 +263,7 @@ def __init__( shape=self.merged_shape, chunks=self.chunks, dtype=self.value_dtype, - compressor=None, + compressor=self.value_compressor, store=self.value_store, overwrite=True, ) @@ -266,7 +271,7 @@ def __init__( shape=self.merged_shape, chunks=self.chunks, dtype=self.count_dtype, - compressor=None, + compressor=self.count_compressor, store=self.count_store, overwrite=True, ) diff --git a/tests/test_zarr_avg_merger.py b/tests/test_zarr_avg_merger.py index ef3bf2670e..016d5cc820 100644 --- a/tests/test_zarr_avg_merger.py +++ b/tests/test_zarr_avg_merger.py @@ -24,6 +24,7 @@ np.seterr(divide="ignore", invalid="ignore") zarr, has_zarr = optional_import("zarr") +numcodecs, has_numcodecs = optional_import("numcodecs") TENSOR_4x4 = torch.randint(low=0, high=255, size=(2, 3, 4, 4), dtype=torch.float32) TENSOR_4x4_WITH_NAN = TENSOR_4x4.clone() @@ -128,8 +129,8 @@ TENSOR_4x4, ] # with both value_dtype, count_dtype set to double precision -TEST_CASE_8_OUTPUT_DTYPE = [ - dict(merged_shape=TENSOR_4x4.shape, output_dtype=np.float64), +TEST_CASE_8_DTYPE = [ + dict(merged_shape=TENSOR_4x4.shape, dtype=np.float64), [ (TENSOR_4x4[..., :2, :2], (0, 0)), (TENSOR_4x4[..., :2, 2:], (0, 2)), @@ -196,6 +197,44 @@ ] +# test for LZ4 compressor +TEST_CASE_13_COMPRESSOR_LZ4 = [ + dict(merged_shape=TENSOR_4x4.shape, compressor="LZ4"), + [ + (TENSOR_4x4[..., :2, :2], (0, 0)), + (TENSOR_4x4[..., :2, 2:], (0, 2)), + (TENSOR_4x4[..., 2:, :2], (2, 0)), + (TENSOR_4x4[..., 2:, 2:], (2, 2)), + ], + TENSOR_4x4, +] + +# test for pickle compressor +TEST_CASE_14_COMPRESSOR_PICKLE = [ + dict(merged_shape=TENSOR_4x4.shape, compressor="Pickle"), + [ + (TENSOR_4x4[..., :2, :2], (0, 0)), + (TENSOR_4x4[..., :2, 2:], (0, 2)), + (TENSOR_4x4[..., 2:, :2], (2, 0)), + (TENSOR_4x4[..., 2:, 2:], (2, 2)), + ], + TENSOR_4x4, +] + +# test for LZMA compressor +TEST_CASE_15_COMPRESSOR_LZMA = [ + dict(merged_shape=TENSOR_4x4.shape, compressor="LZMA"), + [ + (TENSOR_4x4[..., :2, :2], (0, 0)), + (TENSOR_4x4[..., :2, 2:], (0, 2)), + (TENSOR_4x4[..., 2:, :2], (2, 0)), + (TENSOR_4x4[..., 2:, 2:], (2, 2)), + ], + TENSOR_4x4, +] + + +@unittest.skipIf(not has_zarr or not has_numcodecs, "Requires zarr (and numcodecs) packages.)") class ZarrAvgMergerTests(unittest.TestCase): @parameterized.expand( [ @@ -207,14 +246,26 @@ class ZarrAvgMergerTests(unittest.TestCase): TEST_CASE_5_VALUE_DTYPE, TEST_CASE_6_COUNT_DTYPE, TEST_CASE_7_COUNT_VALUE_DTYPE, - TEST_CASE_8_OUTPUT_DTYPE, + TEST_CASE_8_DTYPE, TEST_CASE_9_LARGER_SHAPE, TEST_CASE_10_DIRECTORY_STORE, TEST_CASE_11_MEMORY_STORE, TEST_CASE_12_CHUNKS, + TEST_CASE_13_COMPRESSOR_LZ4, + TEST_CASE_14_COMPRESSOR_PICKLE, + TEST_CASE_15_COMPRESSOR_LZMA, ] ) - def test_avg_merger_patches(self, arguments, patch_locations, expected): + def test_zarr_avg_merger_patches(self, arguments, patch_locations, expected): + if "compressor" in arguments: + if arguments["compressor"] != "default": + arguments["compressor"] = zarr.codec_registry[arguments["compressor"].lower()]() + if "value_compressor" in arguments: + if arguments["value_compressor"] != "default": + arguments["value_compressor"] = zarr.codec_registry[arguments["value_compressor"].lower()]() + if "count_compressor" in arguments: + if arguments["count_compressor"] != "default": + arguments["count_compressor"] = zarr.codec_registry[arguments["count_compressor"].lower()]() merger = ZarrAvgMerger(**arguments) for pl in patch_locations: merger.aggregate(pl[0], pl[1]) @@ -228,13 +279,13 @@ def test_avg_merger_patches(self, arguments, patch_locations, expected): # check if the result is matching the expectation assert_allclose(output[:], expected.numpy()) - def test_avg_merger_finalized_error(self): + def test_zarr_avg_merger_finalized_error(self): with self.assertRaises(ValueError): merger = ZarrAvgMerger(merged_shape=(1, 3, 2, 3)) merger.finalize() merger.aggregate(torch.zeros(1, 3, 2, 2), (3, 3)) - def test_avg_merge_none_merged_shape_error(self): + def test_zarr_avg_merge_none_merged_shape_error(self): with self.assertRaises(ValueError): ZarrAvgMerger(merged_shape=None) From 9bd6ce16d1cce05f2943e7dac1a376c9e0866be1 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Mon, 26 Jun 2023 11:04:30 -0400 Subject: [PATCH 14/19] change to skip unless Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- tests/test_zarr_avg_merger.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_zarr_avg_merger.py b/tests/test_zarr_avg_merger.py index 016d5cc820..bca76b730f 100644 --- a/tests/test_zarr_avg_merger.py +++ b/tests/test_zarr_avg_merger.py @@ -234,7 +234,7 @@ ] -@unittest.skipIf(not has_zarr or not has_numcodecs, "Requires zarr (and numcodecs) packages.)") +@unittest.skipUnless(has_zarr and has_numcodecs, "Requires zarr (and numcodecs) packages.)") class ZarrAvgMergerTests(unittest.TestCase): @parameterized.expand( [ From 6c231eaa48348557abbac82e6bd9f42d0734deb4 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Wed, 28 Jun 2023 09:46:52 -0400 Subject: [PATCH 15/19] make thread locking optional Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- monai/inferers/merger.py | 31 ++++++++++++++++++------------- tests/test_zarr_avg_merger.py | 27 +++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 13 deletions(-) diff --git a/monai/inferers/merger.py b/monai/inferers/merger.py index c1cb825d63..a723484e3b 100644 --- a/monai/inferers/merger.py +++ b/monai/inferers/merger.py @@ -14,6 +14,7 @@ import threading from abc import ABC, abstractmethod from collections.abc import Sequence +from contextlib import nullcontext from typing import TYPE_CHECKING, Any import numpy as np @@ -237,6 +238,7 @@ def __init__( value_compressor: str | None = None, count_compressor: str | None = None, chunks: Sequence[int] | bool = True, + thread_locking: bool = True, ) -> None: super().__init__(merged_shape=merged_shape, cropped_shape=cropped_shape) if not self.merged_shape: @@ -275,8 +277,13 @@ def __init__( store=self.count_store, overwrite=True, ) - self.lock = threading.Lock() - self.lock_finalize = threading.Lock() + self.lock: threading.Lock | nullcontext + if thread_locking: + # use lock to protect the in-place addition during aggregation + self.lock = threading.Lock() + else: + # use nullcontext to avoid the locking if not needed + self.lock = nullcontext() def aggregate(self, values: torch.Tensor, location: Sequence[int]) -> None: """ @@ -311,17 +318,15 @@ def finalize(self) -> zarr.Array: Returns: zarr.Array: a zarr array of of merged patches """ - # guard against possible multithreading calls - with self.lock_finalize: - # guard against multiple calls to finalize - if not self.is_finalized: - # use chunks for division to fit into memory - for chunk in iterate_over_chunks(self.values.chunks, self.values.cdata_shape): - self.output[chunk] = self.values[chunk] / self.counts[chunk] - # finalize the shape - self.output.resize(self.cropped_shape) - # set finalize flag to protect performing in-place division again - self.is_finalized = True + # guard against multiple calls to finalize + if not self.is_finalized: + # use chunks for division to fit into memory + for chunk in iterate_over_chunks(self.values.chunks, self.values.cdata_shape): + self.output[chunk] = self.values[chunk] / self.counts[chunk] + # finalize the shape + self.output.resize(self.cropped_shape) + # set finalize flag to protect performing in-place division again + self.is_finalized = True return self.output diff --git a/tests/test_zarr_avg_merger.py b/tests/test_zarr_avg_merger.py index bca76b730f..cbc713b442 100644 --- a/tests/test_zarr_avg_merger.py +++ b/tests/test_zarr_avg_merger.py @@ -234,6 +234,31 @@ ] +# test with thread locking +TEST_CASE_16_WITH_LOCK = [ + dict(merged_shape=TENSOR_4x4.shape, thread_locking=True), + [ + (TENSOR_4x4[..., :2, :2], (0, 0)), + (TENSOR_4x4[..., :2, 2:], (0, 2)), + (TENSOR_4x4[..., 2:, :2], (2, 0)), + (TENSOR_4x4[..., 2:, 2:], (2, 2)), + ], + TENSOR_4x4, +] + +# test without thread locking +TEST_CASE_17_WITHOUT_LOCK = [ + dict(merged_shape=TENSOR_4x4.shape, thread_locking=False), + [ + (TENSOR_4x4[..., :2, :2], (0, 0)), + (TENSOR_4x4[..., :2, 2:], (0, 2)), + (TENSOR_4x4[..., 2:, :2], (2, 0)), + (TENSOR_4x4[..., 2:, 2:], (2, 2)), + ], + TENSOR_4x4, +] + + @unittest.skipUnless(has_zarr and has_numcodecs, "Requires zarr (and numcodecs) packages.)") class ZarrAvgMergerTests(unittest.TestCase): @parameterized.expand( @@ -254,6 +279,8 @@ class ZarrAvgMergerTests(unittest.TestCase): TEST_CASE_13_COMPRESSOR_LZ4, TEST_CASE_14_COMPRESSOR_PICKLE, TEST_CASE_15_COMPRESSOR_LZMA, + TEST_CASE_16_WITH_LOCK, + TEST_CASE_17_WITHOUT_LOCK, ] ) def test_zarr_avg_merger_patches(self, arguments, patch_locations, expected): From 8145c084c609c50a095592da2c697371e3987189 Mon Sep 17 00:00:00 2001 From: "Dr. Behrooz Hashemian" <3968947+drbeh@users.noreply.github.com> Date: Wed, 28 Jun 2023 13:11:15 -0400 Subject: [PATCH 16/19] Update monai/inferers/merger.py Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Signed-off-by: Dr. Behrooz Hashemian <3968947+drbeh@users.noreply.github.com> --- monai/inferers/merger.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/inferers/merger.py b/monai/inferers/merger.py index a723484e3b..7ed9668d37 100644 --- a/monai/inferers/merger.py +++ b/monai/inferers/merger.py @@ -203,7 +203,7 @@ class ZarrAvgMerger(Merger): See https://zarr.readthedocs.io/en/stable/ for more details. It is particularly useful for storing N-dimensional arrays too large to fit into memory. One specific use case of this class is to merge patches extracted from whole slide images (WSI), - where the merged results does not fit into memory and need to be stored on a file system. + where the merged results do not fit into memory and need to be stored on a file system. Args: merged_shape: the shape of the tensor required to merge the patches. From 3b84aee4e073ea07b6c0592f4d1b90acf0f7c370 Mon Sep 17 00:00:00 2001 From: "Dr. Behrooz Hashemian" <3968947+drbeh@users.noreply.github.com> Date: Wed, 28 Jun 2023 13:11:25 -0400 Subject: [PATCH 17/19] Update monai/inferers/merger.py Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Signed-off-by: Dr. Behrooz Hashemian <3968947+drbeh@users.noreply.github.com> --- monai/inferers/merger.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/inferers/merger.py b/monai/inferers/merger.py index 7ed9668d37..b0510565aa 100644 --- a/monai/inferers/merger.py +++ b/monai/inferers/merger.py @@ -220,7 +220,7 @@ class ZarrAvgMerger(Merger): count_compressor: the compressor for sample counting zarr array. Default is None. chunks : int or tuple of ints that defines the chunk shape, or boolean. Default is True. If True, chunk shape will be guessed from `shape` and `dtype`. - If False, ir will be set to `shape`, i.e., single chunk for the whole array. + If False, it will be set to `shape`, i.e., single chunk for the whole array. If an int, the chunk size in each dimension will be given by the value of `chunks`. """ From a7f4f939d194b48cfd93f399c0954c4cb7f3d020 Mon Sep 17 00:00:00 2001 From: "Dr. Behrooz Hashemian" <3968947+drbeh@users.noreply.github.com> Date: Wed, 28 Jun 2023 13:12:40 -0400 Subject: [PATCH 18/19] Update monai/inferers/merger.py Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Signed-off-by: Dr. Behrooz Hashemian <3968947+drbeh@users.noreply.github.com> --- monai/inferers/merger.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/monai/inferers/merger.py b/monai/inferers/merger.py index b0510565aa..9901951928 100644 --- a/monai/inferers/merger.py +++ b/monai/inferers/merger.py @@ -292,10 +292,6 @@ def aggregate(self, values: torch.Tensor, location: Sequence[int]) -> None: Args: values: a tensor of shape BCHW[D], representing the values of inference output. location: a tuple/list giving the top left location of the patch in the original image. - - Raises: - NotImplementedError: When the subclass does not override this method. - """ if self.is_finalized: raise ValueError("`ZarrAvgMerger` is already finalized. Please instantiate a new object to aggregate.") From 7cb71a00edde770cab63b8f7681da048b35d4ed5 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 28 Jun 2023 18:27:37 +0100 Subject: [PATCH 19/19] unblock premerge download test Signed-off-by: Wenqi Li --- tests/test_download_url_yandex.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_download_url_yandex.py b/tests/test_download_url_yandex.py index d0946f9f70..a08105a93f 100644 --- a/tests/test_download_url_yandex.py +++ b/tests/test_download_url_yandex.py @@ -29,6 +29,7 @@ class TestDownloadUrlYandex(unittest.TestCase): + @unittest.skip("data source unstable") def test_verify(self): with tempfile.TemporaryDirectory() as tempdir: download_url(url=YANDEX_MODEL_URL, filepath=os.path.join(tempdir, "model.pt"))