diff --git a/2d_lesion_detection/PR_curve.py b/2d_lesion_detection/PR_curve.py new file mode 100644 index 0000000..e941f2c --- /dev/null +++ b/2d_lesion_detection/PR_curve.py @@ -0,0 +1,111 @@ +""" +Script for generating Precision-Recall curve and PR-AUC + +First, run yolo inference with a low confidence threshold (LOWER_CONF), +then give those predictions as --preds +""" + +import os +from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser +from pathlib import Path +import subprocess +import tempfile +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +LOWER_CONF = 0.01 +UPPER_CONF = 0.5 + +def _main(): + parser = ArgumentParser( + prog = 'PR_curve', + description = 'Generate PR curve and AUC-PR for yolo model', + formatter_class = ArgumentDefaultsHelpFormatter) + parser.add_argument('-g', '--gt-path', + required= True, + type = str, + help = 'Path to YOLO dataset folder of ground truth txt files') + parser.add_argument('-p', '--preds', + required = True, + type = Path, + help = 'Path to prediction folder containing txt files with confidence values.') + parser.add_argument('-c', '--canproco', + required= True, + type = str, + help = 'Path to canproco database') + parser.add_argument('-o', '--output', + required = True, + type = Path, + help = 'Output directory to save the PR curve to.') + parser.add_argument('-i', '--iou', + default= 0.2, + type = str, + help = 'IoU threshold for a TP') + + args = parser.parse_args() + + # Create output folder if it doesn't exist + os.makedirs(args.output, exist_ok=True) + + recalls = [] + precisions = [] + for conf in np.arange(LOWER_CONF, UPPER_CONF, 0.01): + print(f"\n\nComputing metrics for {conf} conf") + with tempfile.TemporaryDirectory() as tmpdir: + (Path(tmpdir)/"preds").mkdir(parents=True, exist_ok=True) + (Path(tmpdir)/"val").mkdir(parents=True, exist_ok=True) + + # 1. Create new txt files with only boxes that have confidence higher than conf + # load predictions + txt_names = os.listdir(args.preds) + txt_paths = [os.path.join(args.preds, file) for file in txt_names if file.endswith(".txt")] # only keep txts + + print("Copying over txt files") + for txt_path in txt_paths: + # For every file create copy but only keeping boxes with confidence higher than conf + with open(txt_path, "r") as infile: + # Read lines from the input file + lines = infile.readlines() + + filtered_lines = [line for line in lines if float(line.split()[-1]) > conf] + + if filtered_lines: + # only create file if there are boxes + filename = Path(txt_path).name + with open(Path(tmpdir)/"preds"/filename, "w") as outfile: + outfile.writelines(filtered_lines) + + # 2. Call validation and get recall and precision + print("Calling validation") + command = ["python", + "validation.py", + "-g", args.gt_path, + "-p", str(Path(tmpdir)/"preds"), + "-o", str(Path(tmpdir)/"val"), + "-c", args.canproco, + "-i", args.iou] + subprocess.run(command, check=True) + + # 3. Get recall and precision and add to dict + print("Getting recall and precision") + df = pd.read_csv(Path(tmpdir)/"val"/"metrics_report.csv") + + # Extract Recall and Precision from the last row + recalls.append(df.iloc[-1]['Recall']) + precisions.append(df.iloc[-1]['Precision']) + + # Plot + plt.plot(recalls, precisions, marker='.') + plt.xlabel('Recall') + plt.ylabel('Precision') + plt.title(f'Precision-Recall Curve with {args.iou} iou threshold') + plt.savefig(args.output/f'precision_recall_curve_{args.iou}iou.png') + + # Calculate PR-AUC + auc_pr = np.trapz(precisions[::-1], recalls[::-1]) + print('Area under Precision-Recall curve (AUC-PR):', auc_pr) + + +if __name__ == "__main__": + _main() \ No newline at end of file diff --git a/2d_lesion_detection/complete_pre_process.py b/2d_lesion_detection/complete_pre_process.py new file mode 100644 index 0000000..b31d6fe --- /dev/null +++ b/2d_lesion_detection/complete_pre_process.py @@ -0,0 +1,100 @@ +""" +Main script for pre-processing +Calls sc_seg_from_list.py, make_yolo_dataset.py and modify_unlabeled_proportion.py + +Generates a YOLO dataset from a list of scans +""" +from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser +from pathlib import Path +import subprocess +import tempfile + + +def call_sc_seg_from_list(json_path:str|Path, database:str|Path): + """ + Calls sc_seg_from_list.py + """ + print("Getting spinal cord segmentation...") + command = [ + "python", + "sc_seg_from_list.py", + "-j", str(json_path), + "-d", str(database) + ] + subprocess.run(command, check=True) + +def call_make_yolo_dataset(json_path:str|Path, + database:str|Path, + output_dir:str|Path): + """ + Calls make_yolo_dataset.py + """ + print("Converting to YOLO format...") + command = [ + "python", + "make_yolo_dataset.py", + "-j", str(json_path), + "-d", str(database), + "-o", str(output_dir) + ] + subprocess.run(command, check=True) + +def call_modify_unlabeled_proportion(input_path:str|Path, + output_path:str|Path, + ratio: str|float): + """ + Calls modify_unlabeled_proportion.py + """ + print("Modifying unlabeled proportion...") + command = [ + "python", + "modify_unlabeled_proportion.py", + "-i", str(input_path), + "-o", str(output_path), + "-r", str(ratio) + ] + subprocess.run(command, check=True) + +def _main(): + parser = ArgumentParser( + prog = 'complete_pre_process', + description = 'Generates YOLO format dataset from a list of scans and a BIDS database.', + formatter_class = ArgumentDefaultsHelpFormatter) + parser.add_argument('-j', '--json-list', + required = True, + type = Path, + help = 'path to json list of scans to process') + parser.add_argument('-d', '--database', + required = True, + type = Path, + help = 'path to BIDS database (canproco)') + parser.add_argument('-o', '--output-dir', + required = True, + type = Path, + help = 'Output directory for YOLO dataset') + parser.add_argument('-r', '--ratio', + default = None, + type = float, + help = 'Proportion of dataset that should be unlabeled. ' + 'By default, the ratio is not modified and the whole dataset is kept.') + + args = parser.parse_args() + + # Make sure all necessary spinal cord segmentations are present + call_sc_seg_from_list(args.json_list, args.database) + + if args.ratio: + # If ratio needs to be modified, call make_yolo_dataset in a temp dir + with tempfile.TemporaryDirectory() as tmpdir: + call_make_yolo_dataset(args.json_list, args.database, Path(tmpdir)/"yolo_dataset") + call_modify_unlabeled_proportion(Path(tmpdir)/"yolo_dataset", args.output_dir, args.ratio) + + else: + # Otherwise, save dataset to output_dir directly + call_make_yolo_dataset(args.json_list, args.database, args.output_dir) + + print(f"Dataset was saved to {args.output_dir}") + + +if __name__ == "__main__": + _main() diff --git a/2d_lesion_detection/data_utils.py b/2d_lesion_detection/data_utils.py new file mode 100644 index 0000000..e2e1d9e --- /dev/null +++ b/2d_lesion_detection/data_utils.py @@ -0,0 +1,211 @@ +""" +Functions for data preprocessing, to generate a YOLO format dataset from a BIDS database. +""" + +import logging +import os +from pathlib import Path + +import cv2 +import torch +import numpy as np +import nibabel as nib +import scipy.ndimage as ndi +from torchvision.ops import masks_to_boxes +from skimage.exposure import equalize_adapthist + + +def nifti_to_png(nifti_path:Path, output_dir:Path, spinal_cord_path:Path=None, slice_list:list=None): + """ + Converts a nifti volume into slices along the sagittal plane and + saves them as png files in specified output_dir + + If spinal_cord_path is given, only slices that contain part of the spinal cord are saved and slice_list is ignored. + If slice_list is given and spinal_cord_path is None, only slices in the given list are saved. slice_list should contain ints. + + Not suitable for segmentations (instead use nifti_seg_to_png) because + of intensity normalization between 0 and 255 + + png images are named _ + For example, if nifti file sub-cal056_ses-M12_STIR.nii.gz is given, + the first slice will be named sub-cal056_ses-M12_STIR_0.png + + Args: + nifti_path (pathlib.Path) : path to nifti file + output_dir (pathlib.Path) : path to the directory where png slices will be saved + spinal_cord_path (pathlib.Path) : path to the spinal cord segmentation file (optional) + + adapted from https://neuraldatascience.io/8-mri/nifti.html#plot-a-series-of-slices + """ + # Make output directory if it doesn't already exist + output_dir.mkdir(parents=True, exist_ok=True) + + filename = nifti_path.stem[:-len(".nii")] + + volume = nib.load(nifti_path) + vol_data = volume.get_fdata() # np array + + if spinal_cord_path is None: + sc_seg_data = None + else: + sc_seg = nib.load(spinal_cord_path) + sc_seg_data = sc_seg.get_fdata() + + # Normalize pixel intensity from 0 to 255 + vol_data = (vol_data - np.amin(vol_data)) * (255 / (np.amax(vol_data) - np.amin(vol_data))) + vol_data = np.round(vol_data) + + n_slice = vol_data.shape[2] + + for i in range(n_slice): + # if spinal cord segmentation is given, check if slice contains spinal cord + if not sc_seg_data is None: + sc_seg_slice = sc_seg_data[:, :, i] + + if sc_seg_slice.max() == 1: + # slice contains spinal cord + output_path = os.path.join(str(output_dir), f"{filename}_{i}.png") + img_slice = np.clip(ndi.rotate(vol_data[:, :, i], 90) / 255.0, 0, 1) # make sure it's between 0 and 1 for histogram equalization + + cv2.imwrite(output_path, equalize_adapthist(img_slice)*255) # equalize histogram + else: + assert(sc_seg_slice.max() == 0) + else: + # if no segmentation is given + # save slice if slice is in slice_list OR if slice_list is None + if slice_list is None or i in slice_list: + output_path = os.path.join(str(output_dir), f"{filename}_{i}.png") + cv2.imwrite(output_path, ndi.rotate(vol_data[:, :, i], 90)) + + +def mask_to_bbox(mask:np.ndarray) -> "np.ndarray|None": + """ + Extracts bounding box coordinates for each object in a binary tensor + + Bounding box coordinates are in format (x_center, y_center, width, height) + with normalized cooordinates (between 0 and 1) + + Args: + mask (np.ndarray): binary mask to get bboxes from + + Returns: + boxes_array (np.ndarray|None): array containing bounding box coordinates for each object + if no object is detected, None is returned + + """ + # Check if we have a binary mask + try: + assert np.all(np.logical_or(mask == 0, mask == 1)) + except AssertionError as e: + logging.warning(f"{e}: A binary mask is expected, but given mask is not.") + + width = mask.shape[1] + height = mask.shape[0] + + # Separate each object + labeled_array, num_labels = ndi.label(mask) + + if num_labels == 0: # No objects + return None + + # List to store bounding boxes for each lesion + boxes = [] + + # Loop over each labeled region + for label in range(1, num_labels + 1): + # Create a boolean mask for the current object + obj_mask = labeled_array == label + + # Compute the bounding box for the current object + # Returns format (x1, y1, x2, y2) in pixels + obj_box = masks_to_boxes(torch.from_numpy(obj_mask).unsqueeze(0))[0] + + # Add to list + boxes.append(obj_box) + + # Convert list of tensors to an array + boxes_array = np.array([box.numpy() for box in boxes]) + + # Convert coordinates format to (x_center, y_center, width, height) and normalize + boxes_array = convert_bboxes_format(boxes_array, width, height) + + return boxes_array + + +def convert_bboxes_format(bboxes:np.ndarray, img_width:int, img_height:int) -> np.ndarray: + """ + Converts bounding box format from (x1, y1, x2, y2) in pixels to + (x_center, y_center, width, height) normalized between 0 and 1 + + Args: + bboxes (np.ndarray): Bounding box to convert + img_width (int): Corresponding image width (px) + img_height (int): Corresponding image height (px) + + Returns: + (np.ndarray): converted coordinates + """ + # Extract coordinates from the input array + x_start, y_start, x_end, y_end = bboxes[:,0], bboxes[:,1], bboxes[:,2], bboxes[:,3] + + # Calculate center coordinates + x_center = (x_start + x_end) / 2 + y_center = (y_start + y_end) / 2 + + # Calculate width and height of boxes + width = x_end - x_start + height = y_end - y_start + + # Normalize + x_center = x_center/img_width + width = width/img_width + y_center = y_center/img_height + height = height/img_height + + # Stack the converted coordinates and sizes into a single array + return np.stack((x_center, y_center, width, height), axis=-1) + + +def labels_from_nifti(nifti_labels_path:Path, output_dir:Path)->Path: + """ + Creates txt files containing bounding box coordinates for each slice in a nifti segmentation + If no bounding box is found for a given slice, no txt file is created + + txt filenames correspond to the nifti name: + For example, if sub-cal056_ses-M12_STIR_lesion-manual.nii.gz is given as input, the txt file for the + first slice will be named sub-cal056_ses-M12_STIR_0.txt + + Args: + nifti_labels_path (pathlib.Path): Path to the nifti file containing the segmentation + output_dir (pathlib.Path): Path to the directory where txt files will be saved + """ + # Make output directory if it doesn't already exist + output_dir.mkdir(parents=True, exist_ok=True) + + filename = nifti_labels_path.stem[:-len("_lesion_manual.nii")] + + # get nifti volume as numpy array + volume = nib.load(nifti_labels_path) + vol_array = volume.get_fdata() + + # For each slice, extract bounding boxes and save to txt file + n_slice = vol_array.shape[2] + for i in range(n_slice): + slice_array = np.round(ndi.rotate(vol_array[:, :, i], 90)) + + # Get bounding boxes + boxes_array = mask_to_bbox(slice_array) + + if not boxes_array is None: + # Add a column for class 0 + # Since we only have one type of object to detect (lesion) + column_of_zeros = np.zeros((boxes_array.shape[0], 1)) + boxes_array = np.concatenate((column_of_zeros, boxes_array), axis=1) + + # Save to output directory + output_path = os.path.join(str(output_dir), filename + f"_{i}.txt") + np.savetxt(output_path, boxes_array, fmt = ['%g', '%.6f', '%.6f', '%.6f', '%.6f']) + + else: + # If no object is found, no txt file is generated + pass \ No newline at end of file diff --git a/2d_lesion_detection/default_train_params.json b/2d_lesion_detection/default_train_params.json new file mode 100644 index 0000000..1fbc041 --- /dev/null +++ b/2d_lesion_detection/default_train_params.json @@ -0,0 +1,16 @@ + +{ + "epochs": 150, + "lr0": 0.09, + "lrf": 0.08, + "box": 15.6, + "cls": 4.1, + "mosaic": 0, + "hsv_s": 0, + "hsv_h": 0, + "hsv_v": 0.45, + "degrees": 10, + "scale": 0.5, + "fliplr": 0.25, + "translate": 0.25 +} diff --git a/2d_lesion_detection/default_tune_params.json b/2d_lesion_detection/default_tune_params.json new file mode 100644 index 0000000..35a26f1 --- /dev/null +++ b/2d_lesion_detection/default_tune_params.json @@ -0,0 +1,15 @@ +{ + "lr0": [0.0001, 0.5], + "lrf": [0.01, 1.0], + "degrees": [0, 20.0], + "scale": [0.5, 0.9], + "fliplr": [0.0, 1.0], + "translate": [0.0, 0.9], + "hsv_v": [0.0, 0.9], + "box": [0.02, 0.2], + "cls": [0.2, 4.0], + "mosaic": 0, + "hsv_h": 0, + "hsv_s": 0, + "iterations": 16 +} \ No newline at end of file diff --git a/2d_lesion_detection/make_yolo_dataset.py b/2d_lesion_detection/make_yolo_dataset.py new file mode 100644 index 0000000..9e9b6cc --- /dev/null +++ b/2d_lesion_detection/make_yolo_dataset.py @@ -0,0 +1,174 @@ +""" +Generates a YOLO format dataset from a BIDS database. +Only processes scans in the given json list. + +Here is how the json list should be formatted: + {"train": ["sub-cal056_ses-M12_STIR", + "sub-edm011_ses-M0_PSIR", + "sub-cal072_ses-M0_STIR"], + "val": ["sub-cal157_ses-M12_STIR"], + "test": ["sub-edm076_ses-M0_PSIR"]} + + +The YOLO dataset is formatted as follows: + dataset/ + │ + ├── images/ + │ ├── train/ + │ │ ├── sub-cal056_ses-M12_STIR_0.png + │ │ ├── sub-cal056_ses-M12_STIR_1.png + │ │ └── ... + │ ├── val/ + │ │ ├── sub-tor006_ses-M12_PSIR_0.png + │ │ └── ... + │ └── test/ + │ ├── sub-tor007_ses-M12_PSIR_0.png + │ └── ... + │ + ├── labels/ + │ ├── train/ + │ │ ├── sub-cal056_ses-M12_STIR_0.txt + │ │ ├── sub-cal056_ses-M12_STIR_1.txt + │ │ └── ... + │ ├── val/ + │ │ ├── sub-tor006_ses-M12_PSIR_0.txt + │ │ └── ... + │ └── test/ + │ ├── sub-tor007_ses-M12_PSIR_0.txt + │ └── ... + │ + └── data.yaml + +""" +import logging +from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser +from pathlib import Path + +import glob +import json +import ruamel.yaml + +from data_utils import nifti_to_png, labels_from_nifti + +logging.basicConfig(filename='pre_process_warning.log', level=logging.WARNING, filemode='w') + +def process_scan(nii_volume:str, output_dir:Path, database:Path, set_name:str): + """ + For every slice in input scan: + - Extracts bounding boxes from label file + - Saves bounding box coordinates in .txt file + - Saves image slice as .png file + + Output files are saved in the specified set's folder (test, train or val) within the output_dir. + + Args: + nii_volume (str): Name of the scan, including session number and contrast type + example format --> sub-cal056_ses-M12_STIR + output_dir (Path): Path to the YOLO dataset folder + database (Path): Path to the BIDS database + set_name (str): Dataset name -- must be one of ["train", "test", "val"] + """ + + patient = nii_volume.split("_")[0] + ses = nii_volume.split("_")[1] + + print(f"Processing scan {nii_volume}") + + # 1- get bounding boxes from segmentation and save to txt file + # Check if scan has already been processed + label_pattern = output_dir / "labels"/ set_name/(nii_volume+"*") + matching_slices = glob.glob(str(label_pattern)) + + if matching_slices == []: # if no slices are found, process scan + lesion_nii_path = database/ "derivatives"/ "labels"/ patient/ ses/ "anat"/ (nii_volume+"_lesion-manual.nii.gz") + labels_from_nifti(lesion_nii_path, output_dir / "labels"/ set_name) + + # 2- save spinal cord slices as pngs + # Check if scan has already been processed + img_pattern = output_dir / "images"/ set_name/(nii_volume+"*") + matching_slices = glob.glob(str(img_pattern)) + + if matching_slices == []: # if no slices are found, process scan + image_nii_path = database/ patient/ ses/ "anat"/ (nii_volume+".nii.gz") + spinal_cord_nii_path = database/ "derivatives"/ "labels"/ patient/ ses/ "anat"/ (nii_volume+"_seg-manual.nii.gz") + nifti_to_png(image_nii_path, output_dir / "images"/ set_name, spinal_cord_nii_path) + + # Check that all txt files have a corresponding png + # This was implemented after noticing that some sc segmentations were blank + list_of_img_slices = [Path(file).name.replace(".png","") for file in glob.glob(str(img_pattern))] + + for file in glob.glob(str(label_pattern)): + filename = Path(file).name.replace(".txt","") + + try: + assert filename in list_of_img_slices + except AssertionError as e: + # If a txt file has no corresponding png, all slices from that volume are saved + # and a warning is logged + logging.warning(f"{e}: {filename} has a segmentation file, but no corresponding image. Saving all slices") + parts = filename.split('_') + nii_name = '_'.join(parts[:-1]) + + image_nii_path = database/ filename.split("_")[0]/ ses/ "anat"/ (nii_name+".nii.gz") + nifti_to_png(image_nii_path, output_dir / "images"/ set_name) #save all slices + + +def _main(): + parser = ArgumentParser( + prog = 'make_yolo_dataset', + description = 'Generates YOLO format dataset from a list of scans and a BIDS database.', + formatter_class = ArgumentDefaultsHelpFormatter) + parser.add_argument('-j', '--json-list', + required = True, + type = Path, + help = 'Path to json list of images to process') + parser.add_argument('-d', '--database', + required = True, + type = Path, + help = 'Path to BIDS database') + parser.add_argument('-o', '--output-dir', + required = True, + type = Path, + help = 'Output directory for YOLO dataset') + + + args = parser.parse_args() + dir_name = args.output_dir.name + + with open(args.json_list, "r") as json_file: + data = json.load(json_file) + + training_list = data["train"] + validation_list = data["val"] + test_list = data["test"] + + for volume in training_list: + process_scan(volume, args.output_dir, args.database, "train") + + for volume in validation_list: + process_scan(volume, args.output_dir, args.database, "val") + + for volume in test_list: + process_scan(volume, args.output_dir, args.database, "test") + + + # Create yml file + yml_str = f"""\ + path: "{dir_name}" + train: "images/train" + val: "images/val" + test: "images/test" + + nc: 1 + names: ["lesion"] + """ + + yaml = ruamel.yaml.YAML(pure=True) + yaml.preserve_quotes = True + data = yaml.load(yml_str) + + yaml.dump(data, args.output_dir/(dir_name+".yml")) + + +if __name__ == "__main__": + _main() diff --git a/2d_lesion_detection/modify_unlabeled_proportion.py b/2d_lesion_detection/modify_unlabeled_proportion.py new file mode 100644 index 0000000..50d84a3 --- /dev/null +++ b/2d_lesion_detection/modify_unlabeled_proportion.py @@ -0,0 +1,87 @@ +""" +Script for changing the proportion of unlabeled slices (i.e. slices that contain no lesion) +in the train set from a yolo dataset +""" +from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser +import os +from pathlib import Path +import random +import shutil +import ruamel.yaml + + +def _main(): + parser = ArgumentParser( + prog = 'modify_unlabeled_proportion', + description = 'From a yolo dataset, change the proportion of slices that ' + 'contain no lesion (unlabeled) in the train set', + formatter_class = ArgumentDefaultsHelpFormatter) + parser.add_argument('-i', '--input-path', + required = True, + type = Path, + help = 'Path to existing YOLO database') + parser.add_argument('-o', '--output-path', + required = True, + type = Path, + help = 'Path to new YOLO database') + parser.add_argument('-r', '--ratio', + default = 0.25, + type = float, + help = 'Proportion of dataset that should be unlabeled') + + args = parser.parse_args() + + all_train = os.listdir(args.input_path/"images"/"train") + all_train = [file for file in all_train if not file.startswith(".")] #remove hidden files + all_train = [filename.replace(".png", "") for filename in all_train] #remove extension + + all_labels = os.listdir(args.input_path/"labels"/"train") + all_labels = [file for file in all_labels if not file.startswith(".")] #remove hidden files + all_labels = [filename.replace(".txt", "") for filename in all_labels] #remove extension + + all_unlabelled = [filename for filename in all_train if filename not in all_labels] + + random.seed(10) + n_unlabelled_to_copy = (len(all_labels)*args.ratio)/(1-args.ratio) + unlabelled_to_copy = random.sample(all_unlabelled, int(n_unlabelled_to_copy)) + + # Transfer files to new dataset folder + shutil.copytree(args.input_path/"labels", args.output_path/"labels") # all labels + shutil.copytree(args.input_path/"images"/"test", args.output_path/"images"/"test") # images in test + shutil.copytree(args.input_path/"images"/"val", args.output_path/"images"/"val") # images in val + + os.makedirs(args.output_path/"images"/"train", exist_ok=True) + # copy over the unlabelled images that have been selected + for file in unlabelled_to_copy: + print(f"Copying over {file}") + source_file_path = args.input_path/"images"/"train"/(file + ".png") + destination_file_path = args.output_path/"images"/"train"/(file + ".png") + shutil.copy(source_file_path, destination_file_path) + + # copy over all images that are labelled + for file in all_labels: + print(f"Copying over {file}") + source_file_path = args.input_path/"images"/"train"/(file + ".png") + destination_file_path = args.output_path/"images"/"train"/(file + ".png") + shutil.copy(source_file_path, destination_file_path) + + + # create new yaml file + yml_str = f"""\ + path: "{args.output_path.name}" + train: "images/train" + val: "images/val" + test: "images/test" + + nc: 1 + names: ["lesion"] + """ + + yaml = ruamel.yaml.YAML(pure=True) + yaml.preserve_quotes = True + data = yaml.load(yml_str) + + yaml.dump(data, args.output_path/(args.output_path.name+".yml")) + +if __name__ == "__main__": + _main() diff --git a/2d_lesion_detection/requirements.txt b/2d_lesion_detection/requirements.txt new file mode 100644 index 0000000..64157ea --- /dev/null +++ b/2d_lesion_detection/requirements.txt @@ -0,0 +1,106 @@ +appnope @ file:///home/conda/feedstock_root/build_artifacts/appnope_1707233003401/work +asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1698341106958/work +attrs==23.2.0 +certifi @ file:///private/var/folders/c_/qfmhj66j0tn016nkx_th4hxm0000gp/T/abs_dfez5lkpj0/croot/certifi_1707229180975/work/certifi +charset-normalizer @ file:///tmp/build/80754af9/charset-normalizer_1630003229654/work +clearml==1.14.2 +comm @ file:///home/conda/feedstock_root/build_artifacts/comm_1704278392174/work +contourpy==1.2.0 +cycler==0.12.1 +debugpy @ file:///Users/builder/cbouss/perseverance-python-buildout/croot/debugpy_1699267934478/work +decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1641555617451/work +defusedxml==0.7.1 +exceptiongroup @ file:///home/conda/feedstock_root/build_artifacts/exceptiongroup_1704921103267/work +executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1698579936712/work +filelock @ file:///Users/builder/cbouss/perseverance-python-buildout/croot/filelock_1701807880662/work +fonttools==4.48.1 +fsspec==2024.2.0 +furl==2.1.3 +idna @ file:///Users/builder/cbouss/perseverance-python-buildout/croot/idna_1699239456905/work +imageio==2.33.1 +importlib-metadata @ file:///home/conda/feedstock_root/build_artifacts/importlib-metadata_1703269254275/work +iniconfig @ file:///home/linux1/recipes/ci/iniconfig_1610983019677/work +ipykernel @ file:///Users/runner/miniforge3/conda-bld/ipykernel_1707326353119/work +ipython @ file:///home/conda/feedstock_root/build_artifacts/ipython_1706795662110/work +jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1696326070614/work +Jinja2 @ file:///Users/builder/cbouss/perseverance-python-buildout/croot/jinja2_1707339417889/work +joblib @ file:///Users/builder/cbouss/perseverance-python-buildout/croot/joblib_1699246281347/work +jsonschema==4.21.1 +jsonschema-specifications==2023.12.1 +jupyter_client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1699283905679/work +jupyter_core @ file:///Users/builder/cbouss/perseverance-python-buildout/croot/jupyter_core_1701805380453/work +kiwisolver==1.4.5 +lazy_loader==0.3 +MarkupSafe @ file:///Users/builder/cbouss/perseverance-python-buildout/croot/markupsafe_1707339363573/work +matplotlib==3.8.2 +matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1660814786464/work +med2image==2.6.6 +mkl-fft @ file:///Users/builder/cbouss/perseverance-python-buildout/croot/mkl_fft_1699240945874/work +mkl-random @ file:///Users/builder/cbouss/perseverance-python-buildout/croot/mkl_random_1699239515205/work +mkl-service==2.4.0 +monai==1.3.0 +mpmath @ file:///Users/builder/cbouss/perseverance-python-buildout/croot/mpmath_1699248969763/work +nest_asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1705850609492/work +networkx @ file:///Users/builder/cbouss/perseverance-python-buildout/croot/networkx_1699249062431/work +nibabel==5.2.0 +numpy @ file:///private/var/folders/c_/qfmhj66j0tn016nkx_th4hxm0000gp/T/abs_b7iptlxgej/croot/numpy_and_numpy_base_1708638622773/work/dist/numpy-1.26.4-cp312-cp312-macosx_10_9_x86_64.whl#sha256=94ff619271f410d5b23b2f118daf4420d22b48b21c941819693d61b19667808d +opencv-python==4.9.0 +opencv-python-headless==4.9.0.80 +orderedmultidict==1.0.1 +packaging @ file:///home/conda/feedstock_root/build_artifacts/packaging_1696202382185/work +pandas @ file:///Users/runner/miniforge3/conda-bld/pandas_1705728473605/work +parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1638334955874/work +pathlib2==2.3.7.post1 +pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1706113125309/work +pfmisc==2.2.14 +pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1602536217715/work +pillow @ file:///Users/runner/miniforge3/conda-bld/pillow_1704252132396/work +platformdirs @ file:///home/conda/feedstock_root/build_artifacts/platformdirs_1706713388748/work +pluggy @ file:///Users/builder/cbouss/perseverance-python-buildout/croot/pluggy_1699238591533/work +prompt-toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1702399386289/work +psutil @ file:///Users/builder/cbouss/perseverance-python-buildout/croot/psutil_1699246686105/work +ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1609419310487/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl +pudb==2024.1 +pure-eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1642875951954/work +py-cpuinfo==9.0.0 +pydicom==2.4.4 +Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1700607939962/work +PyJWT==2.8.0 +pyparsing==3.1.1 +pytest @ file:///Users/builder/cbouss/perseverance-python-buildout/croot/pytest_1699239314576/work +python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/python-dateutil_1626286286081/work +pytz @ file:///home/conda/feedstock_root/build_artifacts/pytz_1706886791323/work +PyYAML @ file:///Users/builder/cbouss/perseverance-python-buildout/croot/pyyaml_1699244222906/work +pyzmq @ file:///private/var/folders/c_/qfmhj66j0tn016nkx_th4hxm0000gp/T/abs_c7mhasdker/croot/pyzmq_1705605099593/work +referencing==0.33.0 +requests @ file:///private/var/folders/sy/f16zz6x50xz3113nwtb9bvq00000gp/T/abs_4fp_3tflo8/croot/requests_1707355574897/work +rpds-py==0.17.1 +ruamel.yaml==0.18.6 +ruamel.yaml.clib==0.2.8 +scikit-image==0.22.0 +scikit-learn @ file:///Users/runner/miniforge3/conda-bld/scikit-learn_1708076412806/work +scipy @ file:///private/var/folders/sy/f16zz6x50xz3113nwtb9bvq00000gp/T/abs_016_yoiwt8/croot/scipy_1710947331018/work/dist/scipy-1.12.0-cp312-cp312-macosx_10_9_x86_64.whl#sha256=91cbccebfd01ecb647349dbd83a211ec85c8cbd21d783ffc23c2d548a778177d +seaborn==0.13.2 +setuptools==68.2.2 +six @ file:///home/conda/feedstock_root/build_artifacts/six_1620240208055/work +stack-data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1669632077133/work +supervision==0.18.0 +sympy @ file:///private/var/folders/sy/f16zz6x50xz3113nwtb9bvq00000gp/T/abs_94zd__m8ro/croot/sympy_1701397645431/work +thop==0.1.1.post2209072238 +threadpoolctl @ file:///Users/ktietz/demo/mc3/conda-bld/threadpoolctl_1629802263681/work +tifffile==2024.1.30 +torch==2.2.1 +torchaudio==2.2.1 +torchvision==0.17.1 +tornado @ file:///Users/builder/cbouss/perseverance-python-buildout/croot/tornado_1699250848162/work +tqdm==4.66.2 +traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1704212992681/work +typing_extensions @ file:///private/var/folders/sy/f16zz6x50xz3113nwtb9bvq00000gp/T/abs_4eenipoqj8/croot/typing_extensions_1705619919539/work +tzdata @ file:///home/conda/feedstock_root/build_artifacts/python-tzdata_1703878702368/work +ultralytics==8.1.11 +urllib3 @ file:///Users/builder/cbouss/perseverance-python-buildout/croot/urllib3_1701808673422/work +urwid==2.5.2 +urwid_readline==0.13 +wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1704731205417/work +wheel==0.41.2 +zipp @ file:///home/conda/feedstock_root/build_artifacts/zipp_1695255097490/work diff --git a/2d_lesion_detection/sc_seg_from_list.py b/2d_lesion_detection/sc_seg_from_list.py new file mode 100644 index 0000000..b125151 --- /dev/null +++ b/2d_lesion_detection/sc_seg_from_list.py @@ -0,0 +1,115 @@ +""" +Tool for generating missing spinal cord segmentations for a list of scans. +For every scan in the given list, checks if a sc segmentation file exists in the BIDS database. +If the file is missing, generates it using the spinal cord toolbox. + +Make sure to activate the sct env before running: +source ~/sct_6.2/python/envs/venv_sct/bin/activate +""" + +from pathlib import Path +from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser +import subprocess +import logging + +from typing import List, Dict +import json + +logging.basicConfig(filename='sc_seg_warning.log', level=logging.WARNING, filemode='w') + +def segment_spinal_cord(image_path:str, labels_path:str, seg_name:str): + """ + Uses the spinal cord toolbox to generate a spinal cord segmentation file. + + Args: + image_path (str): Path to the image containing the spinal cord + labels_path (str): Path to the corresponding labels folder. + This is where the segmentation file will be saved. + seg_name (str): Name of the segmentation file + """ + + command = [ + "sct_deepseg_sc", + "-i", image_path, + "-c", "t2", + "-ofolder", labels_path, + "-o", seg_name + ] + subprocess.run(command, check=True) + +def remove_errors_from_json(json_data:Dict[str, str], remove_list:List[str]): + """ + Removes volumes that are in remove_list from json_data + + Args: + json_data (Dict[str, str]): Dictionary of volume names in each dataset (contents of json list) + remove_list (List[str]): List of volume names to remove from json_data + + Returns: + json_data (Dict[str, str]): Modified dictionary of volume names in each dataset + without the volumes in remove_list + """ + for key, value in json_data.items(): + # Check if the current value is a list + if isinstance(value, list): + # Remove names from the list if they exist + json_data[key] = [item for item in value if item not in remove_list] + return json_data + +def main(): + parser = ArgumentParser( + prog = 'sc_seg_from_list', + description = 'Generate missing spinal cord segmentations for the given list of scans.', + formatter_class = ArgumentDefaultsHelpFormatter) + parser.add_argument('-j', '--json-list', + required = True, + type = Path, + help = 'path to json list of scans to process') + parser.add_argument('-d', '--database', + required = True, + type = Path, + help = 'path to BIDS database') + + args = parser.parse_args() + + with open(args.json_list, "r") as json_file: + data = json.load(json_file) + + scan_list = [item for sublist in data.values() for item in sublist] + + error_list = [] + for scan in scan_list: + patient = scan.split("_")[0] + ses = scan.split("_")[1] + + spinal_seg_nii_path = args.database/ "derivatives"/ "labels"/ patient/ ses/ "anat"/ (scan+"_seg-manual.nii.gz") + + if not spinal_seg_nii_path.exists(): + image_path = args.database/ patient/ ses/ "anat"/ (scan+".nii.gz") + labels_path = spinal_seg_nii_path.parent + seg_name = spinal_seg_nii_path.name + + # When the segmentation fails, add name to error_list and skip it + try: + segment_spinal_cord(image_path, labels_path, seg_name) + except Exception as e: + logging.warning(f"Error processing scan {scan}: {e}") + error_list.append(scan) + + # if error_list isn't empty + if error_list: + # If there are errors, create new json without the filenames that had errors + data = remove_errors_from_json(data, error_list) + output_file_path = args.json_list.parent/(args.json_list.stem +"_with_sc_seg.json") + with open(output_file_path, "w") as output_file: + json.dump(data, output_file, indent=4) + + print(f"Done! Here are the scans that couldn't be processed: {error_list}. See sc_seg_warning.log for more info" + f"\n\nAn updated json list that excludes the error files has been saved: {output_file_path}") + + else: + print("Done! There were no errors") + + +if __name__ == "__main__": + main() diff --git a/2d_lesion_detection/tests/__init__.py b/2d_lesion_detection/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/2d_lesion_detection/tests/test_data_utils.py b/2d_lesion_detection/tests/test_data_utils.py new file mode 100644 index 0000000..3081edb --- /dev/null +++ b/2d_lesion_detection/tests/test_data_utils.py @@ -0,0 +1,109 @@ +""" +Unit tests for functions used in data_utils.py +""" +import os +import logging +import tempfile +from pathlib import Path + +import numpy as np +import nibabel as nib + +from data_utils import nifti_to_png, mask_to_bbox, convert_bboxes_format + +def test_nifti_to_png(): + """ + Makes sure that only the correct slices are saved, but does not check the content of the saved pngs + """ + # Create nifti image + image_data = np.zeros((3, 3, 3)) + image_data[:,:,0] = image_data[:,:,0] * 500 + image_data[:,:,1] = image_data[:,:,1] * 1000 + + # Create spinal cord seg mask + # Adding white pixel to slices 0 and 2 + sc_data = np.zeros((3, 3, 3)) + sc_data[1,1,0] = 1 + sc_data[1,1,2] = 1 + + with tempfile.TemporaryDirectory() as tmpdir: + # Save nifti + image_path = Path(tmpdir)/'image.nii.gz' + nifti_image = nib.Nifti1Image(image_data, affine=np.eye(4)) + nib.save(nifti_image, image_path) + + sc_seg_path = Path(tmpdir)/'sc_seg.nii.gz' + nifti_sc_seg = nib.Nifti1Image(sc_data, affine=np.eye(4)) + nib.save(nifti_sc_seg, sc_seg_path) + + # With sc_seg -- saved slices should be 0 and 2 + nifti_to_png(image_path, Path(tmpdir)/'with_sc_seg', sc_seg_path) + txt_names = os.listdir(Path(tmpdir)/'with_sc_seg') + + assert len(txt_names) == 2 + assert "image_0.png" in txt_names + assert "image_2.png" in txt_names + assert "image_1.png" not in txt_names + + # With list of slices to save + nifti_to_png(image_path, Path(tmpdir)/'with_slice_list', slice_list=[1,2]) + txt_names = os.listdir(Path(tmpdir)/'with_slice_list') + + assert len(txt_names) == 2 + assert "image_1.png" in txt_names + assert "image_2.png" in txt_names + assert "image_0.png" not in txt_names + + +def test_mask_to_bbox_warning(caplog): + """ + Give a non binary mask to mask_to_bbox(), a warning should be logged + """ + mask = np.zeros((3, 3)) + mask[1,:] = 3 + + caplog.set_level(logging.WARNING) + mask_to_bbox(mask) + + # Make sure a warning was logged + assert len(caplog.records) == 1 + + # Check content of the warning + warning = caplog.records[0] + assert "A binary mask is expected" in warning.message + + +def test_mask_to_bbox(): + """ + Give a mask to mask_to_bbox() and check that bbox is correct + """ + mask = np.array([[0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 1, 1, 0, 0, 0], + [0, 1, 1, 0, 1, 0], + [0, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0]]) + + expected = np.array([[1.5, 3, 1, 2],[4, 3, 0, 0]])/6 + assert np.allclose(mask_to_bbox(mask), expected, atol=0.001) + + +def test_convert_bboxes_format(): + """ + Convert a bounding box in x1,y1,x2,y2 format to + x_center,y_center,width,height normalized + """ + xyxy = np.array([[1, 1, 4, 3]]) + xywh = np.array([[2.5, 2, 3, 2]]) + + width = 10 + height = 5 + + xywhn = np.array([[0, 0, 0, 0]], dtype=np.float64) + xywhn[0][0] = xywh[0][0]/width + xywhn[0][1] = xywh[0][1]/height + xywhn[0][2] = xywh[0][2]/width + xywhn[0][3] = xywh[0][3]/height + + assert np.array_equal(convert_bboxes_format(xyxy, width, height), xywhn) + diff --git a/2d_lesion_detection/tests/test_validation.py b/2d_lesion_detection/tests/test_validation.py new file mode 100644 index 0000000..372888b --- /dev/null +++ b/2d_lesion_detection/tests/test_validation.py @@ -0,0 +1,124 @@ +""" +Unit tests for functions used in validation.py +""" + +import torch +from torch import tensor + +from validation import confusion_matrix, xywhn_to_xyxy, get_png_from_txt, merge_overlapping_boxes, intersection_over_smallest_area + +def test_merge_overlapping_boxes(): + """ + Define two boxes and merge them with two different threshold values. + Make sure merged boxes are as expected. + """ + # Define two overlapping boxes (x1,y1,x2,y2) + box1= tensor([3,3,2,2,6,6], dtype=torch.int32) + box2= tensor([4,5,3,3,6,7], dtype=torch.int32) + + boxes = torch.stack((box1,box2)) + + # Get merged boxes with two different IoUs + merged_25 = merge_overlapping_boxes(boxes, 0.25) + merged_75 = merge_overlapping_boxes(boxes, 0.80) + + # Boxes should be merged with IoU 25 + # but not with IoU 50 + assert torch.equal(merged_25[0], tensor([3,5,2,2,6,7], dtype=torch.int32)) + assert torch.equal(merged_75[0], boxes[0]) + assert torch.equal(merged_75[1], boxes[1]) + + +def test_intersection_over_smallest_area(): + """ + Define 3 boxes and make sure intersection over smallest area values are + as expected. + """ + # Define overlapping boxes (s0,sf,x1,y1,x2,y2) + # s0 and sf are irrelevant, iosa is calculated in 2d + box1= tensor([3,3,2,2,6,6], dtype=torch.int32) + box2= tensor([3,3,3,3,6,7], dtype=torch.int32) + box3= tensor([3,3,6,2,8,8], dtype=torch.int32) + + iosa12 = intersection_over_smallest_area(box1, box2) + iosa13 = intersection_over_smallest_area(box1, box3) + + assert iosa12 == 0.75 + assert iosa13 == 0 + + +def test_confusion_matrix(): + """ + Define two lists of bounding boxes: ground truth (labels) and predictions (preds) + Calculate tp, fn, fp values with different thresholds and make sure values are as expected. + """ + labels= tensor([[2,2,1,1,5,6], [2,3,8,4,8,6]], dtype=torch.int32) + preds= tensor([[2,2,1,2,6,7], [2,3,8,3,9,5], [3,5,8,4,8,6]], dtype=torch.int32) + + # With 0.5 iou threshold + tp, fn, fp = confusion_matrix(labels, preds, 0.5) + assert tp == 1 + assert fn == 1 + assert fp == 2 + + # With 0.1 iou threshold + tp, fn, fp = confusion_matrix(labels, preds, 0.1) + assert tp == 3 + assert fn == 0 + assert fp == 0 + + # With 0.6 iou threshold + tp, fn, fp = confusion_matrix(labels, preds, 0.8) + assert tp == 0 + assert fn == 2 + assert fp == 3 + + +def test_xywhn_to_xyxy(): + """ + Define image size and corresponding coordinates in center format + Convert to corner format and make sure result is as expected + + Repeat for smaller boxes + """ + img_width = 10 + img_height = 20 + center_coords = torch.tensor([[4/img_width, # x_center + 7/img_height, # y_center + 4/img_width, # width + 6/img_height]]) # height + + corners = xywhn_to_xyxy(center_coords, img_width, img_height) + + assert torch.equal(corners, tensor([[2,4,6,10]])) + + img_height = 5 + img_width = 7 + center_coords = torch.tensor([[1.5/img_width, + 2/img_height, + 1/img_width, + 2/img_height]]) + + corners = xywhn_to_xyxy(center_coords, img_width, img_height) + + assert torch.equal(corners, tensor([[1,1,2,3]])) + + +def test_get_png_from_txt(): + """ + Define a labels folder path + Get corresponding images folder + + Repeat with a label txt file + The corresponding images path should be a png file + """ + # If input is a folder + label_folder = "~/data/yolo_training/dataset_1/labels/test" + + assert get_png_from_txt(label_folder) == "~/data/yolo_training/dataset_1/images/test" + + # If input is a file + txt_file = "~/data/yolo_training/dataset_1/labels/test/sub-cal080_ses-M0_STIR_2.txt" + + assert get_png_from_txt(txt_file) == "~/data/yolo_training/dataset_1/images/test/sub-cal080_ses-M0_STIR_2.png" + diff --git a/2d_lesion_detection/train_test_val_from_BIDS.py b/2d_lesion_detection/train_test_val_from_BIDS.py new file mode 100644 index 0000000..0df87c8 --- /dev/null +++ b/2d_lesion_detection/train_test_val_from_BIDS.py @@ -0,0 +1,93 @@ +""" +Makes a list of all the scans in BIDS database that have a lesion segmentation file (*lesion-manual.nii.gz), +Splits that list into 3 datasets (train, test and val). +Saves a json file containing the lists of test, train and val datasets. + +Here is an example of the json format: + {"train": ["sub-cal056_ses-M12_STIR", + "sub-edm011_ses-M0_PSIR", + "sub-cal072_ses-M0_STIR"], + "val": ["sub-cal157_ses-M12_STIR"], + "test": ["sub-edm076_ses-M0_PSIR"]} +""" + +from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser +import glob +import json +import os +from pathlib import Path +from sklearn.model_selection import train_test_split + + +def _train_test_val_split(list_to_split:list, train_size:float, val_size:float): + """ + Splits a list into 3 lists (test, train and val) using the specified ratios. + + Args: + list_to_split (list): List of scan names to split into datasets + train_size (float): Proportion of scans for training (between 0 and 1) + val_size (float): Proportion of scans for validation (between 0 and 1) + + Returns: + train (list): List of scans in training set + test (list): List of scans in testing set + val (list): List of scans in validation set + """ + + train, test_val= train_test_split(list_to_split,random_state=0, test_size= round(1-train_size, 3)) + val, test= train_test_split(test_val,random_state=0, test_size= (round((1-train_size - val_size)/(1-train_size), 3))) + + return train, test, val + + +def _main(): + parser = ArgumentParser( + prog = 'train_test_val_from_BIDS', + description = 'Saves a json file containing the lists of test, train and val datasets', + formatter_class = ArgumentDefaultsHelpFormatter) + parser.add_argument('-d', '--database', + required = True, + type = Path, + help = 'Path to BIDS database') + parser.add_argument('-o', '--output-path', + required = True, + type = Path, + help = 'Output path to json file where lists will be saved. Must end with .json') + parser.add_argument('-t', '--train-size', + default = 0.8, + type = float, + help = 'Proportion of dataset to use for training') + parser.add_argument('-v', '--val-size', + default = 0.1, + type = float, + help = 'Proportion of dataset to use for validation') + + args = parser.parse_args() + + # Find *_lesion-manual.nii.gz files + search_pattern = os.path.join(args.database, "derivatives", "labels", "**", "**", "anat", "*_lesion-manual.nii.gz") + matching_files = glob.glob(search_pattern) + + # Get volume names + volume_list = [] + for path in matching_files: + volume_name = Path(path).stem[:-len("_lesion-manual.nii")] + volume_list.append(volume_name) + + # Split into train, test and val sets + train, test, val = _train_test_val_split(volume_list, args.train_size, args.val_size) + + data_dict = {"train": train, + "test": test, + "val": val} + + # Save to json file + output_dir = args.output_path.parent + output_dir.mkdir(parents=True, exist_ok=True) + + with open(args.output_path, "w") as outfile: + json.dump(data_dict, outfile) + + +if __name__ == "__main__": + _main() \ No newline at end of file diff --git a/2d_lesion_detection/validation.py b/2d_lesion_detection/validation.py new file mode 100644 index 0000000..0f04688 --- /dev/null +++ b/2d_lesion_detection/validation.py @@ -0,0 +1,565 @@ +""" +Script for yolo model validation + +Takes ground truth bounding box labels and predicted labels, and computes recall and precision. +Numbers of TPs, FPs and FNs for every image, as well as recall and precision for the whole batch are saved to a csv file. +Also saves both ground truth and predicted bounding boxes as nifti images (where the contour of the bboxes is 1) +""" + +from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser +import os +from pathlib import Path +from typing import List, Dict, Tuple, Optional +from PIL import Image +import pandas as pd +import torch +import numpy as np +import nibabel as nib + +IOSA = 0.2 # threshold for box merging, was determined by trying different values. With 0.2 boxes in a similar area + # are merged together. Because of the slice thickness, this threshold can't be too high because lesions can + # shift quite a bit from one slice to the next. + + +def expand_bbox(box1:torch.Tensor, box2:torch.Tensor)-> torch.Tensor: + """ + Returns a single bounding box that contains both input boxes + + Args: + box1 (torch.Tensor): First bounding box + box2 (torch.Tensor): Second bounding box + format --> torch.tensor([x1, y1, x2, y2]) + + Returns: + expanded box (torch.Tensor): Bounding box containing both initial boxes + format --> torch.tensor([x1, y1, x2, y2]) + """ + # Expand box1 to include box2 + b1_s0, b1_sf, b1_x1, b1_y1, b1_x2, b1_y2 = box1 + b2_s0, b2_sf, b2_x1, b2_y1, b2_x2, b2_y2 = box2 + x1 = min(b1_x1, b2_x1) + x2 = max(b1_x2, b2_x2) + y1 = min(b1_y1, b2_y1) + y2 = max(b1_y2, b2_y2) + + s0 = min(b1_s0, b2_s0) + sf = max(b1_sf, b2_sf) + + return torch.Tensor([s0, sf, x1, y1, x2, y2]).int() + + +def intersection_over_smallest_area(boxA:torch.Tensor, boxB:torch.Tensor)-> float: + """" + Given two bounding boxes, calculates the intersection area over the smallest box's area + + Adapted from: https://gist.github.com/meyerjo/dd3533edc97c81258898f60d8978eddc + + Args: + boxA (torch.Tensor): First bounding box + boxB (torch.Tensor): Second bounding box + format --> torch.tensor([s0, sf, x1, y1, x2, y2]) + + Returns: + Intersection over small area (float) + """ + + # determine the (x, y)-coordinates of the intersection rectangle + x1 = max(boxA[2], boxB[2]) + y1 = max(boxA[3], boxB[3]) + x2 = min(boxA[4], boxB[4]) + y2 = min(boxA[5], boxB[5]) + + # compute the area of intersection rectangle + interArea = abs(max((x2 - x1, 0)) * max((y2 - y1), 0)) + + if interArea == 0: + return 0 + + # compute the area of both the prediction and ground-truth rectangles + boxAArea = abs((boxA[4] - boxA[2]) * (boxA[5] - boxA[3])) + boxBArea = abs((boxB[4] - boxB[2]) * (boxB[5] - boxB[3])) + + smallest_area = min(boxAArea, boxBArea) + + return interArea/smallest_area + + +def boxes_overlap_or_consecutive(box1:torch.Tensor, box2:torch.Tensor)->bool: + """ + Determines whether two boxes either overlap or are on consecutive slices. + + Args: + box1 (torch.Tensor): First box + box2 (torch.Tensor): Second box + they should both be formatted as: torch.tensor([s0, sf, x1, y1, x2, y2]) + + Returns: + True or False + """ + # Check if they overlap + if box2[0] <= box1[1] and box2[0] >= box1[0]: + return True + + elif box2[1] <= box1[1] and box2[1] >= box1[0]: + return True + + # Check if they are consecutive + elif box1[1] + 1 == box2[0] or box2[1] + 1 == box1[0]: + return True + + else: + return False + +def merge_overlapping_boxes(boxes:torch.Tensor, iosa_threshold:float)-> List[torch.Tensor]: + """ + Takes a tensor of bounding boxes and groups together the ones that overlap (more than given threshold) + + I chose to use intersection over smallest box area (which is the proportion of the smallest box contained + in the bigger box) as a threshold instead of IoU because this way, if a tiny box is fully within a big box, the tiny + one will for sure be merged (iosa will be 1). + + Args: + boxes (torch.Tensor): tensor containing the bounding boxes + format --> torch.tensor([[s0, sf, x1, y1, x2, y2],[s0, sf, x1, y1, x2, y2], ...]) + iosa_threshold (float): Intersection over smallest area threshold + Boxes with a higher iosa than this value will be merged together + + Returns: + merged_boxes (List[torch.Tensor]): List of merged bounding boxes + """ + + # List that will contain the final merged boxes + merged_boxes = [] + + for box in boxes: + i = 0 + while i < len(merged_boxes): + merged_box = merged_boxes[i] + + # Check if the slices are consecutive + if boxes_overlap_or_consecutive(box, merged_box): + iosa = intersection_over_smallest_area(box, merged_box) + if iosa > iosa_threshold: + # Expand the merged box with box + box = expand_bbox(merged_box, box) + + del merged_boxes[i] # this box will be replaced by the newly merged box + # Don't increment i since merged_boxes is staying the same length (one box is replaced) + else: + i += 1 + else: + i += 1 + merged_boxes.append(box.round().int()) + + return merged_boxes + +def xywhn_to_xyxy(bboxes: torch.Tensor, img_width: int, img_height: int) -> torch.Tensor: + """ + Converts bounding box format from (x_center, y_center, width, height) normalized by image size + to (x1, y1, x2, y2) in pixels. + + Args: + bboxes (torch.Tensor): Tensor of bounding boxes in (x_center, y_center, width, height) format + format --> torch.tensor([[x_center, y_center, width, height],[x_center, y_center, width, height], ...]) + img_width (int): Width of the corresponding image + img_height (int): Height of the corresponding image + + Returns: + (torch.Tensor): Converted bounding box coordinates + format --> torch.tensor([[x1, y1, x2, y2],[x1, y1, x2, y2], ...]) + """ + # Extract coordinates and sizes from the input tensor + x_center, y_center, width, height = bboxes[:, 0], bboxes[:, 1], bboxes[:, 2], bboxes[:, 3] + + # Denormalize coordinates and sizes + x_center = x_center * img_width + width = width * img_width + y_center = y_center * img_height + height = height * img_height + + # Calculate (x1, y1, x2, y2) coordinates + x1 = x_center - width / 2 + y1 = y_center - height / 2 + x2 = x_center + width / 2 + y2 = y_center + height / 2 + + # Stack the converted coordinates into a single tensor + return torch.stack((x1, y1, x2, y2), dim=-1) + + +def get_png_from_txt(txt_path:str)-> str: + """ + Within a YOLO databsase, finds png image (or folder) path associated + with a given label file (or folder) path + + Args: + txt_path (str): Path to txt file or labels folder + + Returns: + images (str): Path to associated png file or images folder + """ + images = txt_path.replace("labels", "images") + + # Check if path is a txt file + if images.endswith(".txt"): + return images.replace(".txt", ".png") + + # If not, assume it is a folder + else: + return images + + +def image_from_bboxes(bboxes: List[torch.Tensor], nii_data: np.ndarray)-> np.ndarray: + """ + Creates a 3d image from bounding box coordinates. Sides of boxes will be 1 and background 0. + Image will have the same dimensions as nii_data. + + Args: + bboxes (List[torch.Tensor]): List of bounding boxes + format --> [torch.tensor([[s0, sf, x1, y1, x2, y2]]), ...] + nii_data (np.ndarray): Original nifti image + + Returns: + nii_data (np.ndarray): Image of the bounding boxes + """ + + # Create an empty image volume + nii_data.fill(0) + + # Set bounding box edges to 1 + for bbox in bboxes: + (s0, sf, x1, y1, x2, y2) = bbox.tolist() + + y1 = nii_data.shape[1] - y1 -1 + y2 = nii_data.shape[1] - y2 -1 + + if y1 >= nii_data.shape[1]-1: + y1 = nii_data.shape[1] - 2 + if x2 >= nii_data.shape[0]-1: + x2 = nii_data.shape[0] - 2 + + nii_data[x1-1:x2+2, y1+1, s0:sf+1] = 1 + nii_data[x1-1:x2+2, y2-1, s0:sf+1] = 1 + nii_data[x1-1, y2-1:y1+1, s0:sf+1] = 1 + nii_data[x2+1, y2-1:y1+1, s0:sf+1] = 1 + + return nii_data + + +def confusion_matrix(ground_truth:Optional[List[torch.Tensor]], + predictions:Optional[List[torch.Tensor]], + iou_threshold:float)-> Tuple[int, int, int]: + """ + Computes True positive, False negative and False positive values from a list of + ground truth bounding boxes and a list of predictions. + + A box from the ground truth list is considered a match with a prediction box if their iou is + above iou_threshold. + + Boxes should be formatted like this: s0, sf, x1, y1, x2, y2 (where s0 and sf are the first and last slices) + + Args: + ground_truth (List[torch.Tensor]): list of ground truth bounding boxes + predictions (List[torch.Tensor]): list of prediction bounding boxes + iou_threshold (float): Intersection over union threshold + Boxes with an iou above or equal to this value are considered a match + + Returns: + tp (int): number of true positives + gt boxes with an iou above or equal to iou_threshold with a prediction box + fn (int): number of false negatives + ground truth boxes that don't match with a prediction box + fp (int): number of false positives + prediction boxes that don't match with a ground truth box + """ + # Start by checking if either of the lists is None + if ground_truth is None: + if predictions is None: + return 0, 0, 0 # Both ground_truth and predictions are None, return 0s for tp, fn, fp + else: + return 0, 0, len(predictions) # Ground_truth is None, so all predictions are false positives (fp) + elif predictions is None: + return 0, len(ground_truth), 0 # Predictions are None, so all ground truth are false negatives (fn) + + + # Create matrix of ious between all ground truths (rows) and predictions (columns) + ious = [] + for gt_box in ground_truth: + iou = [] + for pred_box in predictions: + # Calculate intersection + xmin = max(gt_box[2], pred_box[2]) + ymin = max(gt_box[3], pred_box[3]) + zmin = max(gt_box[0], pred_box[0]) + + xmax = min(gt_box[4], pred_box[4]) + ymax = min(gt_box[5], pred_box[5]) + zmax = min(gt_box[1], pred_box[1]) + + # adding 1 to xmax, ymaz, zmax because if the box has a width of 1 px for example, xmax-xmin will be 0 (same for union) + intersection = max(0, xmax+1 - xmin) * max(0, ymax+1 - ymin) * max(0, zmax+1 - zmin) + + # Calculate union + gt_volume = (gt_box[4]+1 - gt_box[2]) * (gt_box[5]+1 - gt_box[3]) * (gt_box[1]+1 - gt_box[0]) + pred_volume = (pred_box[4]+1 - pred_box[2]) * (pred_box[5]+1 - pred_box[3]) * (pred_box[1]+1 - pred_box[0]) + union = gt_volume + pred_volume - intersection + + # Calculate IoU + iou.append(intersection / union if union > 0 else 0) + ious.append(iou) + + + # Count the number of tp, fn, fp + tp = 0 + fn = 0 + fp = 0 + for _, gt_iou in enumerate(ious): + if max(gt_iou) < iou_threshold: + # For a given ground truth, if the max iou is below threshold -> FN + fn+=1 + + for _, pred_iou in enumerate(zip(*ious)): + # If the max iou for a given prediction is over threshold -> TP + if max(pred_iou) >= iou_threshold: + tp+=1 + # If the max is below threshold -> FP + else: + fp+=1 + + return tp, fn, fp + + +def get_volume_boxes(txt_paths:List[str], yolo_img_folder:Path, iosa:float)-> Dict[str, List[torch.Tensor]]: + """ + From a list of txt file paths containing slice-wise bounding box coordinates, sorts + bbox coordinates by volume into a dictionary and merges overlapping boxes + + Converts format from x_center, y_center, width, height (normalized) + to x1, y1, x2, y2 (in pixels) + + Example of expected filenames -> sub-cal080_ses-M0_STIR_2 + where 2 is the slice number + + Args: + txt_paths (List(str)): List of txt file paths that contain the bounding box coordinates + yolo_img_folder (Path): Path to the yolo dataset folder containing the images that correspond to txt_paths + iosa (float): Intersection over smallest area threshold for two bboxes to be merged + + Returns: + labels_dict (Dict[str, List[torch.Tensor]]): dictionary containing bounding boxes for every volume + key is volume name (sub-cal080_ses-M0_STIR for example) + value is a Tensor containg bounding box coordinates + format -> torch.tensor([s0, sf, x1, y1, x2, y2], [s0, sf, x1, y1, x2, y2], ...) + """ + # Start by making a dictionary that groups slices of each volume together + labels_dict_unmerged = {} + for txt_path in txt_paths: + parts = Path(txt_path).name.split('_') + slice_no = parts[-1].replace(".txt","") + volume = '_'.join(parts[:-1]) + + # Get bbox coordinates as tensor + data = [] + with open(txt_path, 'r') as file: + for line in file: + line = line.strip().split() + line = [float(x) for x in line[1:]] # take line[1:] to ignore the class number + data.append(line) + boxes_tensor = torch.tensor(data) + + image_path = yolo_img_folder/f"{volume}_{slice_no}.png" # corresponding image in yolo dataset + img = Image.open(image_path) # Img dimensions needed for bbox format conversion + boxes_tensor = xywhn_to_xyxy(boxes_tensor, img.width, img.height).round().int() + + # Add slice number at the beginning of each row + slice_indices = torch.tensor([int(slice_no), int(slice_no)]).repeat(boxes_tensor.shape[0], 1) + boxes_tensor = torch.cat((slice_indices, boxes_tensor), dim=1) + + # Add to dict + if volume in labels_dict_unmerged: + labels_dict_unmerged[volume] = torch.cat((labels_dict_unmerged[volume], boxes_tensor), dim=0) + + else: + labels_dict_unmerged[volume] = boxes_tensor + + + # Merge overlapping boxes within a volume + labels_dict={} + for volume, boxes in labels_dict_unmerged.items(): + labels_dict[volume]= merge_overlapping_boxes(boxes, iosa) + + return labels_dict + + +def compute_metrics(volumes_list:List[str], + labels_dict:Dict[str, List[torch.Tensor]], + preds_dict:Dict[str, List[torch.Tensor]], + canproco_path:Path, + output_folder:Path, + iou_threshold:float)-> pd.DataFrame: + """ + Compute TP, FP and FN values between predictions and labels. + Save nifti volumes of ground truth and prediction boxes to output folder + + Args: + volumes_list (List[str]): list of volume names to process + labels_dict (Dict[str, List[torch.Tensor]]): dictionary containing ground truth boxes per volume + preds_dict (Dict[str, List[torch.Tensor]]): dictionary containing prediction boxes per volume + canproco_path (Path): path to canproco database + output_folder (Path): path to output folder where nifti volumes of gt and prediction boxes are saved + iou_threshold (float): Intersection over union threshold for a label and prediction box to be considered a match + + Returns: + all_metrics_df (pd.DataFrame): Dataframe containing tp, fp, fn values for every volume + """ + df_columns = ['Volume', 'TP', 'FP', 'FN'] + all_metrics_df = pd.DataFrame(columns=df_columns) + for volume in volumes_list: + + ## 1- Process ground truth + if volume in labels_dict: + # Get original nifti from canproco + parts = volume.split('_') + nii_path = canproco_path/ parts[0]/ parts[1]/ "anat"/ (volume+".nii.gz") + nii_data = nib.load(str(nii_path)) + + # Save boxes as nifti + label_bboxes = labels_dict[volume] + + label_boxes_image = image_from_bboxes(label_bboxes, nii_data.get_fdata()) + boxes_nii = nib.Nifti1Image(label_boxes_image, nii_data.affine) # keep all metadata from original image + boxes_nii.header.set_data_shape(nii_data.shape) + + nib.save(boxes_nii, str(output_folder/ (volume +"_label.nii.gz"))) + + else: + # If no ground truth boxes exist for this volume + label_bboxes = None + + + ## 2- Process prediction + if volume in preds_dict: + # Get original nifti from canproco + parts = volume.split('_') + nii_path = canproco_path/ parts[0]/ parts[1]/ "anat"/ (volume+".nii.gz") + nii_data = nib.load(str(nii_path)) + + # Save boxes as nifti + pred_bboxes = preds_dict[volume] + + pred_boxes_image = image_from_bboxes(pred_bboxes, nii_data.get_fdata()) + boxes_nii = nib.Nifti1Image(pred_boxes_image, nii_data.affine) # keep all metadata from original image + boxes_nii.header.set_data_shape(nii_data.shape) + + nib.save(boxes_nii, str(output_folder/ (volume +"_pred.nii.gz"))) + + else: + # If no prediction boxes exist for this volume + pred_bboxes = None + + + ## 3- Get metrics for given volume and add to all_metrics_df + tp, fn, fp = confusion_matrix(label_bboxes, pred_bboxes, iou_threshold) + all_metrics_df= pd.concat([all_metrics_df, pd.DataFrame([[volume, tp, fp, fn]], columns=df_columns)], ignore_index=True) + + return all_metrics_df + + +def _main(): + parser = ArgumentParser( + prog = 'Validation', + description = 'Validate a yolo model', + formatter_class = ArgumentDefaultsHelpFormatter) + parser.add_argument('-g', '--gt-path', + required= True, + type = str, + help = 'Path to YOLO dataset folder of ground truth txt files') + parser.add_argument('-p', '--preds-path', + required= True, + type = str, + help = 'Path to prediction folder of txt files') + parser.add_argument('-o', '--output-folder', + required= True, + type = Path, + help = 'Path to directory where results will be saved') + parser.add_argument('-c', '--canproco', + required= True, + type = Path, + help = 'Path to canproco database') + parser.add_argument('-i', '--iou', + default= 0.2, + type = float, + help = 'IoU threshold for a TP') + + args = parser.parse_args() + + # Create output folder if it doesn't exist + os.makedirs(args.output_folder, exist_ok=True) + + ## 1-Get ground truth + print("Loading ground truth...") + # Get list of txt file paths from dataset path + txt_names = os.listdir(args.gt_path) + txt_paths = [os.path.join(args.gt_path, file) for file in txt_names if file.endswith(".txt")] # only keep txts + + # Get dictionary with volume names as keys + # And ground truth bounding box tensors as values + labels_dict = get_volume_boxes(txt_paths, Path(get_png_from_txt(args.gt_path)), IOSA) + + + ## 2-Get predictions + print("Loading predictions...") + # Get list of txt file paths from dataset path + txt_names = os.listdir(args.preds_path) + txt_paths = [os.path.join(args.preds_path, file) for file in txt_names if file.endswith(".txt")] #only keep txts + + # Get dictionary with volume names as keys + # And prediction bounding box tensors as values + preds_dict = get_volume_boxes(txt_paths, Path(get_png_from_txt(args.gt_path)), IOSA) + + + ## Save images with labels and predictions + print("Saving images...") + # Since not all images have a txt file (some don't contain lesions), get list of + # all images from images folder instead of labels folder + img_names = os.listdir(get_png_from_txt(args.gt_path)) + img_paths = [os.path.join(get_png_from_txt(args.gt_path), file) for file in img_names if file.endswith(".png")] #only keep pngs + + # Get a list of all volumes from image paths + volumes=[] + for img_path in img_paths: + parts = Path(img_path).name.split('_') + volume = '_'.join(parts[:-1]) + if not volume in volumes: + volumes.append(volume) + + # Compute metrics and save nifti images of label and pred boxes + all_metrics_df = compute_metrics(volumes, labels_dict, preds_dict, args.canproco, args.output_folder, args.iou) + + all_tp = all_metrics_df['TP'].sum() + all_fp = all_metrics_df['FP'].sum() + all_fn = all_metrics_df['FN'].sum() + + precision = all_tp/(all_tp+all_fp) + recall = all_tp/(all_tp+all_fn) + + # Add a row to dataframe with recall and precision + new_row = pd.DataFrame({'Volume': ['Total'], + 'TP': all_tp, + 'FP': all_fp, + 'FN': all_fn, + 'Recall': recall, + 'Precision': precision}) + + all_metrics_df= pd.concat([all_metrics_df, new_row], ignore_index=True) + + # Save dataframe to csv file + all_metrics_df.to_csv(str(args.output_folder/"metrics_report.csv"), index=False) + + # Print final metrics + print('\nRecall: ', recall) + print('Precision: ', precision) + + +if __name__ == "__main__": + _main() \ No newline at end of file diff --git a/2d_lesion_detection/yolo_hyperparameter_tune.py b/2d_lesion_detection/yolo_hyperparameter_tune.py new file mode 100644 index 0000000..5ce2129 --- /dev/null +++ b/2d_lesion_detection/yolo_hyperparameter_tune.py @@ -0,0 +1,86 @@ +""" +Make sure to have installed ray tune: + pip install "ray[tune]<=2.9.3" + +Performs a hyperparameter search on a yolov8n model, using ray tune. + +To track with wandb, make sure wandb is installed (pip install wandb), +then log into wandb account by following these steps: +https://docs.ultralytics.com/integrations/weights-biases/#configuring-weights-biases +""" + +from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser +import json +from pathlib import Path +from ultralytics import YOLO +from ray import tune + + +def _main(): + parser = ArgumentParser( + prog = 'yolo_hyperparameter_tune', + description = 'Perform a hyperparameter search on yolov8n model', + formatter_class = ArgumentDefaultsHelpFormatter) + parser.add_argument('-d', '--data', + required = True, + type = Path, + help = 'Path to data yaml file') + parser.add_argument('-n', '--name', + default=None, + type = str, + help = 'Run name') + parser.add_argument('-p', '--params', + default="default_tune_params.json", + type = Path, + help = 'Path to parameter file. See default_tune_params.json for format.') + parser.add_argument('-g', '--device', + default= 0, + type = str, + help = 'Device to use. 0 refers to gpu 0') + + args = parser.parse_args() + + # Get device as int if input is int + try: + device = int(args.device) # Try converting to an integer + except ValueError: + device = args.device # keep as string + + # It seems like tune currently only works with the smallest model (n): + # https://github.com/ultralytics/ultralytics/issues/2265 + model = YOLO('yolov8n.pt') + + # Get parameters from params file + with open(args.params, 'r') as file: + config = json.load(file) + + param_space = {} + fixed_params = {} + for key, value in config.items(): + if isinstance(value, list): # If the value is a list, it is a tuning parameter + param_space[key] = tune.uniform(value[0], value[1]) + else: # Otherwise, it's a fixed value + fixed_params[key] = value + + # setting epochs to 40 leads to an error + # the error is avoided with 100 epochs, as suggested here: https://github.com/ultralytics/ultralytics/issues/5874 + result_grid = model.tune(data=args.data, + use_ray=True, + space=param_space, + epochs=100, + device=device, + name=args.name, + **fixed_params) + + if result_grid.errors: + print("One or more trials failed!") + else: + print("No errors!") + + for i, result in enumerate(result_grid): + print(f"Trial #{i}: Configuration: {result.config}, Last Reported Metrics: {result.metrics}") + + +if __name__ == "__main__": + _main() + \ No newline at end of file diff --git a/2d_lesion_detection/yolo_inference.py b/2d_lesion_detection/yolo_inference.py new file mode 100644 index 0000000..e9c1b27 --- /dev/null +++ b/2d_lesion_detection/yolo_inference.py @@ -0,0 +1,124 @@ +""" +Functions to predict slice-wise MS lesion positions using a trained YOLOv8 model. + +Optionally performs post-processing by taking all slice predictions for a volume and +merging boxes that overlap + +Dataset should be formatted in YOLO format as defined in pre-processing.py +""" +import math +import os +from pathlib import Path +from typing import List, Dict +from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser + +import torch +import numpy as np +from ultralytics import YOLO +from ultralytics.engine.results import Results + + +def _get_slice_results_dict(results:List[Results], keep_conf_values:bool)-> Dict[str, torch.Tensor]: + """ + Get a dictionnary of YOLO results with slice name as key + + Args: + results (List[Results]): list of results from YOLO predict mode + keep_conf_values (bool): If true, add confidence values to boxes tensor + + Returns: + result_dict (Dict[str, torch.Tensor]): dictionary containing predictions for every slice + !! boxes are in x_center, y_center, width, height normalized format !! + """ + # Sort results into a dictionnary with slice names as keys + result_dict = {} + for result in results: + slice_name = Path(result.path).name.replace(".png", "") + boxes = result.boxes.xywhn + + if keep_conf_values: + conf = result.boxes.conf + + boxes = torch.cat((boxes, conf.unsqueeze(1)), dim=1) + + result_dict[slice_name] = boxes + + return result_dict + + +def _main(): + parser = ArgumentParser( + prog = 'yolo_inference', + description = 'Detect MS lesions with a YOLOv8 model', + formatter_class = ArgumentDefaultsHelpFormatter) + parser.add_argument('-m', '--model', + required = True, + type = Path, + help = 'Path to trained YOLO model .pt file') + parser.add_argument('-d', '--dataset-path', + required= True, + type = Path, + help = 'Path to dataset folder of png images. Check pre-process.py for format.') + parser.add_argument('-o', '--output-folder', + required= True, + type = Path, + help = 'Path to directory where results will be saved') + parser.add_argument('-b', '--batch-size', + default= 64, + type = int, + help = 'Batch size to use for inference.') + parser.add_argument('-c', '--conf-threshold', + default= 0.2, + type = float, + help = 'Confidence threshold to keep model predictions.') + parser.add_argument('-k', '--keep-conf-values', + action= "store_true", + default = False, + help = 'Save conf values in txt file') + + args = parser.parse_args() + + # Create output folder if it doesn't exist + os.makedirs(args.output_folder, exist_ok=True) + + # Get list of image paths from dataset path + img_names = os.listdir(args.dataset_path) + img_paths = [os.path.join(args.dataset_path, file) for file in img_names if file.endswith(".png")] #only keep pngs + + # Load model + model = YOLO(args.model) + + # Perform inference in batches + # From https://github.com/ultralytics/ultralytics/issues/4835 + results=[] + for i in range(0, len(img_paths), args.batch_size): + print(f"\nPredicting batch {int(i/args.batch_size)+1}/{math.ceil(len(img_paths)/args.batch_size)}") + preds = model.predict(img_paths[i:i+args.batch_size], conf=args.conf_threshold) + for pred in preds: + results.append(pred) + + # Put results in a dictionary + result_dict = _get_slice_results_dict(results, args.keep_conf_values) + + + # Save results + print(f"\nSaving results to {str(args.output_folder)}") + for name, boxes in result_dict.items(): + # name is either slice name (if volume is False) + # or volume name (if volume is True) + + if boxes.numel() == 0: + # If no boxes are predicted, skip volume/ slice (no txt file is saved) + continue + + # Save to txt + with open(args.output_folder/(name+".txt"), "w") as file: + # Iterate over the tensors in the list + for box in boxes: + line = ' '.join(['0'] + [str(val) for val in box.cpu().data.numpy()]) # add a zero to indicate the class (assuming there is only one class) + file.write(line + '\n') + + + +if __name__ == "__main__": + _main() \ No newline at end of file diff --git a/2d_lesion_detection/yolo_training.py b/2d_lesion_detection/yolo_training.py new file mode 100644 index 0000000..72b934c --- /dev/null +++ b/2d_lesion_detection/yolo_training.py @@ -0,0 +1,87 @@ +""" +Script to train a YOLOv8 model + +Training progress can be tracked using clearML (other platforms are also integrated but haven't been tested): +https://docs.ultralytics.com/integrations/clearml/#configuring-clearml + +""" + +from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser +import json +from pathlib import Path + +import torch +from ultralytics import YOLO + +CANPROCO_VERSION = "bcd627ed4" # last commit from canproco repo + +def _main(): + parser = ArgumentParser( + prog = 'yolo_training', + description = 'Train a YOLOv8 model', + formatter_class = ArgumentDefaultsHelpFormatter) + parser.add_argument('-d', '--data', + required = True, + type = Path, + help = 'Path to data yaml file') + parser.add_argument('-p', '--params', + default = "default_train_params.json", + type = Path, + help = 'Path to params json file') + parser.add_argument('-n', '--name', + default= None, + type = str, + help = 'Model name') + parser.add_argument('-g', '--device', + default= 0, + type = str, + help = 'Device to use. 0 refers to gpu 0') + + args = parser.parse_args() + + # Get device as int if input is int + try: + device = int(args.device) # Try converting to an integer + except ValueError: + device = args.device # keep as string + + # Get parameters + with open(args.params, 'r') as file: + config = json.load(file) + + # Load a pretrained model + model = YOLO('yolov8n.pt') + + # Train the model + model.train(data=args.data, + name=args.name, + device = device, + **config) + + + ## Add canproco version to model metadata + # Define the metadata + metadata = {'dataset_version': CANPROCO_VERSION, + 'description': 'Model for multiple sclerosis lesion detection on MRI images of spinal cord.'} + save_dir = model.trainer.save_dir + + # best.py + print("Adding metadata to best.py") + model_path = save_dir/"weights"/"best.pt" + best_model = torch.load(model_path) + best_model["metadata"] = metadata + torch.save(best_model, model_path) + + # last.py + print("Adding metadata to last.py") + model_path = save_dir/"weights"/"last.pt" + last_model = torch.load(model_path) + last_model["metadata"] = metadata + torch.save(last_model, model_path) + + # # To access the metadata: + # metadata_read = torch_model.get('metadata', None) + + +if __name__ == "__main__": + _main() diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e69de29