"""
Notebook to run inference and generate figures
"""
# %%
from collections import OrderedDict, defaultdict
import pandas as pd
import cv2
import torch
import os
import yaml
import numpy as np
import json
from tqdm import tqdm
import nibabel as nib
from shutil import copy
import matplotlib.pyplot as plt
from matplotlib import cm
import albumentations
from torchvision.utils import make_grid
import SimpleITK as sitk
# from mpl_toolkits.mplot3d import Axes3D
# enable lib loading even if not installed as a pip package or in PYTHONPATH
# also convenient for relative paths in example config files
from pathlib import Path
os.chdir(Path(__file__).resolve().parent.parent)
from adpkd_segmentation.config.config_utils import get_object_instance # noqa
from adpkd_segmentation.datasets import dataloader as _dataloader # noqa
from adpkd_segmentation.datasets import datasets as _datasets # noqa
from adpkd_segmentation.data.link_data import makelinks # noqa
from adpkd_segmentation.data.data_utils import display_sample # noqa
from adpkd_segmentation.utils.train_utils import load_model_data # noqa
from adpkd_segmentation.utils.stats_utils import ( # noqa
bland_altman_plot,
scatter_plot,
linreg_plot,
)
from adpkd_segmentation.utils.losses import (
SigmoidBinarize,
Dice,
binarize_thresholds,
) # noqa
from torch.nn import Sigmoid
# %%
def load_config(config_path, run_makelinks=False):
"""Reads config file and calculates additional dcm attributes such as
slice volume. Returns a dictionary used for patient wide calculations
such as TKV.
Args:
config_path (str): config file path
run_makelinks (bool, optional): Creates symbolic links during the first run. Defaults to False.
Returns:
dataloader, model, device, binarize_func, save_dir (str), model_name (str), split (str)
"""
if run_makelinks:
makelinks()
with open(config_path, "r") as f:
config = yaml.load(f, Loader=yaml.FullLoader)
model_config = config["_MODEL_CONFIG"]
loader_to_eval = config["_LOADER_TO_EVAL"]
dataloader_config = config[loader_to_eval]
saved_checkpoint = config["_MODEL_CHECKPOINT"]
checkpoint_format = config["_NEW_CKP_FORMAT"]
model = get_object_instance(model_config)()
if saved_checkpoint is not None:
load_model_data(saved_checkpoint, model, new_format=checkpoint_format)
dataloader = get_object_instance(dataloader_config)()
# TODO: support other metrics as needed
# binarize_func = SigmoidBinarize(thresholds=[0.5])
pred_process_config = config["_LOSSES_METRICS_CONFIG"]["criterions_dict"][
"dice_metric"
]["pred_process"]
pred_process = get_object_instance(pred_process_config)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()
model_name = Path(config_path).absolute().parts[-3]
save_dir = "./saved_inference"
return (dataloader, model, device, pred_process, save_dir, model_name)
def plot_model_results(csv_path, name):
df = pd.read_csv(csv_path)
pred = df["TKV_Pred"].to_numpy()
gt = df["TKV_GT"].to_numpy()
bland_altman_plot(
pred, gt, percent=True, title=f"{name} BA Plot: TKV % error"
)
patient_dice = df["patient_dice"].to_numpy()
scatter_plot(patient_dice, gt, title=f"{name} Dice by TKV")
linreg_plot(pred, gt, title=f"{name} Linear Fit")
def inference_to_disk(
dataloader,
model,
device,
binarize_func,
save_dir="./saved_inference",
model_name="model",
):
dataset = dataloader.dataset
output_idx_check = (
hasattr(dataloader.dataset, "output_idx")
and dataloader.dataset.output_idx
)
assert (
output_idx_check is True
), "output indexes are required for the dataset"
for batch_idx, output in enumerate(dataloader):
x_batch, idxs_batch = output
x_batch = x_batch.to(device)
with torch.no_grad():
# get_verbose returns (sample, dcm_path, attributes dict)
dcm_file_paths = [
Path(dataset.get_verbose(idx)[1]) for idx in idxs_batch
]
dcm_file_names = [
Path(dataset.get_verbose(idx)[1]).stem for idx in idxs_batch
]
file_attribs = [dataset.get_verbose(idx)[2] for idx in idxs_batch]
y_batch_hat = model(x_batch)
# TODO: support only sigmoid saves
y_batch_hat_binary = binarize_func(y_batch_hat)
for dcm_path, dcm_name, file_attrib, img, logit, pred in zip(
dcm_file_paths,
dcm_file_names,
file_attribs,
x_batch,
y_batch_hat,
y_batch_hat_binary,
):
out_dir = (
Path.cwd()
/ Path(save_dir)
/ model_name
/ file_attrib["patient"]
/ file_attrib["MR"]
/ dcm_name
)
out_dir.parent.mkdir(parents=True, exist_ok=True)
# print(out_dir)
np.save(str(out_dir) + "_img", img.cpu().numpy())
np.save(str(out_dir) + "_logit", logit.cpu().numpy())
np.save(str(out_dir) + "_pred", pred.cpu().numpy())
copy(dcm_path, out_dir.with_suffix(".dcm"))
class NpEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, np.floating):
return float(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
else:
return super(NpEncoder, self).default(obj)
# get resize transform within compose object
Resize = albumentations.augmentations.transforms.Resize
transform_resize = next(
v
for v in dataloader.dataset.augmentation.transforms
if isinstance(v, Resize)
)
assert (
transform_resize is not None
), "transform_resize must be defined"
file_attrib["transform_resize_dim"] = (
transform_resize.height,
transform_resize.width,
)
attrib_json = json.dumps(file_attrib, cls=NpEncoder)
f = open(str(out_dir) + "_attrib.json", "w")
f.write(attrib_json)
f.close()
# %%
def resized_stack(numpy_list, dsize=None):
"""resizing lists of array with dimension:
slices x 1 x H x W, where H = W.
Sets output size to first array at idx 0 or dsize
Args:
numpy_list (list): list of numpy arr
dsize (int, optional): output dimension. Defaults to None.
Returns:
numpy: stacked numpy lists with same size
"""
assert numpy_list[0].shape[1] == 1, "dimension check"
assert numpy_list[0].shape[2] == numpy_list[0].shape[3], "square check"
def reshape(arr):
"""reshapes [slices x 1 x H x W] to [H x W x slices]"""
arr = np.moveaxis(arr, 0, -1) # slices to end
arr = np.squeeze(arr) # remove 1 dimension
return arr
reshaped = [reshape(arr) for arr in numpy_list]
if dsize is None:
dsize = reshaped[0].shape[0:2] # get H, W from first arr
resized = [
cv2.resize(src, dsize, interpolation=cv2.INTER_CUBIC)
for src in reshaped
]
return np.stack(resized)
def display_volumes(
study_dir,
style="prob",
plot_error=False,
skip_display=True,
save_dir=None,
output_style="png",
):
"""Displays inference over original image
Args:
study_dir (path): Directory of inferences.
style (str, optional): Type of data displayed.
Defaults to "prob" for probability.
plot_error (bool, optional): Display error. Defaults to False.
skip_display (bool, optional): Display console display.
save_dir (path, optional): Directory to save figs. Defaults to None.
Returns:
dict: Dictionary of images, logits, predictions, probs
"""
print(f"loading from {study_dir}")
study_dir = Path(study_dir)
imgs = sorted(study_dir.glob("*_img.npy"))
imgs_np = [np.load(i) for i in imgs]
logits = sorted(study_dir.glob("*_logit.npy"))
logits_np = [np.load(logit) for logit in logits]
preds = sorted(study_dir.glob("*_pred.npy"))
preds_np = [np.load(p) for p in preds]
vols = {
"img": np.stack(imgs_np),
"logit": np.stack(logits_np),
"pred": np.stack(preds_np),
"prob": torch.sigmoid(torch.from_numpy(np.stack(logits_np))).numpy(),
}
def show(img, label=None, img_alpha=1, lb_alpha=0.3):
npimg = img.numpy()
fig, ax = plt.subplots(figsize=(20, 10))
ax.imshow(
np.transpose(npimg, (1, 2, 0)),
interpolation="none",
alpha=img_alpha,
)
if label is not None:
lbimg = label.numpy()
ax.imshow(
np.transpose(lbimg, (1, 2, 0)),
alpha=lb_alpha,
interpolation="none",
)
x = torch.from_numpy(vols["img"])
y = vols[style]
def norm_tensor(x):
x = x / x.sum(0).expand_as(x)
x[torch.isnan(x)] = 0
return x
bkgrd_thresh = 0.01
cmap_vol = np.ma.masked_where(y <= bkgrd_thresh, y)
cmap_vol = np.apply_along_axis(cm.inferno, 0, cmap_vol)
cmap_vol = torch.from_numpy(np.squeeze(cmap_vol))
print(f"style is: {style}")
print(f"vol stats: min:{y.min()} max:{y.max()} mean:{y.mean()}")
if not skip_display:
if style == "img":
show(make_grid(x), lb_alpha=0.5)
plt.show()
else:
show(make_grid(x), make_grid(cmap_vol), lb_alpha=0.5)
if save_dir is None:
plt.show()
else:
os.makedirs(save_dir, exist_ok=True)
plt.savefig(Path(save_dir) / "label_grid.svg")
plt.savefig(Path(save_dir) / "label_grid.png")
return y
def exam_preds_to_stat(
pred_vol, ground_vol, pred_process, attrib_dict, pred_std=None
):
"""computes stats for a single exam prediction
Args:
pred_vol (numpy): prediction volume
ground_vol (numpy): ground truth volume
pred_process (function): converts prediction to binary
attrib (dict): dictionary of attributes (usually from index 0)
Returns:
tuple: study key, dictionary of attributes
"""
volume_ground = None
volume_pred = None
dice = Dice(
pred_process=pred_process, use_as_loss=False, power=1, dim=(0, 1, 2, 3)
)
dice_val = dice(
torch.from_numpy(pred_vol), torch.from_numpy(ground_vol)
).item()
scale_factor = (attrib_dict["dim"][0] ** 2) / (
attrib_dict["transform_resize_dim"][0] ** 2
)
# print(f"scale factor {scale_factor}")
pred_pixel_count = torch.sum(
pred_process(torch.from_numpy(pred_vol))
).item()
volume_pred = scale_factor * attrib_dict["vox_vol"] * pred_pixel_count
ground_pixel_count = torch.sum(
pred_process(torch.from_numpy(ground_vol))
).item()
volume_ground = scale_factor * attrib_dict["vox_vol"] * ground_pixel_count
attrib_dict.update(
{
"TKV_GT": volume_ground,
"TKV_Pred": volume_pred,
"patient_dice": dice_val,
"study": attrib_dict["patient"] + attrib_dict["MR"],
"scale_factor": scale_factor,
"Pred_stdev": pred_std,
}
)
return attrib_dict
def compute_inference_stats(
save_dir, output=False, display=False, patient_ID=None
):
Metric_data = OrderedDict()
Combined_metric_data = OrderedDict()
root = Path.cwd() / Path(save_dir)
model_inferences = list(root.glob("*"))
newline = "\n"
formated_list = "".join([f"{newline} {m}" for m in model_inferences])
print(f"calculating model inferences for {formated_list}")
all_logit_vol = defaultdict(list)
all_pred_vol = defaultdict(list)
all_ground_vol = defaultdict(list)
all_summaries = defaultdict(list)
pred_process = SigmoidBinarize(thresholds=[0.5])
for model_dir in tqdm(model_inferences):
if patient_ID is not None:
MR_num = "*"
else:
patient_ID, MR_num = "*", "*"
studies = model_dir.glob(f"{patient_ID}/{MR_num}")
for study_dir in studies:
imgs = sorted(study_dir.glob("*_img.npy"))
imgs_np = [np.load(i) for i in imgs]
logits = sorted(study_dir.glob("*_logit.npy"))
logits_np = [np.load(logit) for logit in logits]
preds = sorted(study_dir.glob("*_pred.npy"))
preds_np = [np.load(p) for p in preds]
grounds = sorted(study_dir.glob("*_ground.npy"))
grounds_np = [np.load(g) for g in grounds]
attribs = sorted(study_dir.glob("*_attrib.json"))
attribs_dicts = []
for a in attribs:
with open(a) as json_file:
attribs_dicts.append(json.load(json_file))
# volumes for a study within one model inference
img_vol = np.stack(imgs_np)
logit_vol = np.stack(logits_np)
pred_vol = np.stack(preds_np)
ground_vol = np.stack(grounds_np)
if display is True:
display_volumes(img_vol, pred_vol, ground_vol)
summary = exam_preds_to_stat(
pred_vol, ground_vol, pred_process, attribs_dicts[0]
)
Metric_data[summary["study"]] = summary
# accumulate predictions across all models for each study
all_logit_vol[summary["study"]].append(logit_vol)
all_pred_vol[summary["study"]].append(pred_vol)
all_ground_vol[summary["study"]].append(ground_vol)
all_summaries[summary["study"]].append(summary)
df = pd.DataFrame(Metric_data).transpose()
if output is True:
df.to_csv(f"stats-{model_dir.name}.csv")
for key, value in all_logit_vol.items():
# uses index 0 to get ground truth and standard voxel attribs
def sigmoid(x):
return 1 / (1 + np.exp(-x))
# resizes by index 0
prob_vol = resized_stack(value)
# prob_vol = sigmoid(prob_vol)
prob_vol = np.mean(prob_vol, axis=0)
prob_std = np.std(prob_vol)
prob_vol = np.moveaxis(prob_vol, -1, 0) # b x (X x Y)
prob_vol = np.expand_dims(prob_vol, axis=1) # b x c x (X x Y)
pred_vol = binarize_thresholds(torch.from_numpy(prob_vol)).numpy()
ground_vol = all_ground_vol[key][0]
summary = exam_preds_to_stat(
pred_vol,
ground_vol,
pred_process,
all_summaries[key][0],
pred_std=prob_std,
)
Combined_metric_data[summary["study"]] = summary
df = pd.DataFrame(Combined_metric_data).transpose()
if output is True:
print("saving combined csv")
df.to_csv("stats-combined_models.csv")
# %%
# Load models
paths = [
"checkpoints/inference.yml",
]
# %%
# Run inferences
for p in tqdm(paths):
model_args = load_config(config_path=p)
inference_to_disk(*model_args)
# %%
# Show sample inference
root_saved_inference = Path("saved_inference")
study = "adpkd-segmentation/1001707280/Axial T2 SS-FSE"
y = display_volumes(
study_dir=root_saved_inference / study,
style="pred",
plot_error=True,
skip_display=False,
)
# %%
# Creating figures for all inferences
# Get all model inferences
inference_files = Path("saved_inference").glob("**/*")
# Folders are of form 'saved_inference/adpkd-segmentation/{PATIENT-ID}/{SERIES}'
folders = [f.parent for f in inference_files if len(f.parent.parts) == 4]
folders = list(set(folders))
IDX_series = -1
IDX_ID = -2
saved_folders = [
Path("saved_figs") / f"{d.parts[IDX_ID]}_{d.parts[IDX_series]}"
for d in folders
]
# %%
# Generate figures for all inferences
for study_dir, save_dir in tqdm(list(zip(folders, saved_folders))[17:]):
try:
# Save inference figure to save_dir
y = display_volumes(
study_dir=study_dir,
style="pred",
plot_error=True,
skip_display=False,
save_dir=save_dir,
)
except Exception as e:
print(e)
# %%
inference_files = Path("saved_inference").glob("**/*img.npy")
folders = list(set([f.parent for f in inference_files]))
print(folders)
# %%
for folder in folders:
print(f"Folder: {folder}")
imgs = list(folder.glob("*img.npy"))
imgs = sorted(imgs)
np_imgs = [(np.load(i))[0] for i in imgs]
np_vol = np.stack(np_imgs).T
print(np_vol.shape)
print(f"{folder.name}_img_vol.nii")
dicom_paths = list(folder.glob("*.dcm"))
dicom_paths = sorted(dicom_paths)
reader = sitk.ImageSeriesReader()
reader.SetFileNames([str(path) for path in dicom_paths])
errors = []
try:
image_3d = reader.Execute()
sitk.WriteImage(
image_3d,
str(folder / f"{folder.name}_original.nii"), # noqa
)
except Exception as e:
print(e)
errors.append(f"error:{str(e)}\n folder:{folder}")
nifi_vol = nib.Nifti1Image(np_vol, affine=np.eye(4))
nib.save(nifi_vol, folder / f"{folder.name}_img_vol.nii")