--- a +++ b/notebooks/inference_patients_notebook.py @@ -0,0 +1,581 @@ +""" +Notebook to run inference and generate figures + +""" + +# %% +from collections import OrderedDict, defaultdict + +import pandas as pd +import cv2 +import torch +import os +import yaml +import numpy as np +import json +from tqdm import tqdm +import nibabel as nib +from shutil import copy + +import matplotlib.pyplot as plt +from matplotlib import cm +import albumentations +from torchvision.utils import make_grid + +import SimpleITK as sitk + +# from mpl_toolkits.mplot3d import Axes3D + +# enable lib loading even if not installed as a pip package or in PYTHONPATH +# also convenient for relative paths in example config files +from pathlib import Path + +os.chdir(Path(__file__).resolve().parent.parent) + +from adpkd_segmentation.config.config_utils import get_object_instance # noqa +from adpkd_segmentation.datasets import dataloader as _dataloader # noqa +from adpkd_segmentation.datasets import datasets as _datasets # noqa +from adpkd_segmentation.data.link_data import makelinks # noqa +from adpkd_segmentation.data.data_utils import display_sample # noqa +from adpkd_segmentation.utils.train_utils import load_model_data # noqa +from adpkd_segmentation.utils.stats_utils import ( # noqa + bland_altman_plot, + scatter_plot, + linreg_plot, +) + +from adpkd_segmentation.utils.losses import ( + SigmoidBinarize, + Dice, + binarize_thresholds, +) # noqa +from torch.nn import Sigmoid + +# %% + + +def load_config(config_path, run_makelinks=False): + """Reads config file and calculates additional dcm attributes such as + slice volume. Returns a dictionary used for patient wide calculations + such as TKV. + + Args: + config_path (str): config file path + run_makelinks (bool, optional): Creates symbolic links during the first run. Defaults to False. + + Returns: + dataloader, model, device, binarize_func, save_dir (str), model_name (str), split (str) + """ + + if run_makelinks: + makelinks() + with open(config_path, "r") as f: + config = yaml.load(f, Loader=yaml.FullLoader) + + model_config = config["_MODEL_CONFIG"] + loader_to_eval = config["_LOADER_TO_EVAL"] + dataloader_config = config[loader_to_eval] + saved_checkpoint = config["_MODEL_CHECKPOINT"] + checkpoint_format = config["_NEW_CKP_FORMAT"] + + model = get_object_instance(model_config)() + if saved_checkpoint is not None: + load_model_data(saved_checkpoint, model, new_format=checkpoint_format) + + dataloader = get_object_instance(dataloader_config)() + + # TODO: support other metrics as needed + # binarize_func = SigmoidBinarize(thresholds=[0.5]) + + pred_process_config = config["_LOSSES_METRICS_CONFIG"]["criterions_dict"][ + "dice_metric" + ]["pred_process"] + pred_process = get_object_instance(pred_process_config) + + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + model = model.to(device) + model.eval() + + model_name = Path(config_path).absolute().parts[-3] + + save_dir = "./saved_inference" + + return (dataloader, model, device, pred_process, save_dir, model_name) + + +def plot_model_results(csv_path, name): + df = pd.read_csv(csv_path) + pred = df["TKV_Pred"].to_numpy() + gt = df["TKV_GT"].to_numpy() + bland_altman_plot( + pred, gt, percent=True, title=f"{name} BA Plot: TKV % error" + ) + + patient_dice = df["patient_dice"].to_numpy() + scatter_plot(patient_dice, gt, title=f"{name} Dice by TKV") + linreg_plot(pred, gt, title=f"{name} Linear Fit") + + +def inference_to_disk( + dataloader, + model, + device, + binarize_func, + save_dir="./saved_inference", + model_name="model", +): + dataset = dataloader.dataset + output_idx_check = ( + hasattr(dataloader.dataset, "output_idx") + and dataloader.dataset.output_idx + ) + + assert ( + output_idx_check is True + ), "output indexes are required for the dataset" + + for batch_idx, output in enumerate(dataloader): + + x_batch, idxs_batch = output + x_batch = x_batch.to(device) + + with torch.no_grad(): + + # get_verbose returns (sample, dcm_path, attributes dict) + dcm_file_paths = [ + Path(dataset.get_verbose(idx)[1]) for idx in idxs_batch + ] + + dcm_file_names = [ + Path(dataset.get_verbose(idx)[1]).stem for idx in idxs_batch + ] + + file_attribs = [dataset.get_verbose(idx)[2] for idx in idxs_batch] + + y_batch_hat = model(x_batch) + # TODO: support only sigmoid saves + y_batch_hat_binary = binarize_func(y_batch_hat) + + for dcm_path, dcm_name, file_attrib, img, logit, pred in zip( + dcm_file_paths, + dcm_file_names, + file_attribs, + x_batch, + y_batch_hat, + y_batch_hat_binary, + ): + out_dir = ( + Path.cwd() + / Path(save_dir) + / model_name + / file_attrib["patient"] + / file_attrib["MR"] + / dcm_name + ) + + out_dir.parent.mkdir(parents=True, exist_ok=True) + # print(out_dir) + + np.save(str(out_dir) + "_img", img.cpu().numpy()) + np.save(str(out_dir) + "_logit", logit.cpu().numpy()) + np.save(str(out_dir) + "_pred", pred.cpu().numpy()) + copy(dcm_path, out_dir.with_suffix(".dcm")) + + class NpEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, np.integer): + return int(obj) + elif isinstance(obj, np.floating): + return float(obj) + elif isinstance(obj, np.ndarray): + return obj.tolist() + else: + return super(NpEncoder, self).default(obj) + + # get resize transform within compose object + Resize = albumentations.augmentations.transforms.Resize + transform_resize = next( + v + for v in dataloader.dataset.augmentation.transforms + if isinstance(v, Resize) + ) + assert ( + transform_resize is not None + ), "transform_resize must be defined" + + file_attrib["transform_resize_dim"] = ( + transform_resize.height, + transform_resize.width, + ) + + attrib_json = json.dumps(file_attrib, cls=NpEncoder) + f = open(str(out_dir) + "_attrib.json", "w") + f.write(attrib_json) + f.close() + + +# %% +def resized_stack(numpy_list, dsize=None): + """resizing lists of array with dimension: + slices x 1 x H x W, where H = W. + + Sets output size to first array at idx 0 or dsize + + Args: + numpy_list (list): list of numpy arr + dsize (int, optional): output dimension. Defaults to None. + + Returns: + numpy: stacked numpy lists with same size + """ + assert numpy_list[0].shape[1] == 1, "dimension check" + assert numpy_list[0].shape[2] == numpy_list[0].shape[3], "square check" + + def reshape(arr): + """reshapes [slices x 1 x H x W] to [H x W x slices]""" + arr = np.moveaxis(arr, 0, -1) # slices to end + arr = np.squeeze(arr) # remove 1 dimension + return arr + + reshaped = [reshape(arr) for arr in numpy_list] + + if dsize is None: + dsize = reshaped[0].shape[0:2] # get H, W from first arr + + resized = [ + cv2.resize(src, dsize, interpolation=cv2.INTER_CUBIC) + for src in reshaped + ] + + return np.stack(resized) + + +def display_volumes( + study_dir, + style="prob", + plot_error=False, + skip_display=True, + save_dir=None, + output_style="png", +): + """Displays inference over original image + + Args: + study_dir (path): Directory of inferences. + style (str, optional): Type of data displayed. + Defaults to "prob" for probability. + plot_error (bool, optional): Display error. Defaults to False. + skip_display (bool, optional): Display console display. + save_dir (path, optional): Directory to save figs. Defaults to None. + + Returns: + dict: Dictionary of images, logits, predictions, probs + """ + + print(f"loading from {study_dir}") + study_dir = Path(study_dir) + imgs = sorted(study_dir.glob("*_img.npy")) + imgs_np = [np.load(i) for i in imgs] + logits = sorted(study_dir.glob("*_logit.npy")) + logits_np = [np.load(logit) for logit in logits] + preds = sorted(study_dir.glob("*_pred.npy")) + preds_np = [np.load(p) for p in preds] + + vols = { + "img": np.stack(imgs_np), + "logit": np.stack(logits_np), + "pred": np.stack(preds_np), + "prob": torch.sigmoid(torch.from_numpy(np.stack(logits_np))).numpy(), + } + + def show(img, label=None, img_alpha=1, lb_alpha=0.3): + npimg = img.numpy() + fig, ax = plt.subplots(figsize=(20, 10)) + ax.imshow( + np.transpose(npimg, (1, 2, 0)), + interpolation="none", + alpha=img_alpha, + ) + if label is not None: + lbimg = label.numpy() + ax.imshow( + np.transpose(lbimg, (1, 2, 0)), + alpha=lb_alpha, + interpolation="none", + ) + + x = torch.from_numpy(vols["img"]) + y = vols[style] + + def norm_tensor(x): + x = x / x.sum(0).expand_as(x) + x[torch.isnan(x)] = 0 + return x + + bkgrd_thresh = 0.01 + cmap_vol = np.ma.masked_where(y <= bkgrd_thresh, y) + cmap_vol = np.apply_along_axis(cm.inferno, 0, cmap_vol) + cmap_vol = torch.from_numpy(np.squeeze(cmap_vol)) + + print(f"style is: {style}") + print(f"vol stats: min:{y.min()} max:{y.max()} mean:{y.mean()}") + + if not skip_display: + if style == "img": + show(make_grid(x), lb_alpha=0.5) + plt.show() + else: + show(make_grid(x), make_grid(cmap_vol), lb_alpha=0.5) + if save_dir is None: + plt.show() + else: + os.makedirs(save_dir, exist_ok=True) + plt.savefig(Path(save_dir) / "label_grid.svg") + plt.savefig(Path(save_dir) / "label_grid.png") + return y + + +def exam_preds_to_stat( + pred_vol, ground_vol, pred_process, attrib_dict, pred_std=None +): + """computes stats for a single exam prediction + + Args: + pred_vol (numpy): prediction volume + ground_vol (numpy): ground truth volume + pred_process (function): converts prediction to binary + attrib (dict): dictionary of attributes (usually from index 0) + + Returns: + tuple: study key, dictionary of attributes + """ + volume_ground = None + volume_pred = None + dice = Dice( + pred_process=pred_process, use_as_loss=False, power=1, dim=(0, 1, 2, 3) + ) + dice_val = dice( + torch.from_numpy(pred_vol), torch.from_numpy(ground_vol) + ).item() + + scale_factor = (attrib_dict["dim"][0] ** 2) / ( + attrib_dict["transform_resize_dim"][0] ** 2 + ) + # print(f"scale factor {scale_factor}") + pred_pixel_count = torch.sum( + pred_process(torch.from_numpy(pred_vol)) + ).item() + volume_pred = scale_factor * attrib_dict["vox_vol"] * pred_pixel_count + + ground_pixel_count = torch.sum( + pred_process(torch.from_numpy(ground_vol)) + ).item() + volume_ground = scale_factor * attrib_dict["vox_vol"] * ground_pixel_count + + attrib_dict.update( + { + "TKV_GT": volume_ground, + "TKV_Pred": volume_pred, + "patient_dice": dice_val, + "study": attrib_dict["patient"] + attrib_dict["MR"], + "scale_factor": scale_factor, + "Pred_stdev": pred_std, + } + ) + + return attrib_dict + + +def compute_inference_stats( + save_dir, output=False, display=False, patient_ID=None +): + + Metric_data = OrderedDict() + Combined_metric_data = OrderedDict() + root = Path.cwd() / Path(save_dir) + + model_inferences = list(root.glob("*")) + newline = "\n" + formated_list = "".join([f"{newline} {m}" for m in model_inferences]) + + print(f"calculating model inferences for {formated_list}") + + all_logit_vol = defaultdict(list) + all_pred_vol = defaultdict(list) + all_ground_vol = defaultdict(list) + all_summaries = defaultdict(list) + + pred_process = SigmoidBinarize(thresholds=[0.5]) + + for model_dir in tqdm(model_inferences): + if patient_ID is not None: + MR_num = "*" + else: + patient_ID, MR_num = "*", "*" + studies = model_dir.glob(f"{patient_ID}/{MR_num}") + + for study_dir in studies: + imgs = sorted(study_dir.glob("*_img.npy")) + imgs_np = [np.load(i) for i in imgs] + logits = sorted(study_dir.glob("*_logit.npy")) + logits_np = [np.load(logit) for logit in logits] + preds = sorted(study_dir.glob("*_pred.npy")) + preds_np = [np.load(p) for p in preds] + grounds = sorted(study_dir.glob("*_ground.npy")) + grounds_np = [np.load(g) for g in grounds] + attribs = sorted(study_dir.glob("*_attrib.json")) + attribs_dicts = [] + for a in attribs: + with open(a) as json_file: + attribs_dicts.append(json.load(json_file)) + + # volumes for a study within one model inference + img_vol = np.stack(imgs_np) + logit_vol = np.stack(logits_np) + pred_vol = np.stack(preds_np) + ground_vol = np.stack(grounds_np) + + if display is True: + display_volumes(img_vol, pred_vol, ground_vol) + + summary = exam_preds_to_stat( + pred_vol, ground_vol, pred_process, attribs_dicts[0] + ) + + Metric_data[summary["study"]] = summary + + # accumulate predictions across all models for each study + all_logit_vol[summary["study"]].append(logit_vol) + all_pred_vol[summary["study"]].append(pred_vol) + all_ground_vol[summary["study"]].append(ground_vol) + all_summaries[summary["study"]].append(summary) + + df = pd.DataFrame(Metric_data).transpose() + + if output is True: + df.to_csv(f"stats-{model_dir.name}.csv") + + for key, value in all_logit_vol.items(): + # uses index 0 to get ground truth and standard voxel attribs + + def sigmoid(x): + return 1 / (1 + np.exp(-x)) + + # resizes by index 0 + prob_vol = resized_stack(value) + # prob_vol = sigmoid(prob_vol) + prob_vol = np.mean(prob_vol, axis=0) + prob_std = np.std(prob_vol) + + prob_vol = np.moveaxis(prob_vol, -1, 0) # b x (X x Y) + prob_vol = np.expand_dims(prob_vol, axis=1) # b x c x (X x Y) + pred_vol = binarize_thresholds(torch.from_numpy(prob_vol)).numpy() + ground_vol = all_ground_vol[key][0] + + summary = exam_preds_to_stat( + pred_vol, + ground_vol, + pred_process, + all_summaries[key][0], + pred_std=prob_std, + ) + + Combined_metric_data[summary["study"]] = summary + + df = pd.DataFrame(Combined_metric_data).transpose() + + if output is True: + print("saving combined csv") + df.to_csv("stats-combined_models.csv") + + +# %% +# Load models +paths = [ + "checkpoints/inference.yml", +] + +# %% +# Run inferences +for p in tqdm(paths): + model_args = load_config(config_path=p) + inference_to_disk(*model_args) + +# %% +# Show sample inference +root_saved_inference = Path("saved_inference") +study = "adpkd-segmentation/1001707280/Axial T2 SS-FSE" + +y = display_volumes( + study_dir=root_saved_inference / study, + style="pred", + plot_error=True, + skip_display=False, +) +# %% +# Creating figures for all inferences + +# Get all model inferences +inference_files = Path("saved_inference").glob("**/*") + +# Folders are of form 'saved_inference/adpkd-segmentation/{PATIENT-ID}/{SERIES}' +folders = [f.parent for f in inference_files if len(f.parent.parts) == 4] +folders = list(set(folders)) + +IDX_series = -1 +IDX_ID = -2 + +saved_folders = [ + Path("saved_figs") / f"{d.parts[IDX_ID]}_{d.parts[IDX_series]}" + for d in folders +] +# %% +# Generate figures for all inferences +for study_dir, save_dir in tqdm(list(zip(folders, saved_folders))[17:]): + try: + # Save inference figure to save_dir + y = display_volumes( + study_dir=study_dir, + style="pred", + plot_error=True, + skip_display=False, + save_dir=save_dir, + ) + except Exception as e: + print(e) +# %% + +inference_files = Path("saved_inference").glob("**/*img.npy") +folders = list(set([f.parent for f in inference_files])) +print(folders) + +# %% + + +for folder in folders: + print(f"Folder: {folder}") + imgs = list(folder.glob("*img.npy")) + imgs = sorted(imgs) + np_imgs = [(np.load(i))[0] for i in imgs] + np_vol = np.stack(np_imgs).T + print(np_vol.shape) + print(f"{folder.name}_img_vol.nii") + + dicom_paths = list(folder.glob("*.dcm")) + dicom_paths = sorted(dicom_paths) + reader = sitk.ImageSeriesReader() + reader.SetFileNames([str(path) for path in dicom_paths]) + errors = [] + try: + image_3d = reader.Execute() + sitk.WriteImage( + image_3d, + str(folder / f"{folder.name}_original.nii"), # noqa + ) + + except Exception as e: + print(e) + errors.append(f"error:{str(e)}\n folder:{folder}") + + nifi_vol = nib.Nifti1Image(np_vol, affine=np.eye(4)) + nib.save(nifi_vol, folder / f"{folder.name}_img_vol.nii")