Skip to content

Commit

Permalink
made ruffclean
Browse files Browse the repository at this point in the history
  • Loading branch information
Hendrik-code committed Feb 29, 2024
1 parent 6a6b100 commit 6ae0a50
Showing 1 changed file with 27 additions and 23 deletions.
50 changes: 27 additions & 23 deletions spineps/utils/inference_api.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import os
import os.path
import torch
from spineps.utils.predictor import nnUNetPredictor
from TPTBox import NII, No_Logger, Log_Type
from TPTBox.core import sitk_utils
from pathlib import Path

import nibabel as nib
import numpy as np
import torch
from TPTBox import NII, Log_Type, No_Logger
from TPTBox.core import sitk_utils

from spineps.utils.predictor import nnUNetPredictor

logger = No_Logger()
logger.override_prefix = "API"
Expand All @@ -15,7 +17,7 @@
# Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). nnU-Net: a self-configuring
# method for deep learning-based biomedical image segmentation. Nature methods, 18(2), 203-211.
def load_inf_model(
model_folder: str,
model_folder: str | Path,
step_size: float = 0.5,
ddevice: str = "cuda",
use_folds: tuple[str, ...] | None = None,
Expand All @@ -36,6 +38,8 @@ def load_inf_model(
Returns:
predictor: Loaded model predictor object
"""
if isinstance(model_folder, str):
model_folder = Path(model_folder)
if ddevice == "cpu":
# let's allow torch to use hella threads
import multiprocessing
Expand All @@ -50,7 +54,7 @@ def load_inf_model(
else:
device = torch.device("mps")

assert os.path.exists(model_folder), f"model-folder not found: got path {model_folder}"
assert model_folder.exists(), f"model-folder not found: got path {model_folder}"

predictor = nnUNetPredictor(
tile_step_size=step_size,
Expand All @@ -63,21 +67,21 @@ def load_inf_model(
)
check_name = "checkpoint_final.pth" # if not allow_non_final else "checkpoint_best.pth"
try:
predictor.initialize_from_trained_model_folder(model_folder, checkpoint_name=check_name, use_folds=use_folds)
predictor.initialize_from_trained_model_folder(str(model_folder), checkpoint_name=check_name, use_folds=use_folds)
except Exception as e:
if allow_non_final:
predictor.initialize_from_trained_model_folder(model_folder, checkpoint_name="checkpoint_best.pth", use_folds=use_folds)
predictor.initialize_from_trained_model_folder(str(model_folder), checkpoint_name="checkpoint_best.pth", use_folds=use_folds)
logger.print("Checkpoint final not found, will load from best instead", Log_Type.WARNING)
else:
raise e
raise e # noqa: TRY201
logger.print(f"Inference Model loaded from {model_folder}") if verbose else None
return predictor


def run_inference(
input: str | NII | list[NII],
input_nii: str | NII | list[NII],
predictor: nnUNetPredictor,
reorient_PIR: bool = False,
reorient_PIR: bool = False, # noqa: N803
) -> tuple[NII, NII | None, np.ndarray]:
"""Runs nnUnet model inference on one input.
Expand All @@ -91,29 +95,29 @@ def run_inference(
Returns:
Segmentation (NII), Uncertainty Map (NII), Softmax Logits (numpy arr)
"""
if isinstance(input, str):
assert input.endswith(".nii.gz"), f"input file is not a .nii.gz! Got {input}"
input = NII.load(input, seg=False)
if isinstance(input_nii, str):
assert input_nii.endswith(".nii.gz"), f"input file is not a .nii.gz! Got {input_nii}"
input_nii = NII.load(input_nii, seg=False)

assert isinstance(input, NII) or isinstance(input, list), f"input must be a NII or str or list[NII], got {type(input)}"
if isinstance(input, NII):
input = [input]
orientation = input[0].orientation
header = input[0].header
assert isinstance(input_nii, NII | list), f"input must be a NII or str or list[NII], got {type(input_nii)}"
if isinstance(input_nii, NII):
input_nii = [input_nii]
orientation = input_nii[0].orientation
header = input_nii[0].header

img_arrs = []
# Prepare for nnUNet behavior
for i in input:
for i in input_nii:
if reorient_PIR:
i.reorient_()
sitk_nii = sitk_utils.nii_to_sitk(i)
nii_img_converted = i.get_array()
# nii_img_converted = np.pad(nii_img_converted, pad_width=pad_size, mode="edge")
nii_img_converted = np.swapaxes(nii_img_converted, 0, 2)[np.newaxis, :].astype(np.float16)
img_arrs.append(nii_img_converted)
affine = input[0].affine
affine = input_nii[0].affine
img = np.vstack(img_arrs)
zoom = input[0].zoom
zoom = input_nii[0].zoom
props = {
"sitk_stuff": {
# this saves the sitk geometry information. This part is NOT used by nnU-Net!
Expand Down

0 comments on commit 6ae0a50

Please sign in to comment.