Skip to content

Commit

Permalink
1428 update to the latest usage SaveImage (#1429)
Browse files Browse the repository at this point in the history
Fixes #1428



### Checks
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Avoid including large-size files in the PR.
- [x] Clean up long text outputs from code cells in the notebook.
- [x] For security purposes, please check the contents and remove any
sensitive info such as user names and private key.
- [x] Ensure (1) hyperlinks and markdown anchors are working (2) use
relative paths for tutorial repo files (3) put figure and graphs in the
`./figure` folder
- [x] Notebook runs automatically `./runner.sh -t <path to .ipynb file>`

Signed-off-by: Wenqi Li <[email protected]>
  • Loading branch information
wyli committed Jun 18, 2023
1 parent 5de3e5f commit bbc4e18
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions 3d_segmentation/challenge_baseline/run_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import shutil
import sys

import numpy as np
import torch
import torch.nn as nn
from ignite.contrib.handlers import ProgressBar
Expand All @@ -42,7 +41,7 @@ def get_xforms(mode="train", keys=("image", "label")):
"""returns a composed transform for train/val/infer."""

xforms = [
LoadImaged(keys, ensure_channel_first=True),
LoadImaged(keys, ensure_channel_first=True, image_only=True),
Orientationd(keys, axcodes="LPS"),
Spacingd(keys, pixdim=(1.25, 1.25, 5.0), mode=("bilinear", "nearest")[: len(keys)]),
ScaleIntensityRanged(keys[0], a_min=-1000.0, a_max=500.0, b_min=0.0, b_max=1.0, clip=True),
Expand Down Expand Up @@ -239,7 +238,7 @@ def infer(data_folder=".", model_folder="runs", prediction_folder="output"):
)

inferer = get_inferer()
saver = monai.data.NiftiSaver(output_dir=prediction_folder, mode="nearest")
saver = monai.transforms.SaveImage(output_dir=prediction_folder, mode="nearest", resample=True)
with torch.no_grad():
for infer_data in infer_loader:
logging.info(f"segmenting {infer_data['image'].meta['filename_or_obj']}")
Expand All @@ -258,7 +257,8 @@ def infer(data_folder=".", model_folder="runs", prediction_folder="output"):
n = n + 1.0
preds = preds / n
preds = (preds.argmax(dim=1, keepdims=True)).float()
saver.save_batch(preds, infer_data["image"].meta)
for p in preds: # save each image+metadata in the batch respectively
saver(p)

# copy the saved segmentations into the required folder structure for submission
submission_dir = os.path.join(prediction_folder, "to_submit")
Expand Down

0 comments on commit bbc4e18

Please sign in to comment.