diff --git a/spineps/phase_instance.py b/spineps/phase_instance.py index 4a1f383..f619c19 100755 --- a/spineps/phase_instance.py +++ b/spineps/phase_instance.py @@ -212,6 +212,7 @@ def get_corpus_coms( # Check against ivd order seg_sem = seg_nii.map_labels({Location.Endplate.value: Location.Vertebra_Disc.value}, verbose=False) + has_ivd: bool = Location.Vertebra_Disc.value in seg_sem.unique() subreg_cc, subreg_cc_n = seg_sem.get_segmentation_connected_components(labels=Location.Vertebra_Disc.value) subreg_cc = subreg_cc[Location.Vertebra_Disc.value] subreg_cc[subreg_cc > 0] += 100 @@ -242,7 +243,7 @@ def get_corpus_coms( continue nkey = stats_by_height_keys[nidx] - if nkey in stats_by_height and stats_by_height[nkey][1] == is_ivd: + if nkey in stats_by_height and stats_by_height[nkey][1] == is_ivd and has_ivd: neighbor = stats_by_height[nkey] neighborheight = neighbor[0] logger.print( diff --git a/spineps/phase_pre.py b/spineps/phase_pre.py index f92ac61..39214ac 100644 --- a/spineps/phase_pre.py +++ b/spineps/phase_pre.py @@ -14,6 +14,7 @@ def preprocess_input( mri_nii: NII, debug_data: dict, # noqa: ARG001 pad_size: int = 4, + proc_normalize_input: bool = True, proc_do_n4_bias_correction: bool = True, proc_crop_input: bool = True, verbose: bool = False, @@ -24,8 +25,15 @@ def preprocess_input( # Crop Down try: # Enforce to range [0, 1500] - mri_nii.normalize_to_range_(min_value=0, max_value=9000, verbose=logger) - crop = mri_nii.compute_crop(dist=0) if proc_crop_input else (slice(None, None), slice(None, None), slice(None, None)) + if proc_normalize_input: + mri_nii.normalize_to_range_(min_value=0, max_value=9000, verbose=logger) + crop = mri_nii.compute_crop(dist=0) if proc_crop_input else (slice(None, None), slice(None, None), slice(None, None)) + else: + crop = ( + mri_nii.compute_crop(minimum=mri_nii.min(), dist=0) + if proc_crop_input + else (slice(None, None), slice(None, None), slice(None, None)) + ) except ValueError: logger.print("Image Nifty is empty, skip this", Log_Type.FAIL) return None, ErrCode.EMPTY @@ -40,7 +48,8 @@ def preprocess_input( logger.print(f"N4 Bias field correction done in {perf_counter() - n4_start} sec", verbose=True) # Enforce to range [0, 1500] - cropped_nii.normalize_to_range_(min_value=0, max_value=1500, verbose=logger) + if proc_normalize_input: + cropped_nii.normalize_to_range_(min_value=0, max_value=1500, verbose=logger) # Uncrop again # uncropped_input[crop] = cropped_nii.get_array() diff --git a/spineps/seg_enums.py b/spineps/seg_enums.py index 24d89e1..a4ec94a 100755 --- a/spineps/seg_enums.py +++ b/spineps/seg_enums.py @@ -50,6 +50,7 @@ class Modality(Enum_Compare): CT = auto() SEG = auto() MPR = auto() + PD = auto() @classmethod def format_keys(cls, modalities: Self | list[Self]) -> list[str]: diff --git a/spineps/seg_model.py b/spineps/seg_model.py index 37d8666..07097e0 100755 --- a/spineps/seg_model.py +++ b/spineps/seg_model.py @@ -355,8 +355,8 @@ def run( targetc = targetc.to(torch.float32) logits = self.predictor.forward(targetc.to(self.device)) # - except Exception: - # print("Channel-wise model failed, try legacy version") + except Exception as e: + print(f"Channel-wise model failed with {e}, try legacy version") do_backup = True # if do_backup: diff --git a/spineps/seg_run.py b/spineps/seg_run.py index d842c8d..07cb536 100755 --- a/spineps/seg_run.py +++ b/spineps/seg_run.py @@ -250,6 +250,7 @@ def process_img_nii( # noqa: C901 override_postpair: bool = False, override_ctd: bool = False, proc_pad_size: int = 4, + proc_normalize_input: bool = True, # Processings # Semantic proc_sem_crop_input: bool = True, @@ -352,6 +353,9 @@ def process_img_nii( # noqa: C901 done_something = False debug_data_run: dict[str, NII] = {} + if Modality.CT in model_semantic.modalities(): + proc_normalize_input = False # Never normalize input if it is an CT + compatible = check_input_model_compatibility(img_ref, model=model_semantic) if not compatible: if not ignore_compatibility_issues: @@ -378,6 +382,7 @@ def process_img_nii( # noqa: C901 pad_size=proc_pad_size, debug_data=debug_data_run, proc_crop_input=proc_sem_crop_input, + proc_normalize_input=proc_normalize_input, proc_do_n4_bias_correction=proc_sem_n4_bias_correction, verbose=verbose, )