From b62d1e118711b2665435bdac9e3cefb5d496a84a Mon Sep 17 00:00:00 2001 From: Yufan He <59374597+heyufan1995@users.noreply.github.com> Date: Wed, 28 Aug 2024 03:14:44 -0500 Subject: [PATCH 01/15] Fix transpose and patch coords bug (#8047) Fixes # . ### Description Fix the bug that causes wrong results in model zoo finetuning. Patch coords was not passed from sliding window to vista3d. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: heyufan1995 Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- docs/requirements.txt | 1 + monai/apps/vista3d/sampler.py | 29 ++++++++++++++++++----------- monai/networks/nets/vista3d.py | 7 +++++-- 3 files changed, 24 insertions(+), 13 deletions(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index ff94f7b6de..7307d8e5f9 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -42,3 +42,4 @@ zarr huggingface_hub pyamg>=5.0.0 packaging +polygraphy diff --git a/monai/apps/vista3d/sampler.py b/monai/apps/vista3d/sampler.py index b7aeb89a2e..17b2d34911 100644 --- a/monai/apps/vista3d/sampler.py +++ b/monai/apps/vista3d/sampler.py @@ -20,8 +20,6 @@ import torch from torch import Tensor -__all__ = ["sample_prompt_pairs"] - ENABLE_SPECIAL = True SPECIAL_INDEX = (23, 24, 25, 26, 27, 57, 128) MERGE_LIST = { @@ -30,6 +28,8 @@ 132: [57], # overlap with trachea merge into airway } +__all__ = ["sample_prompt_pairs"] + def _get_point_label(id: int) -> tuple[int, int]: if id in SPECIAL_INDEX and ENABLE_SPECIAL: @@ -66,22 +66,29 @@ def sample_prompt_pairs( max_backprompt: int, max number of prompt from background. max_point: maximum number of points for each object. include_background: if include 0 into training prompt. If included, background 0 is treated - the same as foreground. Always be False for multi-partial-dataset training. If needed, - can be true for finetuning specific dataset, . + the same as foreground and points will be sampled. Can be true only if user want to segment + background 0 with point clicks, otherwise always be false. drop_label_prob: probability to drop label prompt. drop_point_prob: probability to drop point prompt. point_sampler: sampler to augment masks with supervoxel. point_sampler_kwargs: arguments for point_sampler. Returns: - label_prompt: [B, 1]. The classes used for training automatic segmentation. - point: [B, N, 3]. The corresponding points for each class. - Note that background label prompt requires matching point as well ([0,0,0] is used). - point_label: [B, N]. The corresponding point labels for each point (negative or positive). - -1 is used for padding the background label prompt and will be ignored. - prompt_class: [B, 1], exactly the same with label_prompt for label indexing for training loss. - label_prompt can be None, and prompt_class is used to identify point classes. + tuple: + - label_prompt (Tensor | None): Tensor of shape [B, 1] containing the classes used for + training automatic segmentation. + - point (Tensor | None): Tensor of shape [B, N, 3] representing the corresponding points + for each class. Note that background label prompts require matching points as well + (e.g., [0, 0, 0] is used). + - point_label (Tensor | None): Tensor of shape [B, N] representing the corresponding point + labels for each point (negative or positive). -1 is used for padding the background + label prompt and will be ignored. + - prompt_class (Tensor | None): Tensor of shape [B, 1], exactly the same as label_prompt + for label indexing during training. If label_prompt is None, prompt_class is used to + identify point classes. + """ + # class label number if not labels.shape[0] == 1: raise ValueError("only support batch size 1") diff --git a/monai/networks/nets/vista3d.py b/monai/networks/nets/vista3d.py index 9148e36542..979a090df0 100644 --- a/monai/networks/nets/vista3d.py +++ b/monai/networks/nets/vista3d.py @@ -336,11 +336,11 @@ def set_auto_grad(self, auto_freeze: bool = False, point_freeze: bool = False): def forward( self, input_images: torch.Tensor, + patch_coords: Sequence[slice] | None = None, point_coords: torch.Tensor | None = None, point_labels: torch.Tensor | None = None, class_vector: torch.Tensor | None = None, prompt_class: torch.Tensor | None = None, - patch_coords: Sequence[slice] | None = None, labels: torch.Tensor | None = None, label_set: Sequence[int] | None = None, prev_mask: torch.Tensor | None = None, @@ -421,7 +421,10 @@ def forward( point_coords, point_labels = None, None if point_coords is None and class_vector is None: - return self.NINF_VALUE + torch.zeros([bs, 1, *image_size], device=device) + logits = self.NINF_VALUE + torch.zeros([bs, 1, *image_size], device=device) + if transpose: + logits = logits.transpose(1, 0) + return logits if self.image_embeddings is not None and kwargs.get("keep_cache", False) and class_vector is None: out, out_auto = self.image_embeddings, None From b6d6d77745ccacfc2e5b478bfc20975104f6ff12 Mon Sep 17 00:00:00 2001 From: stayd <77039165+staydelight@users.noreply.github.com> Date: Wed, 28 Aug 2024 18:33:00 +0800 Subject: [PATCH 02/15] Add a mapping function in image_reader.py and image_writer.py (#7769) Add a function to create a JSON file that maps input and output paths. Fixes #7557 . ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: staydelight Co-authored-by: staydelight Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- docs/source/transforms.rst | 12 +++ monai/transforms/__init__.py | 14 +++- monai/transforms/io/array.py | 60 ++++++++++++++- monai/transforms/io/dictionary.py | 31 +++++++- monai/utils/enums.py | 1 + tests/test_mapping_file.py | 117 ++++++++++++++++++++++++++++ tests/test_mapping_filed.py | 122 ++++++++++++++++++++++++++++++ 7 files changed, 351 insertions(+), 6 deletions(-) create mode 100644 tests/test_mapping_file.py create mode 100644 tests/test_mapping_filed.py diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 637f0873f1..3e45d899ec 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -554,6 +554,12 @@ IO :members: :special-members: __call__ +`WriteFileMapping` +"""""""""""""""""" +.. autoclass:: WriteFileMapping + :members: + :special-members: __call__ + NVIDIA Tool Extension (NVTX) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -1642,6 +1648,12 @@ IO (Dict) :members: :special-members: __call__ +`WriteFileMappingd` +""""""""""""""""""" +.. autoclass:: WriteFileMappingd + :members: + :special-members: __call__ + Post-processing (Dict) ^^^^^^^^^^^^^^^^^^^^^^ diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 9548443768..f37016e63f 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -238,8 +238,18 @@ ) from .inverse import InvertibleTransform, TraceableTransform from .inverse_batch_transform import BatchInverseTransform, Decollated, DecollateD, DecollateDict -from .io.array import SUPPORTED_READERS, LoadImage, SaveImage -from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict +from .io.array import SUPPORTED_READERS, LoadImage, SaveImage, WriteFileMapping +from .io.dictionary import ( + LoadImaged, + LoadImageD, + LoadImageDict, + SaveImaged, + SaveImageD, + SaveImageDict, + WriteFileMappingd, + WriteFileMappingD, + WriteFileMappingDict, +) from .lazy.array import ApplyPending from .lazy.dictionary import ApplyPendingd, ApplyPendingD, ApplyPendingDict from .lazy.functional import apply_pending diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 7c0e8f7123..4e71870fc9 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -15,6 +15,7 @@ from __future__ import annotations import inspect +import json import logging import sys import traceback @@ -45,11 +46,19 @@ from monai.transforms.utility.array import EnsureChannelFirst from monai.utils import GridSamplePadMode from monai.utils import ImageMetaKey as Key -from monai.utils import OptionalImportError, convert_to_dst_type, ensure_tuple, look_up_option, optional_import +from monai.utils import ( + MetaKeys, + OptionalImportError, + convert_to_dst_type, + ensure_tuple, + look_up_option, + optional_import, +) nib, _ = optional_import("nibabel") Image, _ = optional_import("PIL.Image") nrrd, _ = optional_import("nrrd") +FileLock, has_filelock = optional_import("filelock", name="FileLock") __all__ = ["LoadImage", "SaveImage", "SUPPORTED_READERS"] @@ -505,7 +514,7 @@ def __call__( else: self._data_index += 1 if self.savepath_in_metadict and meta_data is not None: - meta_data["saved_to"] = filename + meta_data[MetaKeys.SAVED_TO] = filename return img msg = "\n".join([f"{e}" for e in err]) raise RuntimeError( @@ -514,3 +523,50 @@ def __call__( " https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies.\n" f" The current registered writers for {self.output_ext}: {self.writers}.\n{msg}" ) + + +class WriteFileMapping(Transform): + """ + Writes a JSON file that logs the mapping between input image paths and their corresponding output paths. + This class uses FileLock to ensure safe writing to the JSON file in a multiprocess environment. + + Args: + mapping_file_path (Path or str): Path to the JSON file where the mappings will be saved. + """ + + def __init__(self, mapping_file_path: Path | str = "mapping.json"): + self.mapping_file_path = Path(mapping_file_path) + + def __call__(self, img: NdarrayOrTensor): + """ + Args: + img: The input image with metadata. + """ + if isinstance(img, MetaTensor): + meta_data = img.meta + + if MetaKeys.SAVED_TO not in meta_data: + raise KeyError( + "Missing 'saved_to' key in metadata. Check SaveImage argument 'savepath_in_metadict' is True." + ) + + input_path = meta_data[Key.FILENAME_OR_OBJ] + output_path = meta_data[MetaKeys.SAVED_TO] + log_data = {"input": input_path, "output": output_path} + + if has_filelock: + with FileLock(str(self.mapping_file_path) + ".lock"): + self._write_to_file(log_data) + else: + self._write_to_file(log_data) + return img + + def _write_to_file(self, log_data): + try: + with self.mapping_file_path.open("r") as f: + existing_log_data = json.load(f) + except (FileNotFoundError, json.JSONDecodeError): + existing_log_data = [] + existing_log_data.append(log_data) + with self.mapping_file_path.open("w") as f: + json.dump(existing_log_data, f, indent=4) diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py index 4da1d422ca..be1e78db8a 100644 --- a/monai/transforms/io/dictionary.py +++ b/monai/transforms/io/dictionary.py @@ -17,16 +17,17 @@ from __future__ import annotations +from collections.abc import Hashable, Mapping from pathlib import Path from typing import Callable import numpy as np import monai -from monai.config import DtypeLike, KeysCollection +from monai.config import DtypeLike, KeysCollection, NdarrayOrTensor from monai.data import image_writer from monai.data.image_reader import ImageReader -from monai.transforms.io.array import LoadImage, SaveImage +from monai.transforms.io.array import LoadImage, SaveImage, WriteFileMapping from monai.transforms.transform import MapTransform, Transform from monai.utils import GridSamplePadMode, ensure_tuple, ensure_tuple_rep from monai.utils.enums import PostFix @@ -320,5 +321,31 @@ def __call__(self, data): return d +class WriteFileMappingd(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.WriteFileMapping`. + + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + mapping_file_path: Path to the JSON file where the mappings will be saved. + Defaults to "mapping.json". + allow_missing_keys: don't raise exception if key is missing. + """ + + def __init__( + self, keys: KeysCollection, mapping_file_path: Path | str = "mapping.json", allow_missing_keys: bool = False + ) -> None: + super().__init__(keys, allow_missing_keys) + self.mapping = WriteFileMapping(mapping_file_path) + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: + d = dict(data) + for key in self.key_iterator(d): + d[key] = self.mapping(d[key]) + return d + + LoadImageD = LoadImageDict = LoadImaged SaveImageD = SaveImageDict = SaveImaged +WriteFileMappingD = WriteFileMappingDict = WriteFileMappingd diff --git a/monai/utils/enums.py b/monai/utils/enums.py index b786e92151..eba1be18ed 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -543,6 +543,7 @@ class MetaKeys(StrEnum): SPATIAL_SHAPE = "spatial_shape" # optional key for the length in each spatial dimension SPACE = "space" # possible values of space type are defined in `SpaceKeys` ORIGINAL_CHANNEL_DIM = "original_channel_dim" # an integer or float("nan") + SAVED_TO = "saved_to" class ColorOrder(StrEnum): diff --git a/tests/test_mapping_file.py b/tests/test_mapping_file.py new file mode 100644 index 0000000000..97fa4312ed --- /dev/null +++ b/tests/test_mapping_file.py @@ -0,0 +1,117 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +import os +import shutil +import tempfile +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.data import DataLoader, Dataset +from monai.transforms import Compose, LoadImage, SaveImage, WriteFileMapping +from monai.utils import optional_import + +nib, has_nib = optional_import("nibabel") + + +def create_input_file(temp_dir, name): + test_image = np.random.rand(128, 128, 128) + output_ext = ".nii.gz" + input_file = os.path.join(temp_dir, name + output_ext) + nib.save(nib.Nifti1Image(test_image, np.eye(4)), input_file) + return input_file + + +def create_transform(temp_dir, mapping_file_path, savepath_in_metadict=True): + return Compose( + [ + LoadImage(image_only=True), + SaveImage(output_dir=temp_dir, output_ext=".nii.gz", savepath_in_metadict=savepath_in_metadict), + WriteFileMapping(mapping_file_path=mapping_file_path), + ] + ) + + +@unittest.skipUnless(has_nib, "nibabel required") +class TestWriteFileMapping(unittest.TestCase): + def setUp(self): + self.temp_dir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.temp_dir) + + @parameterized.expand([(True,), (False,)]) + def test_mapping_file(self, savepath_in_metadict): + mapping_file_path = os.path.join(self.temp_dir, "mapping.json") + name = "test_image" + input_file = create_input_file(self.temp_dir, name) + output_file = os.path.join(self.temp_dir, name, name + "_trans.nii.gz") + + transform = create_transform(self.temp_dir, mapping_file_path, savepath_in_metadict) + + if savepath_in_metadict: + transform(input_file) + self.assertTrue(os.path.exists(mapping_file_path)) + with open(mapping_file_path) as f: + mapping_data = json.load(f) + self.assertEqual(len(mapping_data), 1) + self.assertEqual(mapping_data[0]["input"], input_file) + self.assertEqual(mapping_data[0]["output"], output_file) + else: + with self.assertRaises(RuntimeError) as cm: + transform(input_file) + cause_exception = cm.exception.__cause__ + self.assertIsInstance(cause_exception, KeyError) + self.assertIn( + "Missing 'saved_to' key in metadata. Check SaveImage argument 'savepath_in_metadict' is True.", + str(cause_exception), + ) + + def test_multiprocess_mapping_file(self): + num_images = 50 + + single_mapping_file = os.path.join(self.temp_dir, "single_mapping.json") + multi_mapping_file = os.path.join(self.temp_dir, "multi_mapping.json") + + data = [create_input_file(self.temp_dir, f"test_image_{i}") for i in range(num_images)] + + # single process + single_transform = create_transform(self.temp_dir, single_mapping_file) + single_dataset = Dataset(data=data, transform=single_transform) + single_loader = DataLoader(single_dataset, batch_size=1, num_workers=0, shuffle=True) + for _ in single_loader: + pass + + # multiple processes + multi_transform = create_transform(self.temp_dir, multi_mapping_file) + multi_dataset = Dataset(data=data, transform=multi_transform) + multi_loader = DataLoader(multi_dataset, batch_size=4, num_workers=3, shuffle=True) + for _ in multi_loader: + pass + + with open(single_mapping_file) as f: + single_mapping_data = json.load(f) + with open(multi_mapping_file) as f: + multi_mapping_data = json.load(f) + + single_set = {(entry["input"], entry["output"]) for entry in single_mapping_data} + multi_set = {(entry["input"], entry["output"]) for entry in multi_mapping_data} + + self.assertEqual(single_set, multi_set) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_mapping_filed.py b/tests/test_mapping_filed.py new file mode 100644 index 0000000000..d0f8bcf938 --- /dev/null +++ b/tests/test_mapping_filed.py @@ -0,0 +1,122 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +import os +import shutil +import tempfile +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.data import DataLoader, Dataset, decollate_batch +from monai.inferers import sliding_window_inference +from monai.networks.nets import UNet +from monai.transforms import Compose, EnsureChannelFirstd, LoadImaged, SaveImaged, WriteFileMappingd +from monai.utils import optional_import + +nib, has_nib = optional_import("nibabel") + + +def create_input_file(temp_dir, name): + test_image = np.random.rand(128, 128, 128) + input_file = os.path.join(temp_dir, name + ".nii.gz") + nib.save(nib.Nifti1Image(test_image, np.eye(4)), input_file) + return input_file + + +# Test cases that should succeed +SUCCESS_CASES = [(["seg"], ["seg"]), (["image", "seg"], ["seg"])] + +# Test cases that should fail +FAILURE_CASES = [(["seg"], ["image"]), (["image"], ["seg"]), (["seg"], ["image", "seg"])] + + +@unittest.skipUnless(has_nib, "nibabel required") +class TestWriteFileMappingd(unittest.TestCase): + def setUp(self): + self.temp_dir = tempfile.mkdtemp() + self.output_dir = os.path.join(self.temp_dir, "output") + os.makedirs(self.output_dir) + self.mapping_file_path = os.path.join(self.temp_dir, "mapping.json") + + def tearDown(self): + shutil.rmtree(self.temp_dir) + if os.path.exists(self.mapping_file_path): + os.remove(self.mapping_file_path) + + def run_test(self, save_keys, write_keys): + name = "test_image" + input_file = create_input_file(self.temp_dir, name) + output_file = os.path.join(self.output_dir, name, name + "_seg.nii.gz") + data = [{"image": input_file}] + + test_transforms = Compose([LoadImaged(keys=["image"]), EnsureChannelFirstd(keys=["image"])]) + + post_transforms = Compose( + [ + SaveImaged( + keys=save_keys, + meta_keys="image_meta_dict", + output_dir=self.output_dir, + output_postfix="seg", + savepath_in_metadict=True, + ), + WriteFileMappingd(keys=write_keys, mapping_file_path=self.mapping_file_path), + ] + ) + + dataset = Dataset(data=data, transform=test_transforms) + dataloader = DataLoader(dataset, batch_size=1) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = UNet(spatial_dims=3, in_channels=1, out_channels=2, channels=(16, 32), strides=(2,)).to(device) + model.eval() + + with torch.no_grad(): + for batch_data in dataloader: + test_inputs = batch_data["image"].to(device) + roi_size = (64, 64, 64) + sw_batch_size = 2 + batch_data["seg"] = sliding_window_inference(test_inputs, roi_size, sw_batch_size, model) + batch_data = [post_transforms(i) for i in decollate_batch(batch_data)] + + return input_file, output_file + + @parameterized.expand(SUCCESS_CASES) + def test_successful_mapping_filed(self, save_keys, write_keys): + input_file, output_file = self.run_test(save_keys, write_keys) + self.assertTrue(os.path.exists(self.mapping_file_path)) + with open(self.mapping_file_path) as f: + mapping_data = json.load(f) + self.assertEqual(len(mapping_data), len(write_keys)) + for entry in mapping_data: + self.assertEqual(entry["input"], input_file) + self.assertEqual(entry["output"], output_file) + + @parameterized.expand(FAILURE_CASES) + def test_failure_mapping_filed(self, save_keys, write_keys): + with self.assertRaises(RuntimeError) as cm: + self.run_test(save_keys, write_keys) + + cause_exception = cm.exception.__cause__ + self.assertIsInstance(cause_exception, KeyError) + self.assertIn( + "Missing 'saved_to' key in metadata. Check SaveImage argument 'savepath_in_metadict' is True.", + str(cause_exception), + ) + + +if __name__ == "__main__": + unittest.main() From 29ce1a743bc067c259ac6646ec67c111a84ee80a Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Thu, 29 Aug 2024 11:52:51 +0800 Subject: [PATCH 03/15] Use torch_tensorrt.Device instead of torch.device in trt compile (#8051) Fixes #8050 ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/networks/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index f301c2dd5c..bd65ffa33e 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -851,7 +851,7 @@ def _onnx_trt_compile( # wrap the serialized TensorRT engine back to a TorchScript module. trt_model = torch_tensorrt.ts.embed_engine_in_new_module( f.getvalue(), - device=torch.device(f"cuda:{device}"), + device=torch_tensorrt.Device(f"cuda:{device}"), input_binding_names=input_names, output_binding_names=output_names, ) From b209347c0b804d966a83141d602777da5e27f1b7 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Sat, 31 Aug 2024 00:45:29 +0800 Subject: [PATCH 04/15] Ensure synchronization by adding cuda.synchronize() (#8058) Fixes #8054 ### Description Add cuda sync after invoke cuda ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/networks/layers/filtering.py | 8 ++++++-- monai/transforms/utils.py | 3 ++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/monai/networks/layers/filtering.py b/monai/networks/layers/filtering.py index 0ff1187dcc..c48c77cf98 100644 --- a/monai/networks/layers/filtering.py +++ b/monai/networks/layers/filtering.py @@ -51,6 +51,8 @@ def forward(ctx, input, spatial_sigma=5, color_sigma=0.5, fast_approx=True): ctx.cs = color_sigma ctx.fa = fast_approx output_data = _C.bilateral_filter(input, spatial_sigma, color_sigma, fast_approx) + if torch.cuda.is_available(): + torch.cuda.synchronize() return output_data @staticmethod @@ -139,7 +141,8 @@ def forward(ctx, input_img, sigma_x, sigma_y, sigma_z, color_sigma): do_dsig_y, do_dsig_z, ) - + if torch.cuda.is_available(): + torch.cuda.synchronize() return output_tensor @staticmethod @@ -301,7 +304,8 @@ def forward(ctx, input_img, guidance_img, sigma_x, sigma_y, sigma_z, color_sigma do_dsig_z, guidance_img, ) - + if torch.cuda.is_available(): + torch.cuda.synchronize() return output_tensor @staticmethod diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 7027c07d67..1d1f070568 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -2512,6 +2512,7 @@ def distance_transform_edt( block_params=block_params, float64_distances=float64_distances, ) + torch.cuda.synchronize() else: if not has_ndimage: raise RuntimeError("scipy.ndimage required if cupy is not available") @@ -2545,7 +2546,7 @@ def distance_transform_edt( r_vals = [] if return_distances and distances_original is None: - r_vals.append(distances) + r_vals.append(distances_ if use_cp else distances) if return_indices and indices_original is None: r_vals.append(indices) if not r_vals: From d0ba8a60950e4a7ffbffa3bab7349117123f1ddb Mon Sep 17 00:00:00 2001 From: Kennett Vera Date: Fri, 30 Aug 2024 17:07:06 -0700 Subject: [PATCH 05/15] Added docstring to address 'Scaling of RandImageFilterd transform #6857' (#8055) Fixes #6857 ### Description Added docstring to RandImageFilterd method which informs the user that they need to manually scale the result image when using this method. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [x] In-line docstrings updated. --------- Signed-off-by: dedeepyasai Signed-off-by: saelra Signed-off-by: Kelvin R Signed-off-by: ken-ni Signed-off-by: Dureti <98233210+DuretiShemsi@users.noreply.github.com> Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: saelra Co-authored-by: Kelvin R. <138339140+K-Rilla@users.noreply.github.com> Co-authored-by: Kelvin R Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: dedeepyasai Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Ratanachat Saelee <146144408+Saelra@users.noreply.github.com> Co-authored-by: Dureti <98233210+DuretiShemsi@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> --- monai/transforms/utility/dictionary.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 7e3a7b0454..2475060f4e 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -1714,6 +1714,10 @@ class RandImageFilterd(MapTransform, RandomizableTransform): Probability the transform is applied to the data allow_missing_keys: Don't raise exception if key is missing. + + Note: + - This transform does not scale output image values automatically to match the range of the input. + The output should be scaled by later transforms to match the input if this is desired. """ backend = ImageFilter.backend From fa1ef8be157d5eb96de17aa78642384f68d99397 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Sat, 31 Aug 2024 22:25:42 +0800 Subject: [PATCH 06/15] Update base image to 2408 (#8049) Fixes #8048 ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- .github/workflows/cron.yml | 24 ++++++++++++------------ .github/workflows/pythonapp-gpu.yml | 4 ++-- Dockerfile | 2 +- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/.github/workflows/cron.yml b/.github/workflows/cron.yml index cc113b0446..6732ab7256 100644 --- a/.github/workflows/cron.yml +++ b/.github/workflows/cron.yml @@ -13,24 +13,24 @@ jobs: strategy: matrix: environment: - - "PT191+CUDA113" - "PT110+CUDA113" - - "PT113+CUDA113" - - "PTLATEST+CUDA121" + - "PT113+CUDA118" + - "PT210+CUDA121" + - "PTLATEST+CUDA124" include: # https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes - environment: PT110+CUDA113 pytorch: "torch==1.10.2 torchvision==0.11.3 --extra-index-url https://download.pytorch.org/whl/cu113" base: "nvcr.io/nvidia/pytorch:21.06-py3" # CUDA 11.3 - - environment: PT113+CUDA113 - pytorch: "torch==1.13.1 torchvision==0.14.1 --extra-index-url https://download.pytorch.org/whl/cu113" - base: "nvcr.io/nvidia/pytorch:21.06-py3" # CUDA 11.3 - - environment: PT113+CUDA122 + - environment: PT113+CUDA118 pytorch: "torch==1.13.1 torchvision==0.14.1 --extra-index-url https://download.pytorch.org/whl/cu121" - base: "nvcr.io/nvidia/pytorch:23.08-py3" # CUDA 12.2 + base: "nvcr.io/nvidia/pytorch:22.10-py3" # CUDA 11.8 + - environment: PT210+CUDA121 + pytorch: "pytorch==2.1.0 torchvision==0.16.0 --extra-index-url https://download.pytorch.org/whl/cu121" + base: "nvcr.io/nvidia/pytorch:23.08-py3" # CUDA 12.1 - environment: PTLATEST+CUDA124 pytorch: "-U torch torchvision --extra-index-url https://download.pytorch.org/whl/cu121" - base: "nvcr.io/nvidia/pytorch:24.03-py3" # CUDA 12.4 + base: "nvcr.io/nvidia/pytorch:24.08-py3" # CUDA 12.4 container: image: ${{ matrix.base }} options: "--gpus all" @@ -80,7 +80,7 @@ jobs: if: github.repository == 'Project-MONAI/MONAI' strategy: matrix: - container: ["pytorch:23.08", "pytorch:24.03"] + container: ["pytorch:23.08", "pytorch:24.08"] container: image: nvcr.io/nvidia/${{ matrix.container }}-py3 # testing with the latest pytorch base image options: "--gpus all" @@ -129,7 +129,7 @@ jobs: if: github.repository == 'Project-MONAI/MONAI' strategy: matrix: - container: ["pytorch:24.03"] + container: ["pytorch:24.08"] container: image: nvcr.io/nvidia/${{ matrix.container }}-py3 # testing with the latest pytorch base image options: "--gpus all" @@ -233,7 +233,7 @@ jobs: if: github.repository == 'Project-MONAI/MONAI' needs: cron-gpu # so that monai itself is verified first container: - image: nvcr.io/nvidia/pytorch:24.03-py3 # testing with the latest pytorch base image + image: nvcr.io/nvidia/pytorch:24.08-py3 # testing with the latest pytorch base image options: "--gpus all --ipc=host" runs-on: [self-hosted, linux, x64, integration] steps: diff --git a/.github/workflows/pythonapp-gpu.yml b/.github/workflows/pythonapp-gpu.yml index ead622b39c..70c3153076 100644 --- a/.github/workflows/pythonapp-gpu.yml +++ b/.github/workflows/pythonapp-gpu.yml @@ -44,9 +44,9 @@ jobs: pytorch: "-h" # we explicitly set pytorch to -h to avoid pip install error base: "nvcr.io/nvidia/pytorch:23.08-py3" - environment: PT210+CUDA121DOCKER - # 24.03: 2.3.0a0+40ec155e58.nv24.3 + # 24.08: 2.3.0a0+40ec155e58.nv24.3 pytorch: "-h" # we explicitly set pytorch to -h to avoid pip install error - base: "nvcr.io/nvidia/pytorch:24.03-py3" + base: "nvcr.io/nvidia/pytorch:24.08-py3" container: image: ${{ matrix.base }} options: --gpus all --env NVIDIA_DISABLE_REQUIRE=true # workaround for unsatisfied condition: cuda>=11.6 diff --git a/Dockerfile b/Dockerfile index 8e255597d1..e97836e3ce 100644 --- a/Dockerfile +++ b/Dockerfile @@ -11,7 +11,7 @@ # To build with a different base image # please run `docker build` using the `--build-arg PYTORCH_IMAGE=...` flag. -ARG PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:24.03-py3 +ARG PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:24.08-py3 FROM ${PYTORCH_IMAGE} LABEL maintainer="monai.contact@gmail.com" From c9f8d328fc0196ef166007a81dffc20a321f30af Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Sun, 1 Sep 2024 07:04:41 -0700 Subject: [PATCH 07/15] Added TRTWrapper (#7990) ### Description Added alternative class to ONNX->TRT export and wrap TRT engines for inference. It encapsulates filesystem persistence and does not rely on torch-tensortrt for execution. Also can be used to run ONNX with onnxruntime. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Boris Fomitchev Signed-off-by: Yiheng Wang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Co-authored-by: Yiheng Wang Co-authored-by: binliunls <107988372+binliunls@users.noreply.github.com> --- Dockerfile | 1 + docs/source/config_syntax.md | 42 +++ monai/bundle/config_parser.py | 8 +- monai/bundle/scripts.py | 4 +- monai/bundle/utils.py | 36 +- monai/handlers/__init__.py | 1 + monai/handlers/trt_handler.py | 61 ++++ monai/networks/__init__.py | 2 + monai/networks/nets/swin_unetr.py | 8 +- monai/networks/trt_compiler.py | 565 ++++++++++++++++++++++++++++++ monai/networks/utils.py | 264 +++++++++++--- requirements-dev.txt | 2 + setup.cfg | 3 + tests/min_tests.py | 1 + tests/test_config_parser.py | 32 ++ tests/test_sure_loss.py | 2 +- tests/test_trt_compile.py | 140 ++++++++ 17 files changed, 1121 insertions(+), 51 deletions(-) create mode 100644 monai/handlers/trt_handler.py create mode 100644 monai/networks/trt_compiler.py create mode 100644 tests/test_trt_compile.py diff --git a/Dockerfile b/Dockerfile index e97836e3ce..e45932c6bb 100644 --- a/Dockerfile +++ b/Dockerfile @@ -56,4 +56,5 @@ RUN apt-get update \ && rm -rf /var/lib/apt/lists/* # append /opt/tools to runtime path for NGC CLI to be accessible from all file system locations ENV PATH=${PATH}:/opt/tools +ENV POLYGRAPHY_AUTOINSTALL_DEPS=1 WORKDIR /opt/monai diff --git a/docs/source/config_syntax.md b/docs/source/config_syntax.md index c932879b5a..742841acca 100644 --- a/docs/source/config_syntax.md +++ b/docs/source/config_syntax.md @@ -16,6 +16,7 @@ Content: - [`$` to evaluate as Python expressions](#to-evaluate-as-python-expressions) - [`%` to textually replace configuration elements](#to-textually-replace-configuration-elements) - [`_target_` (`_disabled_`, `_desc_`, `_requires_`, `_mode_`) to instantiate a Python object](#instantiate-a-python-object) + - [`+` to alter semantics of merging config keys from multiple configuration files](#multiple-config-files) - [The command line interface](#the-command-line-interface) - [Recommendations](#recommendations) @@ -175,6 +176,47 @@ _Description:_ `_requires_`, `_disabled_`, `_desc_`, and `_mode_` are optional k - `"debug"` -- execute with debug prompt and return the return value of ``pdb.runcall(_target_, **kwargs)``, see also [`pdb.runcall`](https://docs.python.org/3/library/pdb.html#pdb.runcall). +## Multiple config files + +_Description:_ Multiple config files may be specified on the command line. +The content of those config files is being merged. When same keys are specifiled in more than one config file, +the value associated with the key is being overridden, in the order config files are specified. +If the desired behaviour is to merge values from both files, the key in second config file should be prefixed with `+`. +The value types for the merged contents must match and be both of `dict` or both of `list` type. +`dict` values will be merged via update(), `list` values - concatenated via extend(). +Here's an example. In this case, "amp" value will be overridden by extra_config.json. +`imports` and `preprocessing#transforms` lists will be merged. An error would be thrown if the value type in `"+imports"` is not `list`: + +config.json: +```json +{ + "amp": "$True" + "imports": [ + "$import torch" + ], + "preprocessing": { + "_target_": "Compose", + "transforms": [ + "$@t1", + "$@t2" + ] + }, +} +``` + +extra_config.json: +```json +{ + "amp": "$False" + "+imports": [ + "$from monai.networks import trt_compile" + ], + "+preprocessing#transforms": [ + "$@t3" + ] +} +``` + ## The command line interface In addition to the Pythonic APIs, a few command line interfaces (CLI) are provided to interact with the bundle. diff --git a/monai/bundle/config_parser.py b/monai/bundle/config_parser.py index a2ffeedc92..1d9920a230 100644 --- a/monai/bundle/config_parser.py +++ b/monai/bundle/config_parser.py @@ -20,7 +20,7 @@ from monai.bundle.config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem from monai.bundle.reference_resolver import ReferenceResolver -from monai.bundle.utils import ID_REF_KEY, ID_SEP_KEY, MACRO_KEY +from monai.bundle.utils import ID_REF_KEY, ID_SEP_KEY, MACRO_KEY, merge_kv from monai.config import PathLike from monai.utils import ensure_tuple, look_up_option, optional_import from monai.utils.misc import CheckKeyDuplicatesYamlLoader, check_key_duplicates @@ -423,8 +423,10 @@ def load_config_files(cls, files: PathLike | Sequence[PathLike] | dict, **kwargs if isinstance(files, str) and not Path(files).is_file() and "," in files: files = files.split(",") for i in ensure_tuple(files): - for k, v in (cls.load_config_file(i, **kwargs)).items(): - parser[k] = v + config_dict = cls.load_config_file(i, **kwargs) + for k, v in config_dict.items(): + merge_kv(parser, k, v) + return parser.get() # type: ignore @classmethod diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 142a366669..f1d1286e4b 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -32,7 +32,7 @@ from monai.apps.utils import _basename, download_url, extractall, get_logger from monai.bundle.config_item import ConfigComponent from monai.bundle.config_parser import ConfigParser -from monai.bundle.utils import DEFAULT_INFERENCE, DEFAULT_METADATA +from monai.bundle.utils import DEFAULT_INFERENCE, DEFAULT_METADATA, merge_kv from monai.bundle.workflows import BundleWorkflow, ConfigWorkflow from monai.config import IgniteInfo, PathLike from monai.data import load_net_with_metadata, save_net_with_metadata @@ -105,7 +105,7 @@ def update_kwargs(args: str | dict | None = None, ignore_none: bool = True, **kw if isinstance(v, dict) and isinstance(args_.get(k), dict): args_[k] = update_kwargs(args_[k], ignore_none, **v) else: - args_[k] = v + merge_kv(args_, k, v) return args_ diff --git a/monai/bundle/utils.py b/monai/bundle/utils.py index 50d2608f4c..53d619f234 100644 --- a/monai/bundle/utils.py +++ b/monai/bundle/utils.py @@ -13,6 +13,7 @@ import json import os +import warnings import zipfile from typing import Any @@ -21,12 +22,21 @@ yaml, _ = optional_import("yaml") -__all__ = ["ID_REF_KEY", "ID_SEP_KEY", "EXPR_KEY", "MACRO_KEY", "DEFAULT_MLFLOW_SETTINGS", "DEFAULT_EXP_MGMT_SETTINGS"] +__all__ = [ + "ID_REF_KEY", + "ID_SEP_KEY", + "EXPR_KEY", + "MACRO_KEY", + "MERGE_KEY", + "DEFAULT_MLFLOW_SETTINGS", + "DEFAULT_EXP_MGMT_SETTINGS", +] ID_REF_KEY = "@" # start of a reference to a ConfigItem ID_SEP_KEY = "::" # separator for the ID of a ConfigItem EXPR_KEY = "$" # start of a ConfigExpression MACRO_KEY = "%" # start of a macro of a config +MERGE_KEY = "+" # prefix indicating merge instead of override in case of multiple configs. _conf_values = get_config_values() @@ -233,3 +243,27 @@ def load_bundle_config(bundle_path: str, *config_names: str, **load_kw_args: Any parser.read_config(f=cdata) return parser + + +def merge_kv(args: dict | Any, k: str, v: Any) -> None: + """ + Update the `args` dict-like object with the key/value pair `k` and `v`. + """ + if k.startswith(MERGE_KEY): + """ + Both values associated with `+`-prefixed key pair must be of `dict` or `list` type. + `dict` values will be merged, `list` values - concatenated. + """ + id = k[1:] + if id in args: + if isinstance(v, dict) and isinstance(args[id], dict): + args[id].update(v) + elif isinstance(v, list) and isinstance(args[id], list): + args[id].extend(v) + else: + raise ValueError(ValueError(f"config must be dict or list for key `{k}`, but got {type(v)}: {v}.")) + else: + warnings.warn(f"Can't merge entry ['{k}'], '{id}' is not in target dict - copying instead.") + args[id] = v + else: + args[k] = v diff --git a/monai/handlers/__init__.py b/monai/handlers/__init__.py index 641f9aae7d..fa6e158be8 100644 --- a/monai/handlers/__init__.py +++ b/monai/handlers/__init__.py @@ -40,5 +40,6 @@ from .stats_handler import StatsHandler from .surface_distance import SurfaceDistance from .tensorboard_handlers import TensorBoardHandler, TensorBoardImageHandler, TensorBoardStatsHandler +from .trt_handler import TrtHandler from .utils import from_engine, ignore_data, stopping_fn_from_loss, stopping_fn_from_metric, write_metrics_reports from .validation_handler import ValidationHandler diff --git a/monai/handlers/trt_handler.py b/monai/handlers/trt_handler.py new file mode 100644 index 0000000000..0e36b59d8c --- /dev/null +++ b/monai/handlers/trt_handler.py @@ -0,0 +1,61 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from monai.config import IgniteInfo +from monai.networks import trt_compile +from monai.utils import min_version, optional_import + +Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") +if TYPE_CHECKING: + from ignite.engine import Engine +else: + Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine") + + +class TrtHandler: + """ + TrtHandler acts as an Ignite handler to apply TRT acceleration to the model. + Usage example:: + handler = TrtHandler(model=model, base_path="/test/checkpoint.pt", args={"precision": "fp16"}) + handler.attach(engine) + engine.run() + """ + + def __init__(self, model, base_path, args=None, submodule=None): + """ + Args: + base_path: TRT path basename. TRT plan(s) saved to "base_path[.submodule].plan" + args: passed to trt_compile(). See trt_compile() for details. + submodule : Hierarchical ids of submodules to convert, e.g. 'image_decoder.decoder' + """ + self.model = model + self.base_path = base_path + self.args = args + self.submodule = submodule + + def attach(self, engine: Engine) -> None: + """ + Args: + engine: Ignite Engine, it can be a trainer, validator or evaluator. + """ + self.logger = engine.logger + engine.add_event_handler(Events.STARTED, self) + + def __call__(self, engine: Engine) -> None: + """ + Args: + engine: Ignite Engine, it can be a trainer, validator or evaluator. + """ + trt_compile(self.model, self.base_path, args=self.args, submodule=self.submodule, logger=self.logger) diff --git a/monai/networks/__init__.py b/monai/networks/__init__.py index 4c429ae813..5a240021d6 100644 --- a/monai/networks/__init__.py +++ b/monai/networks/__init__.py @@ -11,7 +11,9 @@ from __future__ import annotations +from .trt_compiler import trt_compile from .utils import ( + add_casts_around_norms, convert_to_onnx, convert_to_torchscript, convert_to_trt, diff --git a/monai/networks/nets/swin_unetr.py b/monai/networks/nets/swin_unetr.py index 3900c866b3..714d986f4b 100644 --- a/monai/networks/nets/swin_unetr.py +++ b/monai/networks/nets/swin_unetr.py @@ -320,7 +320,7 @@ def _check_input_size(self, spatial_shape): ) def forward(self, x_in): - if not torch.jit.is_scripting(): + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): self._check_input_size(x_in.shape[2:]) hidden_states_out = self.swinViT(x_in, self.normalize) enc0 = self.encoder1(x_in) @@ -1046,14 +1046,14 @@ def __init__( def proj_out(self, x, normalize=False): if normalize: - x_shape = x.size() + x_shape = x.shape + # Force trace() to generate a constant by casting to int + ch = int(x_shape[1]) if len(x_shape) == 5: - n, ch, d, h, w = x_shape x = rearrange(x, "n c d h w -> n d h w c") x = F.layer_norm(x, [ch]) x = rearrange(x, "n d h w c -> n c d h w") elif len(x_shape) == 4: - n, ch, h, w = x_shape x = rearrange(x, "n c h w -> n h w c") x = F.layer_norm(x, [ch]) x = rearrange(x, "n h w c -> n c h w") diff --git a/monai/networks/trt_compiler.py b/monai/networks/trt_compiler.py new file mode 100644 index 0000000000..a9dd0d9e9b --- /dev/null +++ b/monai/networks/trt_compiler.py @@ -0,0 +1,565 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import inspect +import os +import tempfile +import threading +from collections import OrderedDict +from pathlib import Path +from types import MethodType +from typing import Any, Dict, List, Union + +import torch + +from monai.apps.utils import get_logger +from monai.networks.utils import add_casts_around_norms, convert_to_onnx, convert_to_torchscript, get_profile_shapes +from monai.utils.module import optional_import + +polygraphy, polygraphy_imported = optional_import("polygraphy") +if polygraphy_imported: + from polygraphy.backend.common import bytes_from_path + from polygraphy.backend.trt import ( + CreateConfig, + Profile, + engine_bytes_from_network, + engine_from_bytes, + network_from_onnx_path, + ) + +trt, trt_imported = optional_import("tensorrt") +torch_tensorrt, _ = optional_import("torch_tensorrt", "1.4.0") +cudart, _ = optional_import("cuda.cudart") + + +lock_sm = threading.Lock() + + +# Map of TRT dtype -> Torch dtype +def trt_to_torch_dtype_dict(): + return { + trt.int32: torch.int32, + trt.float32: torch.float32, + trt.float16: torch.float16, + trt.bfloat16: torch.float16, + trt.int64: torch.int64, + trt.int8: torch.int8, + trt.bool: torch.bool, + } + + +def get_dynamic_axes(profiles): + """ + This method calculates dynamic_axes to use in onnx.export(). + Args: + profiles: [[min,opt,max],...] list of profile dimensions + """ + dynamic_axes: dict[str, list[int]] = {} + if not profiles: + return dynamic_axes + for profile in profiles: + for key in profile: + axes = [] + vals = profile[key] + for i in range(len(vals[0])): + if vals[0][i] != vals[2][i]: + axes.append(i) + if len(axes) > 0: + dynamic_axes[key] = axes + return dynamic_axes + + +def cuassert(cuda_ret): + """ + Error reporting method for CUDA calls. + Args: + cuda_ret: CUDA return code. + """ + err = cuda_ret[0] + if err != 0: + raise RuntimeError(f"CUDA ERROR: {err}") + if len(cuda_ret) > 1: + return cuda_ret[1] + return None + + +class ShapeError(Exception): + """ + Exception class to report errors from setting TRT plan input shapes + """ + + pass + + +class TRTEngine: + """ + An auxiliary class to implement running of TRT optimized engines + + """ + + def __init__(self, plan_path, logger=None): + """ + Loads serialized engine, creates execution context and activates it + Args: + plan_path: path to serialized TRT engine. + logger: optional logger object + """ + self.plan_path = plan_path + self.logger = logger or get_logger("trt_compile") + self.logger.info(f"Loading TensorRT engine: {self.plan_path}") + self.engine = engine_from_bytes(bytes_from_path(self.plan_path)) + self.tensors = OrderedDict() + self.cuda_graph_instance = None # cuda graph + self.context = self.engine.create_execution_context() + self.input_names = [] + self.output_names = [] + self.dtypes = [] + self.cur_profile = 0 + dtype_dict = trt_to_torch_dtype_dict() + for idx in range(self.engine.num_io_tensors): + binding = self.engine[idx] + if self.engine.get_tensor_mode(binding) == trt.TensorIOMode.INPUT: + self.input_names.append(binding) + elif self.engine.get_tensor_mode(binding) == trt.TensorIOMode.OUTPUT: + self.output_names.append(binding) + dtype = dtype_dict[self.engine.get_tensor_dtype(binding)] + self.dtypes.append(dtype) + + def allocate_buffers(self, device): + """ + Allocates outputs to run TRT engine + Args: + device: GPU device to allocate memory on + """ + ctx = self.context + + for i, binding in enumerate(self.output_names): + shape = list(ctx.get_tensor_shape(binding)) + if binding not in self.tensors or list(self.tensors[binding].shape) != shape: + t = torch.empty(shape, dtype=self.dtypes[i], device=device).contiguous() + self.tensors[binding] = t + ctx.set_tensor_address(binding, t.data_ptr()) + + def set_inputs(self, feed_dict, stream): + """ + Sets input bindings for TRT engine according to feed_dict + Args: + feed_dict: a dictionary [str->Tensor] + stream: CUDA stream to use + """ + e = self.engine + ctx = self.context + + last_profile = self.cur_profile + + def try_set_inputs(): + for binding, t in feed_dict.items(): + if t is not None: + t = t.contiguous() + shape = t.shape + ctx.set_input_shape(binding, shape) + ctx.set_tensor_address(binding, t.data_ptr()) + + while True: + try: + try_set_inputs() + break + except ShapeError: + next_profile = (self.cur_profile + 1) % e.num_optimization_profiles + if next_profile == last_profile: + raise + self.cur_profile = next_profile + ctx.set_optimization_profile_async(self.cur_profile, stream) + + left = ctx.infer_shapes() + assert len(left) == 0 + + def infer(self, stream, use_cuda_graph=False): + """ + Runs TRT engine. + Args: + stream: CUDA stream to run on + use_cuda_graph: use CUDA graph. Note: requires all inputs to be the same GPU memory between calls. + """ + if use_cuda_graph: + if self.cuda_graph_instance is not None: + cuassert(cudart.cudaGraphLaunch(self.cuda_graph_instance, stream)) + cuassert(cudart.cudaStreamSynchronize(stream)) + else: + # do inference before CUDA graph capture + noerror = self.context.execute_async_v3(stream) + if not noerror: + raise ValueError("ERROR: inference failed.") + # capture cuda graph + cuassert( + cudart.cudaStreamBeginCapture(stream, cudart.cudaStreamCaptureMode.cudaStreamCaptureModeThreadLocal) + ) + self.context.execute_async_v3(stream) + graph = cuassert(cudart.cudaStreamEndCapture(stream)) + self.cuda_graph_instance = cuassert(cudart.cudaGraphInstantiate(graph, 0)) + self.logger.info("CUDA Graph captured!") + else: + noerror = self.context.execute_async_v3(stream) + cuassert(cudart.cudaStreamSynchronize(stream)) + if not noerror: + raise ValueError("ERROR: inference failed.") + + return self.tensors + + +class TrtCompiler: + """ + This class implements: + - TRT lazy persistent export + - Running TRT with optional fallback to Torch + (for TRT engines with limited profiles) + """ + + def __init__( + self, + model, + plan_path, + precision="fp16", + method="onnx", + input_names=None, + output_names=None, + export_args=None, + build_args=None, + input_profiles=None, + dynamic_batchsize=None, + use_cuda_graph=False, + timestamp=None, + fallback=False, + logger=None, + ): + """ + Initialization method: + Tries to load persistent serialized TRT engine + Saves its arguments for lazy TRT build on first forward() call + Args: + model: Model to "wrap". + plan_path : Path where to save persistent serialized TRT engine. + precision: TRT builder precision o engine model. Should be 'fp32'|'tf32'|'fp16'|'bf16'. + method: One of 'onnx'|'torch_trt'. + Default is 'onnx' (torch.onnx.export()->TRT). This is the most stable and efficient option. + 'torch_trt' may not work for some nets. Also AMP must be turned off for it to work. + input_names: Optional list of input names. If None, will be read from the function signature. + output_names: Optional list of output names. Note: If not None, patched forward() will return a dictionary. + export_args: Optional args to pass to export method. See onnx.export() and Torch-TensorRT docs for details. + build_args: Optional args to pass to TRT builder. See polygraphy.Config for details. + input_profiles: Optional list of profiles for TRT builder and ONNX export. + Each profile is a map of the form : {"input id" : [min_shape, opt_shape, max_shape], ...}. + dynamic_batchsize: A sequence with three elements to define the batch size range of the input for the model to be + converted. Should be a sequence like [MIN_BATCH, OPT_BATCH, MAX_BATCH]. + [note]: If neither input_profiles nor dynamic_batchsize specified, static shapes will be used to build TRT engine. + use_cuda_graph: Use CUDA Graph for inference. Note: all inputs have to be the same GPU memory between calls! + timestamp: Optional timestamp to rebuild TRT engine (e.g. if config file changes). + fallback: Allow to fall back to Pytorch when TRT inference fails (e.g, shapes exceed max profile). + """ + + method_vals = ["onnx", "torch_trt"] + if method not in method_vals: + raise ValueError(f"trt_compile(): 'method' should be one of {method_vals}, got: {method}.") + precision_vals = ["fp32", "tf32", "fp16", "bf16"] + if precision not in precision_vals: + raise ValueError(f"trt_compile(): 'precision' should be one of {precision_vals}, got: {precision}.") + + self.plan_path = plan_path + self.precision = precision + self.method = method + self.return_dict = output_names is not None + self.output_names = output_names or [] + self.profiles = input_profiles or [] + self.dynamic_batchsize = dynamic_batchsize + self.export_args = export_args or {} + self.build_args = build_args or {} + self.engine: TRTEngine | None = None + self.use_cuda_graph = use_cuda_graph + self.fallback = fallback + self.disabled = False + + self.logger = logger or get_logger("trt_compile") + + # Normally we read input_names from forward() but can be overridden + if input_names is None: + argspec = inspect.getfullargspec(model.forward) + input_names = argspec.args[1:] + self.input_names = input_names + self.old_forward = model.forward + + # Force engine rebuild if older than the timestamp + if timestamp is not None and os.path.exists(self.plan_path) and os.path.getmtime(self.plan_path) < timestamp: + os.remove(self.plan_path) + + def _inputs_to_dict(self, input_example): + trt_inputs = {} + for i, inp in enumerate(input_example): + input_name = self.input_names[i] + trt_inputs[input_name] = inp + return trt_inputs + + def _load_engine(self): + """ + Loads TRT plan from disk and activates its execution context. + """ + try: + self.engine = TRTEngine(self.plan_path, self.logger) + self.input_names = self.engine.input_names + except Exception as e: + self.logger.debug(f"Exception while loading the engine:\n{e}") + + def forward(self, model, argv, kwargs): + """ + Main forward method: + Builds TRT engine if not available yet. + Tries to run TRT engine + If exception thrown and self.callback==True: falls back to original Pytorch + + Args: Passing through whatever args wrapped module's forward() has + Returns: Passing through wrapped module's forward() return value(s) + + """ + if self.engine is None and not self.disabled: + # Restore original forward for export + new_forward = model.forward + model.forward = self.old_forward + try: + self._load_engine() + if self.engine is None: + build_args = kwargs.copy() + if len(argv) > 0: + build_args.update(self._inputs_to_dict(argv)) + self._build_and_save(model, build_args) + # This will reassign input_names from the engine + self._load_engine() + except Exception as e: + if self.fallback: + self.logger.info(f"Failed to build engine: {e}") + self.disabled = True + else: + raise e + if not self.disabled and not self.fallback: + # Delete all parameters + for param in model.parameters(): + del param + # Call empty_cache to release GPU memory + torch.cuda.empty_cache() + model.forward = new_forward + # Run the engine + try: + if len(argv) > 0: + kwargs.update(self._inputs_to_dict(argv)) + argv = () + + if self.engine is not None: + # forward_trt is not thread safe as we do not use per-thread execution contexts + with lock_sm: + device = torch.cuda.current_device() + stream = torch.cuda.Stream(device=device) + self.engine.set_inputs(kwargs, stream.cuda_stream) + self.engine.allocate_buffers(device=device) + # Need this to synchronize with Torch stream + stream.wait_stream(torch.cuda.current_stream()) + ret = self.engine.infer(stream.cuda_stream, use_cuda_graph=self.use_cuda_graph) + # if output_names is not None, return dictionary + if not self.return_dict: + ret = list(ret.values()) + if len(ret) == 1: + ret = ret[0] + return ret + except Exception as e: + if model is not None: + self.logger.info(f"Exception: {e}\nFalling back to Pytorch ...") + else: + raise e + return self.old_forward(*argv, **kwargs) + + def _onnx_to_trt(self, onnx_path): + """ + Builds TRT engine from ONNX file at onnx_path and saves to self.plan_path + """ + + profiles = [] + if self.profiles: + for input_profile in self.profiles: + if isinstance(input_profile, Profile): + profiles.append(input_profile) + else: + p = Profile() + for name, dims in input_profile.items(): + assert len(dims) == 3 + p.add(name, min=dims[0], opt=dims[1], max=dims[2]) + profiles.append(p) + + build_args = self.build_args.copy() + build_args["tf32"] = self.precision != "fp32" + build_args["fp16"] = self.precision == "fp16" + build_args["bf16"] = self.precision == "bf16" + + self.logger.info(f"Building TensorRT engine for {onnx_path}: {self.plan_path}") + network = network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]) + return engine_bytes_from_network(network, config=CreateConfig(profiles=profiles, **build_args)) + + def _build_and_save(self, model, input_example): + """ + If TRT engine is not ready, exports model to ONNX, + builds TRT engine and saves serialized TRT engine to the disk. + Args: + input_example: passed to onnx.export() + """ + + if self.engine is not None: + return + + export_args = self.export_args + + add_casts_around_norms(model) + + if self.method == "torch_trt": + enabled_precisions = [torch.float32] + if self.precision == "fp16": + enabled_precisions.append(torch.float16) + elif self.precision == "bf16": + enabled_precisions.append(torch.bfloat16) + inputs = list(input_example.values()) + ir_model = convert_to_torchscript(model, inputs=inputs, use_trace=True) + + def get_torch_trt_input(input_shape, dynamic_batchsize): + min_input_shape, opt_input_shape, max_input_shape = get_profile_shapes(input_shape, dynamic_batchsize) + return torch_tensorrt.Input( + min_shape=min_input_shape, opt_shape=opt_input_shape, max_shape=max_input_shape + ) + + tt_inputs = [get_torch_trt_input(i.shape, self.dynamic_batchsize) for i in inputs] + engine_bytes = torch_tensorrt.convert_method_to_trt_engine( + ir_model, + "forward", + inputs=tt_inputs, + ir="torchscript", + enabled_precisions=enabled_precisions, + **export_args, + ) + else: + dbs = self.dynamic_batchsize + if dbs: + if len(self.profiles) > 0: + raise ValueError("ERROR: Both dynamic_batchsize and input_profiles set for TrtCompiler!") + if len(dbs) != 3: + raise ValueError("dynamic_batchsize has to have len ==3 ") + profiles = {} + for id, val in input_example.items(): + sh = val.shape[1:] + profiles[id] = [[dbs[0], *sh], [dbs[1], *sh], [dbs[2], *sh]] + self.profiles = [profiles] + + if len(self.profiles) > 0: + export_args.update({"dynamic_axes": get_dynamic_axes(self.profiles)}) + + # Use temporary directory for easy cleanup in case of external weights + with tempfile.TemporaryDirectory() as tmpdir: + onnx_path = Path(tmpdir) / "model.onnx" + self.logger.info( + f"Exporting to {onnx_path}:\n\toutput_names={self.output_names}\n\texport args: {export_args}" + ) + convert_to_onnx( + model, + input_example, + filename=str(onnx_path), + input_names=self.input_names, + output_names=self.output_names, + **export_args, + ) + self.logger.info("Export to ONNX successful.") + engine_bytes = self._onnx_to_trt(str(onnx_path)) + + open(self.plan_path, "wb").write(engine_bytes) + + +def trt_forward(self, *argv, **kwargs): + """ + Patch function to replace original model's forward() with. + Redirects to TrtCompiler.forward() + """ + return self._trt_compiler.forward(self, argv, kwargs) + + +def trt_compile( + model: torch.nn.Module, + base_path: str, + args: Dict[str, Any] | None = None, + submodule: Union[str, List[str]] | None = None, + logger: Any | None = None, +) -> torch.nn.Module: + """ + Instruments model or submodule(s) with TrtCompiler and replaces its forward() with TRT hook. + Args: + model: module to patch with TrtCompiler object. + base_path: TRT plan(s) saved to f"{base_path}[.{submodule}].plan" path. + dirname(base_path) must exist, base_path does not have to. + If base_path does point to existing file (e.g. associated checkpoint), + that file becomes a dependency - its mtime is added to args["timestamp"]. + args: Optional dict : unpacked and passed to TrtCompiler() - see TrtCompiler above for details. + submodule: Optional hierarchical id(s) of submodule to patch, e.g. ['image_decoder.decoder'] + If None, TrtCompiler patch is applied to the whole model. + Otherwise, submodule (or list of) is being patched. + logger: Optional logger for diagnostics. + Returns: + Always returns same model passed in as argument. This is for ease of use in configs. + """ + + default_args: Dict[str, Any] = { + "method": "onnx", + "precision": "fp16", + "build_args": {"builder_optimization_level": 5, "precision_constraints": "obey"}, + } + + default_args.update(args or {}) + args = default_args + + if trt_imported and polygraphy_imported and torch.cuda.is_available(): + # if "path" filename point to existing file (e.g. checkpoint) + # it's also treated as dependency + if os.path.exists(base_path): + timestamp = int(os.path.getmtime(base_path)) + if "timestamp" in args: + timestamp = max(int(args["timestamp"]), timestamp) + args["timestamp"] = timestamp + + def wrap(model, path): + wrapper = TrtCompiler(model, path + ".plan", logger=logger, **args) + model._trt_compiler = wrapper + model.forward = MethodType(trt_forward, model) + + def find_sub(parent, submodule): + idx = submodule.find(".") + # if there is "." in name, call recursively + if idx != -1: + parent_name = submodule[:idx] + parent = getattr(parent, parent_name) + submodule = submodule[idx + 1 :] + return find_sub(parent, submodule) + return parent, submodule + + if submodule is not None: + if isinstance(submodule, str): + submodule = [submodule] + for s in submodule: + parent, sub = find_sub(model, s) + wrap(getattr(parent, sub), base_path + "." + s) + else: + wrap(model, base_path) + else: + logger = logger or get_logger("trt_compile") + logger.warning("TensorRT and/or polygraphy packages are not available! trt_compile() has no effect.") + + return model diff --git a/monai/networks/utils.py b/monai/networks/utils.py index bd65ffa33e..d0150b4e5b 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -36,6 +36,8 @@ onnx, _ = optional_import("onnx") onnxreference, _ = optional_import("onnx.reference") onnxruntime, _ = optional_import("onnxruntime") +polygraphy, polygraphy_imported = optional_import("polygraphy") +torch_tensorrt, _ = optional_import("torch_tensorrt", "1.4.0") __all__ = [ "one_hot", @@ -61,6 +63,7 @@ "look_up_named_module", "set_named_module", "has_nvfuser_instance_norm", + "get_profile_shapes", ] logger = get_logger(module_name=__name__) @@ -68,6 +71,26 @@ _has_nvfuser = None +def get_profile_shapes(input_shape: Sequence[int], dynamic_batchsize: Sequence[int] | None): + """ + Given a sample input shape, calculate min/opt/max shapes according to dynamic_batchsize. + """ + + def scale_batch_size(input_shape: Sequence[int], scale_num: int): + scale_shape = [*input_shape] + scale_shape[0] = scale_num + return scale_shape + + # Use the dynamic batchsize range to generate the min, opt and max model input shape + if dynamic_batchsize: + min_input_shape = scale_batch_size(input_shape, dynamic_batchsize[0]) + opt_input_shape = scale_batch_size(input_shape, dynamic_batchsize[1]) + max_input_shape = scale_batch_size(input_shape, dynamic_batchsize[2]) + else: + min_input_shape = opt_input_shape = max_input_shape = input_shape + return min_input_shape, opt_input_shape, max_input_shape + + def has_nvfuser_instance_norm(): """whether the current environment has InstanceNorm3dNVFuser https://github.com/NVIDIA/apex/blob/23.05-devel/apex/normalization/instance_norm.py#L15-L16 @@ -606,6 +629,9 @@ def convert_to_onnx( rtol: float = 1e-4, atol: float = 0.0, use_trace: bool = True, + do_constant_folding: bool = True, + constant_size_threshold: int = 16 * 1024 * 1024 * 1024, + dynamo=False, **kwargs, ): """ @@ -632,7 +658,10 @@ def convert_to_onnx( rtol: the relative tolerance when comparing the outputs of PyTorch model and TorchScript model. atol: the absolute tolerance when comparing the outputs of PyTorch model and TorchScript model. use_trace: whether to use `torch.jit.trace` to export the torchscript model. - kwargs: other arguments except `obj` for `torch.jit.script()` to convert model, for more details: + do_constant_folding: passed to onnx.export(). If True, extra polygraphy folding pass is done. + constant_size_threshold: passed to polygrapy conatant forling, default = 16M + kwargs: if use_trace=True: additional arguments to pass to torch.onnx.export() + else: other arguments except `obj` for `torch.jit.script()` to convert model, for more details: https://pytorch.org/docs/master/generated/torch.jit.script.html. """ @@ -642,6 +671,7 @@ def convert_to_onnx( if use_trace: # let torch.onnx.export to trace the model. mode_to_export = model + torch_versioned_kwargs = kwargs else: if not pytorch_after(1, 10): if "example_outputs" not in kwargs: @@ -654,32 +684,37 @@ def convert_to_onnx( del kwargs["example_outputs"] mode_to_export = torch.jit.script(model, **kwargs) + if torch.is_tensor(inputs) or isinstance(inputs, dict): + onnx_inputs = (inputs,) + else: + onnx_inputs = tuple(inputs) + if filename is None: f = io.BytesIO() - torch.onnx.export( - mode_to_export, - tuple(inputs), - f=f, - input_names=input_names, - output_names=output_names, - dynamic_axes=dynamic_axes, - opset_version=opset_version, - **torch_versioned_kwargs, - ) + else: + f = filename + + torch.onnx.export( + mode_to_export, + onnx_inputs, + f=f, + input_names=input_names, + output_names=output_names, + dynamic_axes=dynamic_axes, + opset_version=opset_version, + do_constant_folding=do_constant_folding, + **torch_versioned_kwargs, + ) + if filename is None: onnx_model = onnx.load_model_from_string(f.getvalue()) else: - torch.onnx.export( - mode_to_export, - tuple(inputs), - f=filename, - input_names=input_names, - output_names=output_names, - dynamic_axes=dynamic_axes, - opset_version=opset_version, - **torch_versioned_kwargs, - ) onnx_model = onnx.load(filename) + if do_constant_folding and polygraphy_imported: + from polygraphy.backend.onnx.loader import fold_constants + + fold_constants(onnx_model, size_threshold=constant_size_threshold) + if verify: if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -814,7 +849,6 @@ def _onnx_trt_compile( """ trt, _ = optional_import("tensorrt", "8.5.3") - torch_tensorrt, _ = optional_import("torch_tensorrt", "1.4.0") input_shapes = (min_shape, opt_shape, max_shape) # default to an empty list to fit the `torch_tensorrt.ts.embed_engine_in_new_module` function. @@ -916,8 +950,6 @@ def convert_to_trt( to compile model, for more details: https://pytorch.org/TensorRT/py_api/torch_tensorrt.html#torch-tensorrt-py. """ - torch_tensorrt, _ = optional_import("torch_tensorrt", version="1.4.0") - if not torch.cuda.is_available(): raise Exception("Cannot find any GPU devices.") @@ -935,23 +967,9 @@ def convert_to_trt( convert_precision = torch.float32 if precision == "fp32" else torch.half inputs = [torch.rand(ensure_tuple(input_shape)).to(target_device)] - def scale_batch_size(input_shape: Sequence[int], scale_num: int): - scale_shape = [*input_shape] - scale_shape[0] *= scale_num - return scale_shape - - # Use the dynamic batchsize range to generate the min, opt and max model input shape - if dynamic_batchsize: - min_input_shape = scale_batch_size(input_shape, dynamic_batchsize[0]) - opt_input_shape = scale_batch_size(input_shape, dynamic_batchsize[1]) - max_input_shape = scale_batch_size(input_shape, dynamic_batchsize[2]) - else: - min_input_shape = opt_input_shape = max_input_shape = input_shape - # convert the torch model to a TorchScript model on target device model = model.eval().to(target_device) - ir_model = convert_to_torchscript(model, device=target_device, inputs=inputs, use_trace=use_trace) - ir_model.eval() + min_input_shape, opt_input_shape, max_input_shape = get_profile_shapes(input_shape, dynamic_batchsize) if use_onnx: # set the batch dim as dynamic @@ -960,7 +978,6 @@ def scale_batch_size(input_shape: Sequence[int], scale_num: int): ir_model = convert_to_onnx( model, inputs, onnx_input_names, onnx_output_names, use_trace=use_trace, dynamic_axes=dynamic_axes ) - # convert the model through the ONNX-TensorRT way trt_model = _onnx_trt_compile( ir_model, @@ -973,6 +990,8 @@ def scale_batch_size(input_shape: Sequence[int], scale_num: int): output_names=onnx_output_names, ) else: + ir_model = convert_to_torchscript(model, device=target_device, inputs=inputs, use_trace=use_trace) + ir_model.eval() # convert the model through the Torch-TensorRT way ir_model.to(target_device) with torch.no_grad(): @@ -1189,3 +1208,168 @@ def forward(self, x): if dtype == self.initial_type: x = x.to(self.initial_type) return x + + +def cast_tensor(x, from_dtype=torch.float16, to_dtype=torch.float32): + """ + Utility function to cast a single tensor from from_dtype to to_dtype + """ + return x.to(dtype=to_dtype) if x.dtype == from_dtype else x + + +def cast_all(x, from_dtype=torch.float16, to_dtype=torch.float32): + """ + Utility function to cast all tensors in a tuple from from_dtype to to_dtype + """ + if isinstance(x, torch.Tensor): + return cast_tensor(x, from_dtype=from_dtype, to_dtype=to_dtype) + else: + if isinstance(x, dict): + new_dict = {} + for k in x.keys(): + new_dict[k] = cast_all(x[k], from_dtype=from_dtype, to_dtype=to_dtype) + return new_dict + elif isinstance(x, tuple): + return tuple(cast_all(y, from_dtype=from_dtype, to_dtype=to_dtype) for y in x) + + +class CastToFloat(torch.nn.Module): + """ + Class used to add autocast protection for ONNX export + for forward methods with single return vaue + """ + + def __init__(self, mod): + super().__init__() + self.mod = mod + + def forward(self, x): + dtype = x.dtype + with torch.amp.autocast("cuda", enabled=False): + ret = self.mod.forward(x.to(torch.float32)).to(dtype) + return ret + + +class CastToFloatAll(torch.nn.Module): + """ + Class used to add autocast protection for ONNX export + for forward methods with multiple return values + """ + + def __init__(self, mod): + super().__init__() + self.mod = mod + + def forward(self, *args): + from_dtype = args[0].dtype + with torch.amp.autocast("cuda", enabled=False): + ret = self.mod.forward(*cast_all(args, from_dtype=from_dtype, to_dtype=torch.float32)) + return cast_all(ret, from_dtype=torch.float32, to_dtype=from_dtype) + + +def wrap_module(base_t: type[nn.Module], dest_t: type[nn.Module]) -> Callable[[nn.Module], nn.Module | None]: + """ + Generic function generator to replace base_t module with dest_t wrapper. + Args: + base_t : module type to replace + dest_t : destination module type + Returns: + swap function to replace base_t module with dest_t + """ + + def expansion_fn(mod: nn.Module) -> nn.Module | None: + out = dest_t(mod) + return out + + return expansion_fn + + +def simple_replace(base_t: type[nn.Module], dest_t: type[nn.Module]) -> Callable[[nn.Module], nn.Module | None]: + """ + Generic function generator to replace base_t module with dest_t. + base_t and dest_t should have same atrributes. No weights are copied. + Args: + base_t : module type to replace + dest_t : destination module type + Returns: + swap function to replace base_t module with dest_t + """ + + def expansion_fn(mod: nn.Module) -> nn.Module | None: + if not isinstance(mod, base_t): + return None + args = [getattr(mod, name, None) for name in mod.__constants__] + out = dest_t(*args) + return out + + return expansion_fn + + +def _swap_modules(model: nn.Module, mapping: dict[str, nn.Module]) -> nn.Module: + """ + This function swaps nested modules as specified by "dot paths" in mod with a desired replacement. This allows + for swapping nested modules through arbitrary levels if children + + NOTE: This occurs in place, if you want to preserve model then make sure to copy it first. + + """ + for path, new_mod in mapping.items(): + expanded_path = path.split(".") + parent_mod = model + for sub_path in expanded_path[:-1]: + submod = parent_mod._modules[sub_path] + if submod is None: + break + else: + parent_mod = submod + parent_mod._modules[expanded_path[-1]] = new_mod + + return model + + +def replace_modules_by_type( + model: nn.Module, expansions: dict[str, Callable[[nn.Module], nn.Module | None]] +) -> nn.Module: + """ + Top-level function to replace modules in model, specified by class name with a desired replacement. + NOTE: This occurs in place, if you want to preserve model then make sure to copy it first. + Args: + model : top level module + expansions : replacement dictionary: module class name -> replacement function generator + Returns: + model, possibly modified in-place + """ + mapping: dict[str, nn.Module] = {} + for name, m in model.named_modules(): + m_type = type(m).__name__ + if m_type in expansions: + # print (f"Found {m_type} in expansions ...") + swapped = expansions[m_type](m) + if swapped: + mapping[name] = swapped + + print(f"Swapped {len(mapping)} modules") + _swap_modules(model, mapping) + return model + + +def add_casts_around_norms(model: nn.Module) -> nn.Module: + """ + Top-level function to add cast wrappers around modules known to cause issues for FP16/autocast ONNX export + NOTE: This occurs in place, if you want to preserve model then make sure to copy it first. + Args: + model : top level module + Returns: + model, possibly modified in-place + """ + print("Adding casts around norms...") + cast_replacements = { + "BatchNorm1d": wrap_module(nn.BatchNorm1d, CastToFloat), + "BatchNorm2d": wrap_module(nn.BatchNorm2d, CastToFloat), + "BatchNorm3d": wrap_module(nn.BatchNorm2d, CastToFloat), + "LayerNorm": wrap_module(nn.LayerNorm, CastToFloat), + "InstanceNorm1d": wrap_module(nn.InstanceNorm1d, CastToFloat), + "InstanceNorm3d": wrap_module(nn.InstanceNorm3d, CastToFloat), + } + replace_modules_by_type(model, cast_replacements) + return model diff --git a/requirements-dev.txt b/requirements-dev.txt index 9aad0804e6..6d0ccd378a 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -59,3 +59,5 @@ nvidia-ml-py huggingface_hub pyamg>=5.0.0 git+https://github.com/facebookresearch/segment-anything.git@6fdee8f2727f4506cfbbe553e23b895e27956588 +onnx_graphsurgeon +polygraphy diff --git a/setup.cfg b/setup.cfg index 1ce4a3f34c..c97118d43a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -160,6 +160,9 @@ lpips = lpips==0.1.4 pynvml = nvidia-ml-py +polygraphy = + polygraphy + # # workaround https://github.com/Project-MONAI/MONAI/issues/5882 # MetricsReloaded = # MetricsReloaded @ git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded diff --git a/tests/min_tests.py b/tests/min_tests.py index f80d06f5d3..632355b5c6 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -186,6 +186,7 @@ def run_testsuit(): "test_torchvisiond", "test_transchex", "test_transformerblock", + "test_trt_compile", "test_unetr", "test_unetr_block", "test_vit", diff --git a/tests/test_config_parser.py b/tests/test_config_parser.py index cf1edc8f08..2b00c9f9d1 100644 --- a/tests/test_config_parser.py +++ b/tests/test_config_parser.py @@ -125,6 +125,22 @@ def __call__(self, a, b): [0, 4], ] +TEST_CASE_MERGE_JSON = ["""{"key1": [0], "key2": [0] }""", """{"key1": [1], "+key2": [4] }""", "json", [1], [0, 4]] + +TEST_CASE_MERGE_YAML = [ + """ + key1: 0 + key2: [0] + """, + """ + key1: 1 + +key2: [4] + """, + "yaml", + 1, + [0, 4], +] + class TestConfigParser(unittest.TestCase): @@ -357,6 +373,22 @@ def test_parse_json_warn(self, config_string, extension, expected_unique_val, ex self.assertEqual(parser.get_parsed_content("key#unique"), expected_unique_val) self.assertIn(parser.get_parsed_content("key#duplicate"), expected_duplicate_vals) + @parameterized.expand([TEST_CASE_MERGE_JSON, TEST_CASE_MERGE_YAML]) + @skipUnless(has_yaml, "Requires pyyaml") + def test_load_configs( + self, config_string, config_string2, extension, expected_overridden_val, expected_merged_vals + ): + with tempfile.TemporaryDirectory() as tempdir: + config_path1 = Path(tempdir) / f"config1.{extension}" + config_path2 = Path(tempdir) / f"config2.{extension}" + config_path1.write_text(config_string) + config_path2.write_text(config_string2) + + parser = ConfigParser.load_config_files([config_path1, config_path2]) + + self.assertEqual(parser["key1"], expected_overridden_val) + self.assertEqual(parser["key2"], expected_merged_vals) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_sure_loss.py b/tests/test_sure_loss.py index 903f9bd2ca..fb8f5dda72 100644 --- a/tests/test_sure_loss.py +++ b/tests/test_sure_loss.py @@ -65,7 +65,7 @@ def operator(x): loss_real = sure_loss_real(operator, x_real, y_pseudo_gt_real, complex_input=False) loss_complex = sure_loss_complex(operator, x_complex, y_pseudo_gt_complex, complex_input=True) - self.assertAlmostEqual(loss_real.item(), loss_complex.abs().item(), places=6) + self.assertAlmostEqual(loss_real.item(), loss_complex.abs().item(), places=5) if __name__ == "__main__": diff --git a/tests/test_trt_compile.py b/tests/test_trt_compile.py new file mode 100644 index 0000000000..21125d203f --- /dev/null +++ b/tests/test_trt_compile.py @@ -0,0 +1,140 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import tempfile +import unittest + +import torch +from parameterized import parameterized + +from monai.handlers import TrtHandler +from monai.networks import trt_compile +from monai.networks.nets import UNet, cell_sam_wrapper, vista3d132 +from monai.utils import optional_import +from tests.utils import skip_if_no_cuda, skip_if_quick, skip_if_windows + +trt, trt_imported = optional_import("tensorrt") +polygraphy, polygraphy_imported = optional_import("polygraphy") +build_sam_vit_b, has_sam = optional_import("segment_anything.build_sam", name="build_sam_vit_b") + +TEST_CASE_1 = ["fp32"] +TEST_CASE_2 = ["fp16"] + + +@skip_if_windows +@skip_if_no_cuda +@skip_if_quick +@unittest.skipUnless(trt_imported, "tensorrt is required") +@unittest.skipUnless(polygraphy_imported, "polygraphy is required") +class TestTRTCompile(unittest.TestCase): + + def setUp(self): + self.gpu_device = torch.cuda.current_device() + + def tearDown(self): + current_device = torch.cuda.current_device() + if current_device != self.gpu_device: + torch.cuda.set_device(self.gpu_device) + + def test_handler(self): + from ignite.engine import Engine + + net1 = torch.nn.Sequential(*[torch.nn.PReLU(), torch.nn.PReLU()]) + data1 = net1.state_dict() + data1["0.weight"] = torch.tensor([0.1]) + data1["1.weight"] = torch.tensor([0.2]) + net1.load_state_dict(data1) + net1.cuda() + + with tempfile.TemporaryDirectory() as tempdir: + engine = Engine(lambda e, b: None) + args = {"method": "torch_trt"} + TrtHandler(net1, tempdir + "/trt_handler", args=args).attach(engine) + engine.run([0] * 8, max_epochs=1) + self.assertIsNotNone(net1._trt_compiler) + self.assertIsNone(net1._trt_compiler.engine) + net1.forward(torch.tensor([[0.0, 1.0], [1.0, 2.0]], device="cuda")) + self.assertIsNotNone(net1._trt_compiler.engine) + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + def test_unet_value(self, precision): + model = UNet( + spatial_dims=3, + in_channels=1, + out_channels=2, + channels=(2, 2, 4, 8, 4), + strides=(2, 2, 2, 2), + num_res_units=2, + norm="batch", + ).cuda() + with torch.no_grad(), tempfile.TemporaryDirectory() as tmpdir: + model.eval() + input_example = torch.randn(2, 1, 96, 96, 96).cuda() + output_example = model(input_example) + args: dict = {"builder_optimization_level": 1} + trt_compile( + model, + f"{tmpdir}/test_unet_trt_compile", + args={"precision": precision, "build_args": args, "dynamic_batchsize": [1, 4, 8]}, + ) + self.assertIsNone(model._trt_compiler.engine) + trt_output = model(input_example) + # Check that lazy TRT build succeeded + self.assertIsNotNone(model._trt_compiler.engine) + torch.testing.assert_close(trt_output, output_example, rtol=0.01, atol=0.01) + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @unittest.skipUnless(has_sam, "Requires SAM installation") + def test_cell_sam_wrapper_value(self, precision): + model = cell_sam_wrapper.CellSamWrapper(checkpoint=None).to("cuda") + with torch.no_grad(), tempfile.TemporaryDirectory() as tmpdir: + model.eval() + input_example = torch.randn(1, 3, 128, 128).to("cuda") + output_example = model(input_example) + trt_compile( + model, + f"{tmpdir}/test_cell_sam_wrapper_trt_compile", + args={"precision": precision, "dynamic_batchsize": [1, 1, 1]}, + ) + self.assertIsNone(model._trt_compiler.engine) + trt_output = model(input_example) + # Check that lazy TRT build succeeded + self.assertIsNotNone(model._trt_compiler.engine) + torch.testing.assert_close(trt_output, output_example, rtol=0.01, atol=0.01) + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + def test_vista3d(self, precision): + model = vista3d132(in_channels=1).to("cuda") + with torch.no_grad(), tempfile.TemporaryDirectory() as tmpdir: + model.eval() + input_example = torch.randn(1, 1, 64, 64, 64).to("cuda") + output_example = model(input_example) + model = trt_compile( + model, + f"{tmpdir}/test_vista3d_trt_compile", + args={"precision": precision, "dynamic_batchsize": [1, 1, 1]}, + submodule=["image_encoder.encoder", "class_head"], + ) + self.assertIsNotNone(model.image_encoder.encoder._trt_compiler) + self.assertIsNotNone(model.class_head._trt_compiler) + trt_output = model.forward(input_example) + # Check that lazy TRT build succeeded + # TODO: set up input_example in such a way that image_encoder.encoder and class_head are called + # and uncomment the asserts below + # self.assertIsNotNone(model.image_encoder.encoder._trt_compiler.engine) + # self.assertIsNotNone(model.class_head._trt_compiler.engine) + torch.testing.assert_close(trt_output, output_example, rtol=0.01, atol=0.01) + + +if __name__ == "__main__": + unittest.main() From 7219ee7db771930179d9f219c59463c2c6d227ef Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Mon, 2 Sep 2024 13:12:15 +0800 Subject: [PATCH 08/15] Add box and points convert transform (#8053) Add box and points convert transform Cherrypick ApplyTransformToPoints ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: Mingxin Zheng <18563433+mingxin-zheng@users.noreply.github.com> --- docs/source/transforms.rst | 36 ++++++ monai/transforms/__init__.py | 12 ++ monai/transforms/spatial/array.py | 44 +++++++ monai/transforms/spatial/dictionary.py | 60 ++++++++++ monai/transforms/spatial/functional.py | 71 +++++++++++- monai/transforms/utility/array.py | 140 ++++++++++++++++++++++- monai/transforms/utility/dictionary.py | 74 ++++++++++++ monai/transforms/utils.py | 23 ++++ monai/utils/__init__.py | 1 + monai/utils/type_conversion.py | 8 ++ tests/test_apply_transform_to_points.py | 81 +++++++++++++ tests/test_apply_transform_to_pointsd.py | 133 +++++++++++++++++++++ tests/test_convert_box_points.py | 121 ++++++++++++++++++++ 13 files changed, 799 insertions(+), 5 deletions(-) create mode 100644 tests/test_apply_transform_to_points.py create mode 100644 tests/test_apply_transform_to_pointsd.py create mode 100644 tests/test_convert_box_points.py diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 3e45d899ec..41bb4ae79a 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -976,6 +976,18 @@ Spatial :members: :special-members: __call__ +`ConvertBoxToPoints` +"""""""""""""""""""" +.. autoclass:: ConvertBoxToPoints + :members: + :special-members: __call__ + +`ConvertPointsToBoxes` +"""""""""""""""""""""" +.. autoclass:: ConvertPointsToBoxes + :members: + :special-members: __call__ + Smooth Field ^^^^^^^^^^^^ @@ -1222,6 +1234,12 @@ Utility :members: :special-members: __call__ +`ApplyTransformToPoints` +"""""""""""""""""""""""" +.. autoclass:: ApplyTransformToPoints + :members: + :special-members: __call__ + Dictionary Transforms --------------------- @@ -1973,6 +1991,18 @@ Spatial (Dict) :members: :special-members: __call__ +`ConvertBoxToPointsd` +""""""""""""""""""""" +.. autoclass:: ConvertBoxToPointsd + :members: + :special-members: __call__ + +`ConvertPointsToBoxesd` +""""""""""""""""""""""" +.. autoclass:: ConvertPointsToBoxesd + :members: + :special-members: __call__ + Smooth Field (Dict) ^^^^^^^^^^^^^^^^^^^ @@ -2277,6 +2307,12 @@ Utility (Dict) :members: :special-members: __call__ +`ApplyTransformToPointsd` +""""""""""""""""""""""""" +.. autoclass:: ApplyTransformToPointsd + :members: + :special-members: __call__ + MetaTensor ^^^^^^^^^^ diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index f37016e63f..2cdd965c91 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -396,6 +396,8 @@ from .spatial.array import ( Affine, AffineGrid, + ConvertBoxToPoints, + ConvertPointsToBoxes, Flip, GridDistortion, GridPatch, @@ -427,6 +429,12 @@ Affined, AffineD, AffineDict, + ConvertBoxToPointsd, + ConvertBoxToPointsD, + ConvertBoxToPointsDict, + ConvertPointsToBoxesd, + ConvertPointsToBoxesD, + ConvertPointsToBoxesDict, Flipd, FlipD, FlipDict, @@ -503,6 +511,7 @@ from .utility.array import ( AddCoordinateChannels, AddExtremePointsChannel, + ApplyTransformToPoints, AsChannelLast, CastToType, ClassesToIndices, @@ -542,6 +551,9 @@ AddExtremePointsChanneld, AddExtremePointsChannelD, AddExtremePointsChannelDict, + ApplyTransformToPointsd, + ApplyTransformToPointsD, + ApplyTransformToPointsDict, AsChannelLastd, AsChannelLastD, AsChannelLastDict, diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 3739a83e71..6e39fb2e19 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -25,6 +25,7 @@ from monai.config import USE_COMPILED, DtypeLike from monai.config.type_definitions import NdarrayOrTensor +from monai.data.box_utils import BoxMode, StandardMode from monai.data.meta_obj import get_track_meta, set_track_meta from monai.data.meta_tensor import MetaTensor from monai.data.utils import AFFINE_TOL, affine_to_spacing, compute_shape_offset, iter_patch, to_affine_nd, zoom_affine @@ -34,6 +35,8 @@ from monai.transforms.inverse import InvertibleTransform from monai.transforms.spatial.functional import ( affine_func, + convert_box_to_points, + convert_points_to_box, flip, orientation, resize, @@ -3544,3 +3547,44 @@ def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor: else: return img + + +class ConvertBoxToPoints(Transform): + """ + Converts an axis-aligned bounding box to points. It can automatically convert the boxes to the points based on the box mode. + Bounding boxes of the shape (N, C) for N boxes. C is [x1, y1, x2, y2] for 2D or [x1, y1, z1, x2, y2, z2] for 3D for each box. + Return shape will be (N, 4, 2) for 2D or (N, 8, 3) for 3D. + """ + + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __init__(self, mode: str | BoxMode | type[BoxMode] | None = None) -> None: + """ + Args: + mode: the mode of the box, can be a string, a BoxMode instance or a BoxMode class. Defaults to StandardMode. + """ + super().__init__() + self.mode = StandardMode if mode is None else mode + + def __call__(self, data: Any): + data = convert_to_tensor(data, track_meta=get_track_meta()) + points = convert_box_to_points(data, mode=self.mode) + return convert_to_dst_type(points, data)[0] + + +class ConvertPointsToBoxes(Transform): + """ + Converts points to an axis-aligned bounding box. + Points representing the corners of the bounding box. Shape (N, 8, 3) for the 8 corners of a 3D cuboid or + (N, 4, 2) for the 4 corners of a 2D rectangle. + """ + + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __init__(self) -> None: + super().__init__() + + def __call__(self, data: Any): + data = convert_to_tensor(data, track_meta=get_track_meta()) + box = convert_points_to_box(data) + return convert_to_dst_type(box, data)[0] diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 01fadcfb69..82dee15c7c 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -26,6 +26,7 @@ from monai.config import DtypeLike, KeysCollection, SequenceStr from monai.config.type_definitions import NdarrayOrTensor +from monai.data.box_utils import BoxMode, StandardMode from monai.data.meta_obj import get_track_meta from monai.data.meta_tensor import MetaTensor from monai.networks.layers.simplelayers import GaussianFilter @@ -33,6 +34,8 @@ from monai.transforms.inverse import InvertibleTransform from monai.transforms.spatial.array import ( Affine, + ConvertBoxToPoints, + ConvertPointsToBoxes, Flip, GridDistortion, GridPatch, @@ -2611,6 +2614,61 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N return d +class ConvertBoxToPointsd(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.ConvertBoxToPoints`. + """ + + backend = ConvertBoxToPoints.backend + + def __init__( + self, + keys: KeysCollection, + point_key="points", + mode: str | BoxMode | type[BoxMode] | None = StandardMode, + allow_missing_keys: bool = False, + ): + """ + Args: + keys: keys of the corresponding items to be transformed. + point_key: key to store the point data. + mode: the mode of the input boxes. Defaults to StandardMode. + allow_missing_keys: don't raise exception if key is missing. + """ + super().__init__(keys, allow_missing_keys) + self.point_key = point_key + self.converter = ConvertBoxToPoints(mode=mode) + + def __call__(self, data): + d = dict(data) + for key in self.key_iterator(d): + data[self.point_key] = self.converter(d[key]) + return data + + +class ConvertPointsToBoxesd(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.ConvertPointsToBoxes`. + """ + + def __init__(self, keys: KeysCollection, box_key="box", allow_missing_keys: bool = False): + """ + Args: + keys: keys of the corresponding items to be transformed. + box_key: key to store the box data. + allow_missing_keys: don't raise exception if key is missing. + """ + super().__init__(keys, allow_missing_keys) + self.box_key = box_key + self.converter = ConvertPointsToBoxes() + + def __call__(self, data): + d = dict(data) + for key in self.key_iterator(d): + data[self.box_key] = self.converter(d[key]) + return data + + SpatialResampleD = SpatialResampleDict = SpatialResampled ResampleToMatchD = ResampleToMatchDict = ResampleToMatchd SpacingD = SpacingDict = Spacingd @@ -2635,3 +2693,5 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N GridPatchD = GridPatchDict = GridPatchd RandGridPatchD = RandGridPatchDict = RandGridPatchd RandSimulateLowResolutionD = RandSimulateLowResolutionDict = RandSimulateLowResolutiond +ConvertBoxToPointsD = ConvertBoxToPointsDict = ConvertBoxToPointsd +ConvertPointsToBoxesD = ConvertPointsToBoxesDict = ConvertPointsToBoxesd diff --git a/monai/transforms/spatial/functional.py b/monai/transforms/spatial/functional.py index 22726f06a5..b693e7d023 100644 --- a/monai/transforms/spatial/functional.py +++ b/monai/transforms/spatial/functional.py @@ -24,6 +24,7 @@ import monai from monai.config import USE_COMPILED from monai.config.type_definitions import NdarrayOrTensor +from monai.data.box_utils import get_boxmode from monai.data.meta_obj import get_track_meta from monai.data.meta_tensor import MetaTensor from monai.data.utils import AFFINE_TOL, compute_shape_offset, to_affine_nd @@ -32,7 +33,7 @@ from monai.transforms.intensity.array import GaussianSmooth from monai.transforms.inverse import TraceableTransform from monai.transforms.utils import create_rotate, create_translate, resolves_modes, scale_affine -from monai.transforms.utils_pytorch_numpy_unification import allclose +from monai.transforms.utils_pytorch_numpy_unification import allclose, concatenate, stack from monai.utils import ( LazyAttr, TraceKeys, @@ -610,3 +611,71 @@ def affine_func( out = _maybe_new_metatensor(img, dtype=torch.float32, device=resampler.device) out = out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out return out if image_only else (out, affine) + + +def convert_box_to_points(bbox, mode): + """ + Converts an axis-aligned bounding box to points. + + Args: + mode: The mode specifying how to interpret the bounding box. + bbox: Bounding boxes of the shape (N, C) for N boxes. C is [x1, y1, x2, y2] for 2D or [x1, y1, z1, x2, y2, z2] + for 3D for each box. Return shape will be (N, 4, 2) for 2D or (N, 8, 3) for 3D. + + Returns: + sequence of points representing the corners of the bounding box. + """ + + mode = get_boxmode(mode) + + points_list = [] + for _num in range(bbox.shape[0]): + corners = mode.boxes_to_corners(bbox[_num : _num + 1]) + if len(corners) == 4: + points_list.append( + concatenate( + [ + concatenate([corners[0], corners[1]], axis=1), + concatenate([corners[2], corners[1]], axis=1), + concatenate([corners[2], corners[3]], axis=1), + concatenate([corners[0], corners[3]], axis=1), + ], + axis=0, + ) + ) + else: + points_list.append( + concatenate( + [ + concatenate([corners[0], corners[1], corners[2]], axis=1), + concatenate([corners[3], corners[1], corners[2]], axis=1), + concatenate([corners[3], corners[4], corners[2]], axis=1), + concatenate([corners[0], corners[4], corners[2]], axis=1), + concatenate([corners[0], corners[1], corners[5]], axis=1), + concatenate([corners[3], corners[1], corners[5]], axis=1), + concatenate([corners[3], corners[4], corners[5]], axis=1), + concatenate([corners[0], corners[4], corners[5]], axis=1), + ], + axis=0, + ) + ) + + return stack(points_list, dim=0) + + +def convert_points_to_box(points): + """ + Converts points to an axis-aligned bounding box. + + Args: + points: Points representing the corners of the bounding box. Shape (N, 8, 3) for the 8 corners of + a 3D cuboid or (N, 4, 2) for the 4 corners of a 2D rectangle. + """ + from monai.transforms.utils_pytorch_numpy_unification import max, min + + mins = min(points, dim=1) + maxs = max(points, dim=1) + # Concatenate the min and max values to get the bounding boxes + bboxes = concatenate([mins, maxs], axis=1) + + return bboxes diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 5dfbcb0e91..fee546bea3 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -31,7 +31,7 @@ from monai.config.type_definitions import NdarrayOrTensor from monai.data.meta_obj import get_track_meta from monai.data.meta_tensor import MetaTensor -from monai.data.utils import is_no_channel, no_collation +from monai.data.utils import is_no_channel, no_collation, orientation_ras_lps from monai.networks.layers.simplelayers import ( ApplyFilter, EllipticalFilter, @@ -42,16 +42,17 @@ SharpenFilter, median_filter, ) -from monai.transforms.inverse import InvertibleTransform +from monai.transforms.inverse import InvertibleTransform, TraceableTransform from monai.transforms.traits import MultiSampleTrait from monai.transforms.transform import Randomizable, RandomizableTrait, RandomizableTransform, Transform from monai.transforms.utils import ( + apply_affine_to_points, extreme_points_to_image, get_extreme_points, map_binary_to_indices, map_classes_to_indices, ) -from monai.transforms.utils_pytorch_numpy_unification import concatenate, in1d, moveaxis, unravel_indices +from monai.transforms.utils_pytorch_numpy_unification import concatenate, in1d, linalg_inv, moveaxis, unravel_indices from monai.utils import ( MetaKeys, TraceKeys, @@ -66,7 +67,7 @@ ) from monai.utils.enums import TransformBackends from monai.utils.misc import is_module_ver_at_least -from monai.utils.type_conversion import convert_to_dst_type, get_equivalent_dtype +from monai.utils.type_conversion import convert_to_dst_type, get_dtype_string, get_equivalent_dtype PILImageImage, has_pil = optional_import("PIL.Image", name="Image") pil_image_fromarray, _ = optional_import("PIL.Image", name="fromarray") @@ -106,6 +107,7 @@ "ToCupy", "ImageFilter", "RandImageFilter", + "ApplyTransformToPoints", ] @@ -1715,3 +1717,133 @@ def __call__(self, img: NdarrayOrTensor, meta_dict: Mapping | None = None) -> Nd if self._do_transform: img = self.filter(img) return img + + +class ApplyTransformToPoints(InvertibleTransform, Transform): + """ + Transform points between image coordinates and world coordinates. + The input coordinates are assumed to be in the shape (C, N, 2 or 3), where C represents the number of channels + and N denotes the number of points. It will return a tensor with the same shape as the input. + + Args: + dtype: The desired data type for the output. + affine: A 3x3 or 4x4 affine transformation matrix applied to points. This matrix typically originates + from the image. For 2D points, a 3x3 matrix can be provided, avoiding the need to add an unnecessary + Z dimension. While a 4x4 matrix is required for 3D transformations, it's important to note that when + applying a 4x4 matrix to 2D points, the additional dimensions are handled accordingly. + The matrix is always converted to float64 for computation, which can be computationally + expensive when applied to a large number of points. + If None, will try to use the affine matrix from the input data. + invert_affine: Whether to invert the affine transformation matrix applied to the points. Defaults to ``True``. + Typically, the affine matrix is derived from an image and represents its location in world space, + while the points are in world coordinates. A value of ``True`` represents transforming these + world space coordinates to the image's coordinate space, and ``False`` the inverse of this operation. + affine_lps_to_ras: Defaults to ``False``. Set to `True` if your point data is in the RAS coordinate system + or you're using `ITKReader` with `affine_lps_to_ras=True`. + This ensures the correct application of the affine transformation between LPS (left-posterior-superior) + and RAS (right-anterior-superior) coordinate systems. This argument ensures the points and the affine + matrix are in the same coordinate system. + + Use Cases: + - Transforming points between world space and image space, and vice versa. + - Automatically handling inverse transformations between image space and world space. + - If points have an existing affine transformation, the class computes and + applies the required delta affine transformation. + + """ + + def __init__( + self, + dtype: DtypeLike | torch.dtype | None = None, + affine: torch.Tensor | None = None, + invert_affine: bool = True, + affine_lps_to_ras: bool = False, + ) -> None: + self.dtype = dtype + self.affine = affine + self.invert_affine = invert_affine + self.affine_lps_to_ras = affine_lps_to_ras + + def transform_coordinates( + self, data: torch.Tensor, affine: torch.Tensor | None = None + ) -> tuple[torch.Tensor, dict]: + """ + Transform coordinates using an affine transformation matrix. + + Args: + data: The input coordinates are assumed to be in the shape (C, N, 2 or 3), + where C represents the number of channels and N denotes the number of points. + affine: 3x3 or 4x4 affine transformation matrix. The matrix is always converted to float64 for computation, + which can be computationally expensive when applied to a large number of points. + + Returns: + Transformed coordinates. + """ + data = convert_to_tensor(data, track_meta=get_track_meta()) + # applied_affine is the affine transformation matrix that has already been applied to the point data + applied_affine = getattr(data, "affine", None) + + if affine is None and self.invert_affine: + raise ValueError("affine must be provided when invert_affine is True.") + + affine = applied_affine if affine is None else affine + affine = convert_data_type(affine, dtype=torch.float64)[0] # always convert to float64 for affine + original_affine: torch.Tensor = affine + if self.affine_lps_to_ras: + affine = orientation_ras_lps(affine) + + # the final affine transformation matrix that will be applied to the point data + _affine: torch.Tensor = affine + if self.invert_affine: + _affine = linalg_inv(affine) + if applied_affine is not None: + # consider the affine transformation already applied to the data in the world space + # and compute delta affine + _affine = _affine @ linalg_inv(applied_affine) + out = apply_affine_to_points(data, _affine, dtype=self.dtype) + + extra_info = { + "invert_affine": self.invert_affine, + "dtype": get_dtype_string(self.dtype), + "image_affine": original_affine, # record for inverse operation + "affine_lps_to_ras": self.affine_lps_to_ras, + } + xform: torch.Tensor = original_affine if self.invert_affine else linalg_inv(original_affine) + meta_info = TraceableTransform.track_transform_meta( + data, affine=xform, extra_info=extra_info, transform_info=self.get_transform_info() + ) + + return out, meta_info + + def __call__(self, data: torch.Tensor, affine: torch.Tensor | None = None): + """ + Args: + data: The input coordinates are assumed to be in the shape (C, N, 2 or 3), + where C represents the number of channels and N denotes the number of points. + affine: A 3x3 or 4x4 affine transformation matrix, this argument will take precedence over ``self.affine``. + """ + if data.ndim != 3 or data.shape[-1] not in (2, 3): + raise ValueError(f"data should be in shape (C, N, 2 or 3), got {data.shape}.") + affine = self.affine if affine is None else affine + if affine is not None and affine.shape not in ((3, 3), (4, 4)): + raise ValueError(f"affine should be in shape (3, 3) or (4, 4), got {affine.shape}.") + + out, meta_info = self.transform_coordinates(data, affine) + + return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out + + def inverse(self, data: torch.Tensor) -> torch.Tensor: + transform = self.pop_transform(data) + # Create inverse transform + dtype = transform[TraceKeys.EXTRA_INFO]["dtype"] + invert_affine = not transform[TraceKeys.EXTRA_INFO]["invert_affine"] + affine = transform[TraceKeys.EXTRA_INFO]["image_affine"] + affine_lps_to_ras = transform[TraceKeys.EXTRA_INFO]["affine_lps_to_ras"] + inverse_transform = ApplyTransformToPoints( + dtype=dtype, invert_affine=invert_affine, affine_lps_to_ras=affine_lps_to_ras + ) + # Apply inverse + with inverse_transform.trace_transform(False): + data = inverse_transform(data, affine) + + return data diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 2475060f4e..1279ca93ab 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -35,6 +35,7 @@ from monai.transforms.utility.array import ( AddCoordinateChannels, AddExtremePointsChannel, + ApplyTransformToPoints, AsChannelLast, CastToType, ClassesToIndices, @@ -180,6 +181,9 @@ "ClassesToIndicesd", "ClassesToIndicesD", "ClassesToIndicesDict", + "ApplyTransformToPointsd", + "ApplyTransformToPointsD", + "ApplyTransformToPointsDict", ] DEFAULT_POST_FIX = PostFix.meta() @@ -1744,6 +1748,75 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N return d +class ApplyTransformToPointsd(MapTransform, InvertibleTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.ApplyTransformToPoints`. + The input coordinates are assumed to be in the shape (C, N, 2 or 3), + where C represents the number of channels and N denotes the number of points. + The output has the same shape as the input. + + Args: + keys: keys of the corresponding items to be transformed. + See also: monai.transforms.MapTransform + refer_key: The key of the reference item used for transformation. + It can directly refer to an affine or an image from which the affine can be derived. + dtype: The desired data type for the output. + affine: A 3x3 or 4x4 affine transformation matrix applied to points. This matrix typically originates + from the image. For 2D points, a 3x3 matrix can be provided, avoiding the need to add an unnecessary + Z dimension. While a 4x4 matrix is required for 3D transformations, it's important to note that when + applying a 4x4 matrix to 2D points, the additional dimensions are handled accordingly. + The matrix is always converted to float64 for computation, which can be computationally + expensive when applied to a large number of points. + If None, will try to use the affine matrix from the refer data. + invert_affine: Whether to invert the affine transformation matrix applied to the points. Defaults to ``True``. + Typically, the affine matrix is derived from the image, while the points are in world coordinates. + If you want to align the points with the image, set this to ``True``. Otherwise, set it to ``False``. + affine_lps_to_ras: Defaults to ``False``. Set to `True` if your point data is in the RAS coordinate system + or you're using `ITKReader` with `affine_lps_to_ras=True`. + This ensures the correct application of the affine transformation between LPS (left-posterior-superior) + and RAS (right-anterior-superior) coordinate systems. This argument ensures the points and the affine + matrix are in the same coordinate system. + allow_missing_keys: Don't raise exception if key is missing. + """ + + def __init__( + self, + keys: KeysCollection, + refer_key: str | None = None, + dtype: DtypeLike | torch.dtype = torch.float64, + affine: torch.Tensor | None = None, + invert_affine: bool = True, + affine_lps_to_ras: bool = False, + allow_missing_keys: bool = False, + ): + MapTransform.__init__(self, keys, allow_missing_keys) + self.refer_key = refer_key + self.converter = ApplyTransformToPoints( + dtype=dtype, affine=affine, invert_affine=invert_affine, affine_lps_to_ras=affine_lps_to_ras + ) + + def __call__(self, data: Mapping[Hashable, torch.Tensor]): + d = dict(data) + if self.refer_key is not None: + if self.refer_key in d: + refer_data = d[self.refer_key] + else: + raise KeyError(f"The refer_key '{self.refer_key}' is not found in the data.") + else: + refer_data = None + affine = getattr(refer_data, "affine", refer_data) + for key in self.key_iterator(d): + coords = d[key] + d[key] = self.converter(coords, affine) + return d + + def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + d = dict(data) + for key in self.key_iterator(d): + d[key] = self.converter.inverse(d[key]) + return d + + RandImageFilterD = RandImageFilterDict = RandImageFilterd ImageFilterD = ImageFilterDict = ImageFilterd IdentityD = IdentityDict = Identityd @@ -1784,3 +1857,4 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N RandCuCIMD = RandCuCIMDict = RandCuCIMd AddCoordinateChannelsD = AddCoordinateChannelsDict = AddCoordinateChannelsd FlattenSubKeysD = FlattenSubKeysDict = FlattenSubKeysd +ApplyTransformToPointsD = ApplyTransformToPointsDict = ApplyTransformToPointsd diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 1d1f070568..b1f1bbd0f6 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -27,6 +27,7 @@ import monai from monai.config import DtypeLike, IndexSelection from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor +from monai.data.utils import to_affine_nd from monai.networks.layers import GaussianFilter from monai.networks.utils import meshgrid_ij from monai.transforms.compose import Compose @@ -35,6 +36,7 @@ from monai.transforms.utils_pytorch_numpy_unification import ( any_np_pt, ascontiguousarray, + concatenate, cumsum, isfinite, nonzero, @@ -2555,5 +2557,26 @@ def distance_transform_edt( return convert_data_type(r_vals[0] if len(r_vals) == 1 else r_vals, output_type=type(img), device=device)[0] +def apply_affine_to_points(data: torch.Tensor, affine: torch.Tensor, dtype: DtypeLike | torch.dtype | None = None): + """ + apply affine transformation to a set of points. + + Args: + data: input data to apply affine transformation, should be a tensor of shape (C, N, 2 or 3), + where C represents the number of channels and N denotes the number of points. + affine: affine matrix to be applied, should be a tensor of shape (3, 3) or (4, 4). + dtype: output data dtype. + """ + data_: torch.Tensor = convert_to_tensor(data, track_meta=False, dtype=torch.float64) + affine = to_affine_nd(data_.shape[-1], affine) + + homogeneous: torch.Tensor = concatenate((data_, torch.ones((data_.shape[0], data_.shape[1], 1))), axis=2) # type: ignore + transformed_homogeneous = torch.matmul(homogeneous, affine.T) + transformed_coordinates = transformed_homogeneous[:, :, :-1] + out, *_ = convert_to_dst_type(transformed_coordinates, data, dtype=dtype) + + return out + + if __name__ == "__main__": print_transform_backends() diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 03fa1ceed1..4e36e3cd47 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -148,6 +148,7 @@ dtype_numpy_to_torch, dtype_torch_to_numpy, get_dtype, + get_dtype_string, get_equivalent_dtype, get_numpy_dtype_from_string, get_torch_dtype_from_string, diff --git a/monai/utils/type_conversion.py b/monai/utils/type_conversion.py index e4f97fc4a6..420e935b33 100644 --- a/monai/utils/type_conversion.py +++ b/monai/utils/type_conversion.py @@ -33,6 +33,7 @@ "get_equivalent_dtype", "convert_data_type", "get_dtype", + "get_dtype_string", "convert_to_cupy", "convert_to_numpy", "convert_to_tensor", @@ -102,6 +103,13 @@ def get_dtype(data: Any) -> DtypeLike | torch.dtype: return type(data) +def get_dtype_string(dtype: DtypeLike | torch.dtype) -> str: + """Get a string representation of the dtype.""" + if isinstance(dtype, torch.dtype): + return str(dtype)[6:] + return str(dtype)[3:] + + def convert_to_tensor( data: Any, dtype: DtypeLike | torch.dtype = None, diff --git a/tests/test_apply_transform_to_points.py b/tests/test_apply_transform_to_points.py new file mode 100644 index 0000000000..0c16603996 --- /dev/null +++ b/tests/test_apply_transform_to_points.py @@ -0,0 +1,81 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.data import MetaTensor +from monai.transforms.utility.array import ApplyTransformToPoints +from monai.utils import set_determinism + +set_determinism(seed=0) + +DATA_2D = torch.rand(1, 64, 64) +DATA_3D = torch.rand(1, 64, 64, 64) +POINT_2D_WORLD = torch.tensor([[[2, 2], [2, 4], [4, 6]]]) +POINT_2D_IMAGE = torch.tensor([[[1, 1], [1, 2], [2, 3]]]) +POINT_2D_IMAGE_RAS = torch.tensor([[[-1, -1], [-1, -2], [-2, -3]]]) +POINT_3D_WORLD = torch.tensor([[[2, 4, 6], [8, 10, 12]], [[14, 16, 18], [20, 22, 24]]]) +POINT_3D_IMAGE = torch.tensor([[[-8, 8, 6], [-2, 14, 12]], [[4, 20, 18], [10, 26, 24]]]) +POINT_3D_IMAGE_RAS = torch.tensor([[[-12, 0, 6], [-18, -6, 12]], [[-24, -12, 18], [-30, -18, 24]]]) +AFFINE_1 = torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]) +AFFINE_2 = torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]]) + +TEST_CASES = [ + [MetaTensor(DATA_2D, affine=AFFINE_1), POINT_2D_WORLD, None, True, False, POINT_2D_IMAGE], + [None, MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), None, False, False, POINT_2D_WORLD], + [None, MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), AFFINE_1, False, False, POINT_2D_WORLD], + [MetaTensor(DATA_2D, affine=AFFINE_1), POINT_2D_WORLD, None, True, True, POINT_2D_IMAGE_RAS], + [MetaTensor(DATA_3D, affine=AFFINE_2), POINT_3D_WORLD, None, True, False, POINT_3D_IMAGE], + [ + MetaTensor(DATA_3D, affine=AFFINE_2), + MetaTensor(POINT_3D_IMAGE, affine=AFFINE_2), + None, + False, + False, + POINT_3D_WORLD, + ], + [MetaTensor(DATA_3D, affine=AFFINE_2), POINT_3D_WORLD, None, True, True, POINT_3D_IMAGE_RAS], +] + +TEST_CASES_WRONG = [ + [POINT_2D_WORLD, True, None], + [POINT_2D_WORLD.unsqueeze(0), False, None], + [POINT_3D_WORLD[..., 0:1], False, None], + [POINT_3D_WORLD, False, torch.tensor([[[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]]])], +] + + +class TestCoordinateTransform(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_transform_coordinates(self, image, points, affine, invert_affine, affine_lps_to_ras, expected_output): + transform = ApplyTransformToPoints( + dtype=torch.int64, affine=affine, invert_affine=invert_affine, affine_lps_to_ras=affine_lps_to_ras + ) + affine = image.affine if image is not None else None + output = transform(points, affine) + self.assertTrue(torch.allclose(output, expected_output)) + invert_out = transform.inverse(output) + self.assertTrue(torch.allclose(invert_out, points)) + + @parameterized.expand(TEST_CASES_WRONG) + def test_wrong_input(self, input, invert_affine, affine): + transform = ApplyTransformToPoints(dtype=torch.int64, invert_affine=invert_affine) + with self.assertRaises(ValueError): + transform(input, affine) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_apply_transform_to_pointsd.py b/tests/test_apply_transform_to_pointsd.py new file mode 100644 index 0000000000..4cedfa9d66 --- /dev/null +++ b/tests/test_apply_transform_to_pointsd.py @@ -0,0 +1,133 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.data import MetaTensor +from monai.transforms.utility.dictionary import ApplyTransformToPointsd +from monai.utils import set_determinism + +set_determinism(seed=0) + +DATA_2D = torch.rand(1, 64, 64) +DATA_3D = torch.rand(1, 64, 64, 64) +POINT_2D_WORLD = torch.tensor([[[2, 2], [2, 4], [4, 6]]]) +POINT_2D_IMAGE = torch.tensor([[[1, 1], [1, 2], [2, 3]]]) +POINT_2D_IMAGE_RAS = torch.tensor([[[-1, -1], [-1, -2], [-2, -3]]]) +POINT_3D_WORLD = torch.tensor([[[2, 4, 6], [8, 10, 12]], [[14, 16, 18], [20, 22, 24]]]) +POINT_3D_IMAGE = torch.tensor([[[-8, 8, 6], [-2, 14, 12]], [[4, 20, 18], [10, 26, 24]]]) +POINT_3D_IMAGE_RAS = torch.tensor([[[-12, 0, 6], [-18, -6, 12]], [[-24, -12, 18], [-30, -18, 24]]]) + +TEST_CASES = [ + [ + MetaTensor(DATA_2D, affine=torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])), + POINT_2D_WORLD, + None, + True, + False, + POINT_2D_IMAGE, + ], + [ + None, + MetaTensor(POINT_2D_IMAGE, affine=torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])), + None, + False, + False, + POINT_2D_WORLD, + ], + [ + None, + MetaTensor(POINT_2D_IMAGE, affine=torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])), + torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]), + False, + False, + POINT_2D_WORLD, + ], + [ + MetaTensor(DATA_2D, affine=torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])), + POINT_2D_WORLD, + None, + True, + True, + POINT_2D_IMAGE_RAS, + ], + [ + MetaTensor(DATA_3D, affine=torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]])), + POINT_3D_WORLD, + None, + True, + False, + POINT_3D_IMAGE, + ], + ["affine", POINT_3D_WORLD, None, True, False, POINT_3D_IMAGE], + [ + MetaTensor(DATA_3D, affine=torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]])), + MetaTensor(POINT_3D_IMAGE, affine=torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]])), + None, + False, + False, + POINT_3D_WORLD, + ], + [ + MetaTensor(DATA_3D, affine=torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]])), + POINT_3D_WORLD, + None, + True, + True, + POINT_3D_IMAGE_RAS, + ], +] + +TEST_CASES_WRONG = [ + [POINT_2D_WORLD, True, None], + [POINT_2D_WORLD.unsqueeze(0), False, None], + [POINT_3D_WORLD[..., 0:1], False, None], + [POINT_3D_WORLD, False, torch.tensor([[[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]]])], +] + + +class TestCoordinateTransform(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_transform_coordinates(self, image, points, affine, invert_affine, affine_lps_to_ras, expected_output): + data = { + "image": image, + "point": points, + "affine": torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]]), + } + refer_key = "image" if (image is not None and image != "affine") else image + transform = ApplyTransformToPointsd( + keys="point", + refer_key=refer_key, + dtype=torch.int64, + affine=affine, + invert_affine=invert_affine, + affine_lps_to_ras=affine_lps_to_ras, + ) + output = transform(data) + + self.assertTrue(torch.allclose(output["point"], expected_output)) + invert_out = transform.inverse(output) + self.assertTrue(torch.allclose(invert_out["point"], points)) + + @parameterized.expand(TEST_CASES_WRONG) + def test_wrong_input(self, input, invert_affine, affine): + transform = ApplyTransformToPointsd(keys="point", dtype=torch.int64, invert_affine=invert_affine, affine=affine) + with self.assertRaises(ValueError): + transform({"point": input}) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_convert_box_points.py b/tests/test_convert_box_points.py new file mode 100644 index 0000000000..5e3d7ee645 --- /dev/null +++ b/tests/test_convert_box_points.py @@ -0,0 +1,121 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.data.box_utils import convert_box_to_standard_mode +from monai.transforms.spatial.array import ConvertBoxToPoints, ConvertPointsToBoxes +from tests.utils import assert_allclose + +TEST_CASE_POINTS_2D = [ + [ + torch.tensor([[10, 20, 30, 40], [50, 60, 70, 80]]), + "xyxy", + torch.tensor([[[10, 20], [30, 20], [30, 40], [10, 40]], [[50, 60], [70, 60], [70, 80], [50, 80]]]), + ], + [torch.tensor([[10, 20, 20, 20]]), "ccwh", torch.tensor([[[0, 10], [20, 10], [20, 30], [0, 30]]])], +] +TEST_CASE_POINTS_3D = [ + [ + torch.tensor([[10, 20, 30, 40, 50, 60], [70, 80, 90, 100, 110, 120]]), + "xyzxyz", + torch.tensor( + [ + [ + [10, 20, 30], + [40, 20, 30], + [40, 50, 30], + [10, 50, 30], + [10, 20, 60], + [40, 20, 60], + [40, 50, 60], + [10, 50, 60], + ], + [ + [70, 80, 90], + [100, 80, 90], + [100, 110, 90], + [70, 110, 90], + [70, 80, 120], + [100, 80, 120], + [100, 110, 120], + [70, 110, 120], + ], + ] + ), + ], + [ + torch.tensor([[10, 20, 30, 10, 10, 10]]), + "cccwhd", + torch.tensor( + [ + [ + [5, 15, 25], + [15, 15, 25], + [15, 25, 25], + [5, 25, 25], + [5, 15, 35], + [15, 15, 35], + [15, 25, 35], + [5, 25, 35], + ] + ] + ), + ], + [ + torch.tensor([[10, 20, 30, 40, 50, 60]]), + "xxyyzz", + torch.tensor( + [ + [ + [10, 30, 50], + [20, 30, 50], + [20, 40, 50], + [10, 40, 50], + [10, 30, 60], + [20, 30, 60], + [20, 40, 60], + [10, 40, 60], + ] + ] + ), + ], +] + +TEST_CASES = TEST_CASE_POINTS_2D + TEST_CASE_POINTS_3D + + +class TestConvertBoxToPoints(unittest.TestCase): + + @parameterized.expand(TEST_CASES) + def test_convert_box_to_points(self, boxes, mode, expected_points): + transform = ConvertBoxToPoints(mode=mode) + converted_points = transform(boxes) + assert_allclose(converted_points, expected_points, type_test=False) + + +class TestConvertPointsToBoxes(unittest.TestCase): + + @parameterized.expand(TEST_CASES) + def test_convert_box_to_points(self, boxes, mode, points): + transform = ConvertPointsToBoxes() + converted_boxes = transform(points) + expected_boxes = convert_box_to_standard_mode(boxes, mode) + assert_allclose(converted_boxes, expected_boxes, type_test=False) + + +if __name__ == "__main__": + unittest.main() From 6a0e1b043ba2890e1463fa49df76f66e56a68b08 Mon Sep 17 00:00:00 2001 From: Yufan He <59374597+heyufan1995@users.noreply.github.com> Date: Mon, 2 Sep 2024 01:10:10 -0500 Subject: [PATCH 09/15] Fix vista3d transpose bug (#8059) Fixes # . ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: heyufan1995 Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Yiheng Wang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Yiheng Wang Co-authored-by: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> --- monai/apps/vista3d/inferer.py | 2 +- monai/networks/nets/vista3d.py | 16 ++++++++++------ 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/monai/apps/vista3d/inferer.py b/monai/apps/vista3d/inferer.py index 709f81f624..8f622ef6cd 100644 --- a/monai/apps/vista3d/inferer.py +++ b/monai/apps/vista3d/inferer.py @@ -100,7 +100,7 @@ def point_based_window_inferer( point_labels=point_labels, class_vector=class_vector, prompt_class=prompt_class, - patch_coords=unravel_slice, + patch_coords=[unravel_slice], prev_mask=prev_mask, **kwargs, ) diff --git a/monai/networks/nets/vista3d.py b/monai/networks/nets/vista3d.py index 979a090df0..4215a9a594 100644 --- a/monai/networks/nets/vista3d.py +++ b/monai/networks/nets/vista3d.py @@ -336,7 +336,7 @@ def set_auto_grad(self, auto_freeze: bool = False, point_freeze: bool = False): def forward( self, input_images: torch.Tensor, - patch_coords: Sequence[slice] | None = None, + patch_coords: list[Sequence[slice]] | None = None, point_coords: torch.Tensor | None = None, point_labels: torch.Tensor | None = None, class_vector: torch.Tensor | None = None, @@ -364,8 +364,12 @@ def forward( the points are for zero-shot or supported class. When class_vector and point_coords are both provided, prompt_class is the same as class_vector. For prompt_class[b] > 512, point_coords[b] will be considered novel class. - patch_coords: a sequence of the python slice objects representing the patch coordinates during sliding window inference. - This value is passed from sliding_window_inferer. This is an indicator for training phase or validation phase. + patch_coords: a list of sequence of the python slice objects representing the patch coordinates during sliding window + inference. This value is passed from sliding_window_inferer. + This is an indicator for training phase or validation phase. + Notice for sliding window batch size > 1 (only supported by automatic segmentation), patch_coords will inlcude + coordinates of multiple patches. If point prompts are included, the batch size can only be one and all the + functions using patch_coords will by default use patch_coords[0]. labels: [1, 1, H, W, D], the groundtruth label tensor, only used for point-only evaluation label_set: the label index matching the indexes in labels. If labels are mapped to global index using RelabelID, this label_set should be global mapped index. If labels are not mapped to global index, e.g. in zero-shot @@ -395,14 +399,14 @@ def forward( if val_point_sampler is None: # TODO: think about how to refactor this part. val_point_sampler = self.sample_points_patch_val - point_coords, point_labels, prompt_class = val_point_sampler(labels, patch_coords, label_set) + point_coords, point_labels, prompt_class = val_point_sampler(labels, patch_coords[0], label_set) if prompt_class[0].item() == 0: # type: ignore point_labels[0] = -1 # type: ignore labels, prev_mask = None, None elif point_coords is not None: # If not performing patch-based point only validation, use user provided click points for inference. # the point clicks is in original image space, convert it to current patch-coordinate space. - point_coords, point_labels = self.update_point_to_patch(patch_coords, point_coords, point_labels) # type: ignore + point_coords, point_labels = self.update_point_to_patch(patch_coords[0], point_coords, point_labels) # type: ignore if point_coords is not None and point_labels is not None: # remove points that used for padding purposes (point_label = -1) @@ -455,7 +459,7 @@ def forward( logits[mapping_index] = self.point_head(out, point_coords, point_labels, class_vector=prompt_class) if prev_mask is not None and patch_coords is not None: logits = self.connected_components_combine( - prev_mask[patch_coords].transpose(1, 0).to(logits.device), + prev_mask[patch_coords[0]].transpose(1, 0).to(logits.device), logits[mapping_index], point_coords, # type: ignore point_labels, # type: ignore From c9b8bdb93b8a5e5fd4e9a72c34dba760f9c66a04 Mon Sep 17 00:00:00 2001 From: BenjaminLi <76795078+25benjaminli@users.noreply.github.com> Date: Mon, 2 Sep 2024 11:06:34 -0400 Subject: [PATCH 10/15] Add deterministic support for RandSimulateLowResolutiond (#8057) Fixes #7911, which describes how the RandSimulateLowResolutiond dictionary transform produces non-deterministic outputs, yet the typical array transform RandSimulateLowResolution produces deterministic ones. ### Description Inside of `RandSimulateLowResolutiond`, added the line `self.sim_lowres_tfm.set_random_state(seed, state)` in `set_random_state` to ensure the helper function `sim_lowres_tfm` is seeded and the transform can be performed deterministically. Note: I also sifted through the other dictionary transforms with helper functions and did not find anything that looked problematic similar to this. ### Types of changes - [ ] Non-breaking change (fix or new feature that would not break existing functionality). - [x] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: BenjaminLi <25benjaminli@gmail.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/transforms/spatial/dictionary.py | 1 + 1 file changed, 1 insertion(+) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 82dee15c7c..2b80034a07 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -2588,6 +2588,7 @@ def set_random_state( self, seed: int | None = None, state: np.random.RandomState | None = None ) -> RandSimulateLowResolutiond: super().set_random_state(seed, state) + self.sim_lowres_tfm.set_random_state(seed, state) return self def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: From dbfe418c03073baf07a0e14cd7606571f3d0de18 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Tue, 3 Sep 2024 00:35:26 +0800 Subject: [PATCH 11/15] Ignore warning from nptyping as workaround (#8062) workaround for #8061 ### Description Ignore warning from nptyping as workaround ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> --- monai/data/image_reader.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index f5e199e2a3..aab1e03898 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -34,6 +34,9 @@ ) from monai.utils import MetaKeys, SpaceKeys, TraceKeys, ensure_tuple, optional_import, require_pkg +# workaround for https://github.com/Project-MONAI/MONAI/issues/8061 +warnings.filterwarnings("ignore", category=DeprecationWarning, module="nptyping") + if TYPE_CHECKING: import itk import nibabel as nib From befb5f6af1521e53ee3e757ea66df30b40d24801 Mon Sep 17 00:00:00 2001 From: Bastian Wittmann <73648286+bwittmann@users.noreply.github.com> Date: Tue, 3 Sep 2024 14:55:00 +0200 Subject: [PATCH 12/15] Forced `Fourier` class to output contiguous tensors. (#7969) Forced `Fourier` class to output contiguous tensors, which potentially fixes a performance bottleneck. ### Description Some transforms, such as `RandKSpaceSpikeNoise`, rely on the `Fourier` class. In its current state, the `Fourier` class returns non-contiguous tensors, which potentially limits performance. For example, when followed by `RandHistogramShift`, the following warning occurs: ``` /monai/transforms/intensity/array.py:1852: UserWarning: torch.searchsorted(): input value tensor is non-contiguous, this will lower the performance due to extra data copy when converting non-contiguous tensor to contiguous, please use contiguous input value tensor if possible. This message will only appear once per program. (Triggered internally at /opt/conda/conda-bld/pytorch_1716905975447/work/aten/src/ATen/native/BucketizationUtils.h:32.) indices = ns.searchsorted(xp.reshape(-1), x.reshape(-1)) - 1 ``` A straightforward fix is to force the `Fourier` class to output contiguous tensors (see commit). To reproduce, please run: ``` from monai.transforms import RandKSpaceSpikeNoise from monai.transforms.utils import Fourier import numpy as np ### TEST WITH TRANSFORMS ### t = RandKSpaceSpikeNoise(prob=1) # for torch tensors a_torch = torch.rand(1, 128, 128, 128) print(a_torch.is_contiguous()) a_torch_mod = t(a_torch) print(a_torch_mod.is_contiguous()) # for np arrays a_np = np.random.rand(1, 128, 128, 128) print(a_np.flags['C_CONTIGUOUS']) a_np_mod = t(a_np) # automatically transformed to torch.tensor print(a_np_mod.is_contiguous()) ### TEST DIRECTLY WITH FOURIER ### f = Fourier() # inv_shift_fourier # for torch tensors real_torch = torch.randn(1, 128, 128, 128) im_torch = torch.randn(1, 128, 128, 128) k_torch = torch.complex(real_torch, im_torch) print(k_torch.is_contiguous()) out_torch = f.inv_shift_fourier(k_torch, spatial_dims=3) print(out_torch.is_contiguous()) # for np arrays real_np = np.random.randn(1, 100, 100, 100) im_np = np.random.randn(1, 100, 100, 100) k_np = real_np + 1j * im_np print(k_np.flags['C_CONTIGUOUS']) out_np = f.inv_shift_fourier(k_np, spatial_dims=3) print(out_np.flags['C_CONTIGUOUS']) # shift_fourier # for torch tensors a_torch = torch.rand(1, 128, 128, 128) print(a_torch.is_contiguous()) out_torch = f.shift_fourier(a_torch, spatial_dims=3) print(out_torch.is_contiguous()) # for np arrays a_np = np.random.rand(1, 128, 128, 128) print(a_np.flags['C_CONTIGUOUS']) out_np = f.shift_fourier(a_np, spatial_dims=3) print(out_np.flags['C_CONTIGUOUS']) ``` ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Bastian Wittmann . Signed-off-by: Bastian Wittmann Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/transforms/utils.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index b1f1bbd0f6..32fffc25f0 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -1863,7 +1863,7 @@ class Fourier: """ @staticmethod - def shift_fourier(x: NdarrayOrTensor, spatial_dims: int) -> NdarrayOrTensor: + def shift_fourier(x: NdarrayOrTensor, spatial_dims: int, as_contiguous: bool = False) -> NdarrayOrTensor: """ Applies fourier transform and shifts the zero-frequency component to the center of the spectrum. Only the spatial dimensions get transformed. @@ -1871,6 +1871,7 @@ def shift_fourier(x: NdarrayOrTensor, spatial_dims: int) -> NdarrayOrTensor: Args: x: Image to transform. spatial_dims: Number of spatial dimensions. + as_contiguous: Whether to convert the cached NumPy array or PyTorch tensor to be contiguous. Returns k: K-space data. @@ -1885,10 +1886,12 @@ def shift_fourier(x: NdarrayOrTensor, spatial_dims: int) -> NdarrayOrTensor: k = np.fft.fftshift(np.fft.fftn(x.cpu().numpy(), axes=dims), axes=dims) else: k = np.fft.fftshift(np.fft.fftn(x, axes=dims), axes=dims) - return k + return ascontiguousarray(k) if as_contiguous else k @staticmethod - def inv_shift_fourier(k: NdarrayOrTensor, spatial_dims: int, n_dims: int | None = None) -> NdarrayOrTensor: + def inv_shift_fourier( + k: NdarrayOrTensor, spatial_dims: int, n_dims: int | None = None, as_contiguous: bool = False + ) -> NdarrayOrTensor: """ Applies inverse shift and fourier transform. Only the spatial dimensions are transformed. @@ -1896,6 +1899,7 @@ def inv_shift_fourier(k: NdarrayOrTensor, spatial_dims: int, n_dims: int | None Args: k: K-space data. spatial_dims: Number of spatial dimensions. + as_contiguous: Whether to convert the cached NumPy array or PyTorch tensor to be contiguous. Returns: x: Tensor in image space. @@ -1910,7 +1914,7 @@ def inv_shift_fourier(k: NdarrayOrTensor, spatial_dims: int, n_dims: int | None out = np.fft.ifftn(np.fft.ifftshift(k.cpu().numpy(), axes=dims), axes=dims).real else: out = np.fft.ifftn(np.fft.ifftshift(k, axes=dims), axes=dims).real - return out + return ascontiguousarray(out) if as_contiguous else out def get_number_image_type_conversions(transform: Compose, test_data: Any, key: Hashable | None = None) -> int: From aea46ff26b39c0c88e3d00cb88cb03442df61dd5 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Wed, 4 Sep 2024 01:31:28 -0700 Subject: [PATCH 13/15] Trt compiler fixes (#8064) Fixes https://github.com/Project-MONAI/MONAI/issues/8061. ### Description Post-merge fixes for trt_compile() ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Boris Fomitchev Signed-off-by: Yiheng Wang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Co-authored-by: Yiheng Wang Co-authored-by: binliunls <107988372+binliunls@users.noreply.github.com> --- monai/networks/trt_compiler.py | 8 ++++++-- tests/test_trt_compile.py | 4 ++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/monai/networks/trt_compiler.py b/monai/networks/trt_compiler.py index a9dd0d9e9b..00d2eb61af 100644 --- a/monai/networks/trt_compiler.py +++ b/monai/networks/trt_compiler.py @@ -342,6 +342,7 @@ def forward(self, model, argv, kwargs): self._build_and_save(model, build_args) # This will reassign input_names from the engine self._load_engine() + assert self.engine is not None except Exception as e: if self.fallback: self.logger.info(f"Failed to build engine: {e}") @@ -403,8 +404,10 @@ def _onnx_to_trt(self, onnx_path): build_args = self.build_args.copy() build_args["tf32"] = self.precision != "fp32" - build_args["fp16"] = self.precision == "fp16" - build_args["bf16"] = self.precision == "bf16" + if self.precision == "fp16": + build_args["fp16"] = True + elif self.precision == "bf16": + build_args["bf16"] = True self.logger.info(f"Building TensorRT engine for {onnx_path}: {self.plan_path}") network = network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]) @@ -502,6 +505,7 @@ def trt_compile( ) -> torch.nn.Module: """ Instruments model or submodule(s) with TrtCompiler and replaces its forward() with TRT hook. + Note: TRT 10.3 is recommended for best performance. Some nets may even fail to work with TRT 8.x Args: model: module to patch with TrtCompiler object. base_path: TRT plan(s) saved to f"{base_path}[.{submodule}].plan" path. diff --git a/tests/test_trt_compile.py b/tests/test_trt_compile.py index 21125d203f..2f9db8f0c2 100644 --- a/tests/test_trt_compile.py +++ b/tests/test_trt_compile.py @@ -20,10 +20,10 @@ from monai.handlers import TrtHandler from monai.networks import trt_compile from monai.networks.nets import UNet, cell_sam_wrapper, vista3d132 -from monai.utils import optional_import +from monai.utils import min_version, optional_import from tests.utils import skip_if_no_cuda, skip_if_quick, skip_if_windows -trt, trt_imported = optional_import("tensorrt") +trt, trt_imported = optional_import("tensorrt", "10.1.0", min_version) polygraphy, polygraphy_imported = optional_import("polygraphy") build_sam_vit_b, has_sam = optional_import("segment_anything.build_sam", name="build_sam_vit_b") From 4e70bf694c5178637f4749a84e7a59a8d07332e7 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Wed, 4 Sep 2024 17:52:52 +0800 Subject: [PATCH 14/15] Allow ApplyTransformToPointsd receive a sequence of refer keys (#8063) Enhance `ApplyTransformToPointsd` to receive a sequence of refer keys. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- monai/transforms/utility/array.py | 64 ++++++----- monai/transforms/utility/dictionary.py | 28 ++--- tests/test_apply_transform_to_pointsd.py | 136 ++++++++++++++++------- 3 files changed, 146 insertions(+), 82 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index fee546bea3..bfd2f506c2 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -1764,6 +1764,30 @@ def __init__( self.invert_affine = invert_affine self.affine_lps_to_ras = affine_lps_to_ras + def _compute_final_affine(self, affine: torch.Tensor, applied_affine: torch.Tensor | None = None) -> torch.Tensor: + """ + Compute the final affine transformation matrix to apply to the point data. + + Args: + data: Input coordinates assumed to be in the shape (C, N, 2 or 3). + affine: 3x3 or 4x4 affine transformation matrix. + + Returns: + Final affine transformation matrix. + """ + + affine = convert_data_type(affine, dtype=torch.float64)[0] + + if self.affine_lps_to_ras: + affine = orientation_ras_lps(affine) + + if self.invert_affine: + affine = linalg_inv(affine) + if applied_affine is not None: + affine = affine @ applied_affine + + return affine + def transform_coordinates( self, data: torch.Tensor, affine: torch.Tensor | None = None ) -> tuple[torch.Tensor, dict]: @@ -1780,35 +1804,25 @@ def transform_coordinates( Transformed coordinates. """ data = convert_to_tensor(data, track_meta=get_track_meta()) - # applied_affine is the affine transformation matrix that has already been applied to the point data - applied_affine = getattr(data, "affine", None) - if affine is None and self.invert_affine: raise ValueError("affine must be provided when invert_affine is True.") - + # applied_affine is the affine transformation matrix that has already been applied to the point data + applied_affine: torch.Tensor | None = getattr(data, "affine", None) affine = applied_affine if affine is None else affine - affine = convert_data_type(affine, dtype=torch.float64)[0] # always convert to float64 for affine - original_affine: torch.Tensor = affine - if self.affine_lps_to_ras: - affine = orientation_ras_lps(affine) + if affine is None: + raise ValueError("affine must be provided if data does not have an affine matrix.") - # the final affine transformation matrix that will be applied to the point data - _affine: torch.Tensor = affine - if self.invert_affine: - _affine = linalg_inv(affine) - if applied_affine is not None: - # consider the affine transformation already applied to the data in the world space - # and compute delta affine - _affine = _affine @ linalg_inv(applied_affine) - out = apply_affine_to_points(data, _affine, dtype=self.dtype) + final_affine = self._compute_final_affine(affine, applied_affine) + out = apply_affine_to_points(data, final_affine, dtype=self.dtype) extra_info = { "invert_affine": self.invert_affine, "dtype": get_dtype_string(self.dtype), - "image_affine": original_affine, # record for inverse operation + "image_affine": affine, "affine_lps_to_ras": self.affine_lps_to_ras, } - xform: torch.Tensor = original_affine if self.invert_affine else linalg_inv(original_affine) + + xform = orientation_ras_lps(linalg_inv(final_affine)) if self.affine_lps_to_ras else linalg_inv(final_affine) meta_info = TraceableTransform.track_transform_meta( data, affine=xform, extra_info=extra_info, transform_info=self.get_transform_info() ) @@ -1834,16 +1848,12 @@ def __call__(self, data: torch.Tensor, affine: torch.Tensor | None = None): def inverse(self, data: torch.Tensor) -> torch.Tensor: transform = self.pop_transform(data) - # Create inverse transform - dtype = transform[TraceKeys.EXTRA_INFO]["dtype"] - invert_affine = not transform[TraceKeys.EXTRA_INFO]["invert_affine"] - affine = transform[TraceKeys.EXTRA_INFO]["image_affine"] - affine_lps_to_ras = transform[TraceKeys.EXTRA_INFO]["affine_lps_to_ras"] inverse_transform = ApplyTransformToPoints( - dtype=dtype, invert_affine=invert_affine, affine_lps_to_ras=affine_lps_to_ras + dtype=transform[TraceKeys.EXTRA_INFO]["dtype"], + invert_affine=not transform[TraceKeys.EXTRA_INFO]["invert_affine"], + affine_lps_to_ras=transform[TraceKeys.EXTRA_INFO]["affine_lps_to_ras"], ) - # Apply inverse with inverse_transform.trace_transform(False): - data = inverse_transform(data, affine) + data = inverse_transform(data, transform[TraceKeys.EXTRA_INFO]["image_affine"]) return data diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 1279ca93ab..db5f19c0de 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -1758,8 +1758,9 @@ class ApplyTransformToPointsd(MapTransform, InvertibleTransform): Args: keys: keys of the corresponding items to be transformed. See also: monai.transforms.MapTransform - refer_key: The key of the reference item used for transformation. - It can directly refer to an affine or an image from which the affine can be derived. + refer_keys: The key of the reference item used for transformation. + It can directly refer to an affine or an image from which the affine can be derived. It can also be a + sequence of keys, in which case each refers to the affine applied to the matching points in `keys`. dtype: The desired data type for the output. affine: A 3x3 or 4x4 affine transformation matrix applied to points. This matrix typically originates from the image. For 2D points, a 3x3 matrix can be provided, avoiding the need to add an unnecessary @@ -1782,7 +1783,7 @@ class ApplyTransformToPointsd(MapTransform, InvertibleTransform): def __init__( self, keys: KeysCollection, - refer_key: str | None = None, + refer_keys: KeysCollection | None = None, dtype: DtypeLike | torch.dtype = torch.float64, affine: torch.Tensor | None = None, invert_affine: bool = True, @@ -1790,23 +1791,24 @@ def __init__( allow_missing_keys: bool = False, ): MapTransform.__init__(self, keys, allow_missing_keys) - self.refer_key = refer_key + self.refer_keys = ensure_tuple_rep(refer_keys, len(self.keys)) self.converter = ApplyTransformToPoints( dtype=dtype, affine=affine, invert_affine=invert_affine, affine_lps_to_ras=affine_lps_to_ras ) def __call__(self, data: Mapping[Hashable, torch.Tensor]): d = dict(data) - if self.refer_key is not None: - if self.refer_key in d: - refer_data = d[self.refer_key] - else: - raise KeyError(f"The refer_key '{self.refer_key}' is not found in the data.") - else: - refer_data = None - affine = getattr(refer_data, "affine", refer_data) - for key in self.key_iterator(d): + for key, refer_key in self.key_iterator(d, self.refer_keys): coords = d[key] + affine = None # represents using affine given in constructor + if refer_key is not None: + if refer_key in d: + refer_data = d[refer_key] + else: + raise KeyError(f"The refer_key '{refer_key}' is not found in the data.") + + # use the "affine" member of refer_data, or refer_data itself, as the affine matrix + affine = getattr(refer_data, "affine", refer_data) d[key] = self.converter(coords, affine) return d diff --git a/tests/test_apply_transform_to_pointsd.py b/tests/test_apply_transform_to_pointsd.py index 4cedfa9d66..978113931c 100644 --- a/tests/test_apply_transform_to_pointsd.py +++ b/tests/test_apply_transform_to_pointsd.py @@ -30,72 +30,90 @@ POINT_3D_WORLD = torch.tensor([[[2, 4, 6], [8, 10, 12]], [[14, 16, 18], [20, 22, 24]]]) POINT_3D_IMAGE = torch.tensor([[[-8, 8, 6], [-2, 14, 12]], [[4, 20, 18], [10, 26, 24]]]) POINT_3D_IMAGE_RAS = torch.tensor([[[-12, 0, 6], [-18, -6, 12]], [[-24, -12, 18], [-30, -18, 24]]]) +AFFINE_1 = torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]) +AFFINE_2 = torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]]) TEST_CASES = [ + [MetaTensor(DATA_2D, affine=AFFINE_1), POINT_2D_WORLD, None, True, False, POINT_2D_IMAGE], # use image affine + [None, MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), None, False, False, POINT_2D_WORLD], # use point affine + [None, MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), AFFINE_1, False, False, POINT_2D_WORLD], # use input affine + [None, POINT_2D_WORLD, AFFINE_1, True, False, POINT_2D_IMAGE], # use input affine [ - MetaTensor(DATA_2D, affine=torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])), + MetaTensor(DATA_2D, affine=AFFINE_1), POINT_2D_WORLD, None, True, - False, - POINT_2D_IMAGE, - ], + True, + POINT_2D_IMAGE_RAS, + ], # test affine_lps_to_ras + [MetaTensor(DATA_3D, affine=AFFINE_2), POINT_3D_WORLD, None, True, False, POINT_3D_IMAGE], + ["affine", POINT_3D_WORLD, None, True, False, POINT_3D_IMAGE], # use refer_data itself [ - None, - MetaTensor(POINT_2D_IMAGE, affine=torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])), + MetaTensor(DATA_3D, affine=AFFINE_2), + MetaTensor(POINT_3D_IMAGE, affine=AFFINE_2), None, False, False, - POINT_2D_WORLD, + POINT_3D_WORLD, ], + [MetaTensor(DATA_3D, affine=AFFINE_2), POINT_3D_WORLD, None, True, True, POINT_3D_IMAGE_RAS], + [MetaTensor(DATA_3D, affine=AFFINE_2), POINT_3D_WORLD, None, True, True, POINT_3D_IMAGE_RAS], +] +TEST_CASES_SEQUENCE = [ [ + (MetaTensor(DATA_2D, affine=AFFINE_1), MetaTensor(DATA_3D, affine=AFFINE_2)), + [POINT_2D_WORLD, POINT_3D_WORLD], None, - MetaTensor(POINT_2D_IMAGE, affine=torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])), - torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]), - False, + True, False, - POINT_2D_WORLD, - ], + ["image_1", "image_2"], + [POINT_2D_IMAGE, POINT_3D_IMAGE], + ], # use image affine [ - MetaTensor(DATA_2D, affine=torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])), - POINT_2D_WORLD, + (MetaTensor(DATA_2D, affine=AFFINE_1), MetaTensor(DATA_3D, affine=AFFINE_2)), + [POINT_2D_WORLD, POINT_3D_WORLD], None, True, True, - POINT_2D_IMAGE_RAS, - ], + ["image_1", "image_2"], + [POINT_2D_IMAGE_RAS, POINT_3D_IMAGE_RAS], + ], # test affine_lps_to_ras [ - MetaTensor(DATA_3D, affine=torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]])), - POINT_3D_WORLD, + (None, None), + [MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), MetaTensor(POINT_3D_IMAGE, affine=AFFINE_2)], None, + False, + False, + None, + [POINT_2D_WORLD, POINT_3D_WORLD], + ], # use point affine + [ + (None, None), + [POINT_2D_WORLD, POINT_2D_WORLD], + AFFINE_1, True, False, - POINT_3D_IMAGE, - ], - ["affine", POINT_3D_WORLD, None, True, False, POINT_3D_IMAGE], + None, + [POINT_2D_IMAGE, POINT_2D_IMAGE], + ], # use input affine [ - MetaTensor(DATA_3D, affine=torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]])), - MetaTensor(POINT_3D_IMAGE, affine=torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]])), + (MetaTensor(DATA_2D, affine=AFFINE_1), MetaTensor(DATA_3D, affine=AFFINE_2)), + [MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), MetaTensor(POINT_3D_IMAGE, affine=AFFINE_2)], None, False, False, - POINT_3D_WORLD, - ], - [ - MetaTensor(DATA_3D, affine=torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]])), - POINT_3D_WORLD, - None, - True, - True, - POINT_3D_IMAGE_RAS, + ["image_1", "image_2"], + [POINT_2D_WORLD, POINT_3D_WORLD], ], ] TEST_CASES_WRONG = [ - [POINT_2D_WORLD, True, None], - [POINT_2D_WORLD.unsqueeze(0), False, None], - [POINT_3D_WORLD[..., 0:1], False, None], - [POINT_3D_WORLD, False, torch.tensor([[[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]]])], + [POINT_2D_WORLD, True, None, None], + [POINT_2D_WORLD.unsqueeze(0), False, None, None], + [POINT_3D_WORLD[..., 0:1], False, None, None], + [POINT_3D_WORLD, False, torch.tensor([[[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]]]), None], + [POINT_3D_WORLD, False, None, "image"], + [POINT_3D_WORLD, False, None, []], ] @@ -107,10 +125,10 @@ def test_transform_coordinates(self, image, points, affine, invert_affine, affin "point": points, "affine": torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]]), } - refer_key = "image" if (image is not None and image != "affine") else image + refer_keys = "image" if (image is not None and image != "affine") else image transform = ApplyTransformToPointsd( keys="point", - refer_key=refer_key, + refer_keys=refer_keys, dtype=torch.int64, affine=affine, invert_affine=invert_affine, @@ -122,11 +140,45 @@ def test_transform_coordinates(self, image, points, affine, invert_affine, affin invert_out = transform.inverse(output) self.assertTrue(torch.allclose(invert_out["point"], points)) + @parameterized.expand(TEST_CASES_SEQUENCE) + def test_transform_coordinates_sequences( + self, image, points, affine, invert_affine, affine_lps_to_ras, refer_keys, expected_output + ): + data = {"image_1": image[0], "image_2": image[1], "point_1": points[0], "point_2": points[1]} + keys = ["point_1", "point_2"] + transform = ApplyTransformToPointsd( + keys=keys, + refer_keys=refer_keys, + dtype=torch.int64, + affine=affine, + invert_affine=invert_affine, + affine_lps_to_ras=affine_lps_to_ras, + ) + output = transform(data) + + self.assertTrue(torch.allclose(output["point_1"], expected_output[0])) + self.assertTrue(torch.allclose(output["point_2"], expected_output[1])) + invert_out = transform.inverse(output) + self.assertTrue(torch.allclose(invert_out["point_1"], points[0])) + @parameterized.expand(TEST_CASES_WRONG) - def test_wrong_input(self, input, invert_affine, affine): - transform = ApplyTransformToPointsd(keys="point", dtype=torch.int64, invert_affine=invert_affine, affine=affine) - with self.assertRaises(ValueError): - transform({"point": input}) + def test_wrong_input(self, input, invert_affine, affine, refer_keys): + if refer_keys == []: + with self.assertRaises(ValueError): + ApplyTransformToPointsd( + keys="point", dtype=torch.int64, invert_affine=invert_affine, affine=affine, refer_keys=refer_keys + ) + else: + transform = ApplyTransformToPointsd( + keys="point", dtype=torch.int64, invert_affine=invert_affine, affine=affine, refer_keys=refer_keys + ) + data = {"point": input} + if refer_keys == "image": + with self.assertRaises(KeyError): + transform(data) + else: + with self.assertRaises(ValueError): + transform(data) if __name__ == "__main__": From 19cc6f01766120132f964beecb06d1d561f83801 Mon Sep 17 00:00:00 2001 From: "Wei_Chuan, Chiang" <45346252+slicepaste@users.noreply.github.com> Date: Wed, 4 Sep 2024 18:42:49 +0800 Subject: [PATCH 15/15] Make MetaTensor optional printed in DataStats and DataStatsd #5905 (#7814) Fixes #5905 ### Description We simply add one argument for DataStats and DataStatsd to make MetaTensor optional printed. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Wei_Chuan, Chiang Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Suraj Pai Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: Suraj Pai Co-authored-by: Ben Murray --- monai/transforms/utility/array.py | 7 ++++ monai/transforms/utility/dictionary.py | 28 +++++++++++-- tests/test_data_stats.py | 41 ++++++++++++++++++- tests/test_data_statsd.py | 54 +++++++++++++++++++++++++- 4 files changed, 123 insertions(+), 7 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index bfd2f506c2..72dd189009 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -656,6 +656,7 @@ def __init__( data_shape: bool = True, value_range: bool = True, data_value: bool = False, + meta_info: bool = False, additional_info: Callable | None = None, name: str = "DataStats", ) -> None: @@ -667,6 +668,7 @@ def __init__( value_range: whether to show the value range of input data. data_value: whether to show the raw value of input data. a typical example is to print some properties of Nifti image: affine, pixdim, etc. + meta_info: whether to show the data of MetaTensor. additional_info: user can define callable function to extract additional info from input data. name: identifier of `logging.logger` to use, defaulting to "DataStats". @@ -681,6 +683,7 @@ def __init__( self.data_shape = data_shape self.value_range = value_range self.data_value = data_value + self.meta_info = meta_info if additional_info is not None and not callable(additional_info): raise TypeError(f"additional_info must be None or callable but is {type(additional_info).__name__}.") self.additional_info = additional_info @@ -707,6 +710,7 @@ def __call__( data_shape: bool | None = None, value_range: bool | None = None, data_value: bool | None = None, + meta_info: bool | None = None, additional_info: Callable | None = None, ) -> NdarrayOrTensor: """ @@ -727,6 +731,9 @@ def __call__( lines.append(f"Value range: (not a PyTorch or Numpy array, type: {type(img)})") if self.data_value if data_value is None else data_value: lines.append(f"Value: {img}") + if self.meta_info if meta_info is None else meta_info: + metadata = getattr(img, "meta", "(input is not a MetaTensor)") + lines.append(f"Meta info: {repr(metadata)}") additional_info = self.additional_info if additional_info is None else additional_info if additional_info is not None: lines.append(f"Additional info: {additional_info(img)}") diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index db5f19c0de..79d0be522d 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -793,6 +793,7 @@ def __init__( data_shape: Sequence[bool] | bool = True, value_range: Sequence[bool] | bool = True, data_value: Sequence[bool] | bool = False, + meta_info: Sequence[bool] | bool = False, additional_info: Sequence[Callable] | Callable | None = None, name: str = "DataStats", allow_missing_keys: bool = False, @@ -812,6 +813,8 @@ def __init__( data_value: whether to show the raw value of input data. it also can be a sequence of bool, each element corresponds to a key in ``keys``. a typical example is to print some properties of Nifti image: affine, pixdim, etc. + meta_info: whether to show the data of MetaTensor. + it also can be a sequence of bool, each element corresponds to a key in ``keys``. additional_info: user can define callable function to extract additional info from input data. it also can be a sequence of string, each element corresponds to a key in ``keys``. @@ -825,15 +828,34 @@ def __init__( self.data_shape = ensure_tuple_rep(data_shape, len(self.keys)) self.value_range = ensure_tuple_rep(value_range, len(self.keys)) self.data_value = ensure_tuple_rep(data_value, len(self.keys)) + self.meta_info = ensure_tuple_rep(meta_info, len(self.keys)) self.additional_info = ensure_tuple_rep(additional_info, len(self.keys)) self.printer = DataStats(name=name) def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) - for key, prefix, data_type, data_shape, value_range, data_value, additional_info in self.key_iterator( - d, self.prefix, self.data_type, self.data_shape, self.value_range, self.data_value, self.additional_info + for ( + key, + prefix, + data_type, + data_shape, + value_range, + data_value, + meta_info, + additional_info, + ) in self.key_iterator( + d, + self.prefix, + self.data_type, + self.data_shape, + self.value_range, + self.data_value, + self.meta_info, + self.additional_info, ): - d[key] = self.printer(d[key], prefix, data_type, data_shape, value_range, data_value, additional_info) + d[key] = self.printer( + d[key], prefix, data_type, data_shape, value_range, data_value, meta_info, additional_info + ) return d diff --git a/tests/test_data_stats.py b/tests/test_data_stats.py index 05453b0694..f9b424f8e1 100644 --- a/tests/test_data_stats.py +++ b/tests/test_data_stats.py @@ -23,6 +23,7 @@ import torch from parameterized import parameterized +from monai.data.meta_tensor import MetaTensor from monai.transforms import DataStats TEST_CASE_1 = [ @@ -130,20 +131,55 @@ ] TEST_CASE_8 = [ + { + "prefix": "test data", + "data_type": True, + "data_shape": True, + "value_range": True, + "data_value": True, + "additional_info": np.mean, + "name": "DataStats", + }, np.array([[0, 1], [1, 2]]), "test data statistics:\nType: int64\nShape: (2, 2)\nValue range: (0, 2)\n" "Value: [[0 1]\n [1 2]]\nAdditional info: 1.0\n", ] +TEST_CASE_9 = [ + np.array([[0, 1], [1, 2]]), + "test data statistics:\nType: int64\nShape: (2, 2)\nValue range: (0, 2)\n" + "Value: [[0 1]\n [1 2]]\n" + "Meta info: '(input is not a MetaTensor)'\n" + "Additional info: 1.0\n", +] + +TEST_CASE_10 = [ + MetaTensor( + torch.tensor([[0, 1], [1, 2]]), + affine=torch.as_tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 2, 0], [0, 0, 0, 1]], dtype=torch.float64), + meta={"some": "info"}, + ), + "test data statistics:\nType: torch.int64\n" + "Shape: torch.Size([2, 2])\nValue range: (0, 2)\n" + "Value: tensor([[0, 1],\n [1, 2]])\n" + "Meta info: {'some': 'info', affine: tensor([[2., 0., 0., 0.],\n" + " [0., 2., 0., 0.],\n" + " [0., 0., 2., 0.],\n" + " [0., 0., 0., 1.]], dtype=torch.float64), space: RAS}\n" + "Additional info: 1.0\n", +] + class TestDataStats(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7]) + @parameterized.expand( + [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8] + ) def test_value(self, input_param, input_data, expected_print): transform = DataStats(**input_param) _ = transform(input_data) - @parameterized.expand([TEST_CASE_8]) + @parameterized.expand([TEST_CASE_9, TEST_CASE_10]) def test_file(self, input_data, expected_print): with tempfile.TemporaryDirectory() as tempdir: filename = os.path.join(tempdir, "test_data_stats.log") @@ -158,6 +194,7 @@ def test_file(self, input_data, expected_print): "data_shape": True, "value_range": True, "data_value": True, + "meta_info": True, "additional_info": np.mean, "name": name, } diff --git a/tests/test_data_statsd.py b/tests/test_data_statsd.py index ef88300c10..a28a938c40 100644 --- a/tests/test_data_statsd.py +++ b/tests/test_data_statsd.py @@ -21,6 +21,7 @@ import torch from parameterized import parameterized +from monai.data.meta_tensor import MetaTensor from monai.transforms import DataStatsd TEST_CASE_1 = [ @@ -150,22 +151,70 @@ ] TEST_CASE_9 = [ + { + "keys": "img", + "prefix": "test data", + "data_shape": True, + "value_range": True, + "data_value": True, + "meta_info": False, + "additional_info": np.mean, + "name": "DataStats", + }, {"img": np.array([[0, 1], [1, 2]])}, "test data statistics:\nType: int64\nShape: (2, 2)\nValue range: (0, 2)\n" "Value: [[0 1]\n [1 2]]\nAdditional info: 1.0\n", ] +TEST_CASE_10 = [ + {"img": np.array([[0, 1], [1, 2]])}, + "test data statistics:\nType: int64\nShape: (2, 2)\nValue range: (0, 2)\n" + "Value: [[0 1]\n [1 2]]\n" + "Meta info: '(input is not a MetaTensor)'\n" + "Additional info: 1.0\n", +] + +TEST_CASE_11 = [ + { + "img": ( + MetaTensor( + torch.tensor([[0, 1], [1, 2]]), + affine=torch.as_tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 2, 0], [0, 0, 0, 1]], dtype=torch.float64), + meta={"some": "info"}, + ) + ) + }, + "test data statistics:\nType: torch.int64\n" + "Shape: torch.Size([2, 2])\nValue range: (0, 2)\n" + "Value: tensor([[0, 1],\n [1, 2]])\n" + "Meta info: {'some': 'info', affine: tensor([[2., 0., 0., 0.],\n" + " [0., 2., 0., 0.],\n" + " [0., 0., 2., 0.],\n" + " [0., 0., 0., 1.]], dtype=torch.float64), space: RAS}\n" + "Additional info: 1.0\n", +] + class TestDataStatsd(unittest.TestCase): @parameterized.expand( - [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8] + [ + TEST_CASE_1, + TEST_CASE_2, + TEST_CASE_3, + TEST_CASE_4, + TEST_CASE_5, + TEST_CASE_6, + TEST_CASE_7, + TEST_CASE_8, + TEST_CASE_9, + ] ) def test_value(self, input_param, input_data, expected_print): transform = DataStatsd(**input_param) _ = transform(input_data) - @parameterized.expand([TEST_CASE_9]) + @parameterized.expand([TEST_CASE_10, TEST_CASE_11]) def test_file(self, input_data, expected_print): with tempfile.TemporaryDirectory() as tempdir: filename = os.path.join(tempdir, "test_stats.log") @@ -180,6 +229,7 @@ def test_file(self, input_data, expected_print): "data_shape": True, "value_range": True, "data_value": True, + "meta_info": True, "additional_info": np.mean, "name": name, }