[637b40]: / notebooks / initial_error_analysis.py

Download this file

147 lines (118 with data), 4.1 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
"""Check low dice examples"""
# %%
import os
import yaml
import matplotlib.pyplot as plt
import seaborn as sns
import torch
# enable lib loading even if not installed as a pip package or in PYTHONPATH
# also convenient for relative paths in example config files
from pathlib import Path
os.chdir(Path(__file__).resolve().parent.parent)
from adpkd_segmentation.config.config_utils import get_object_instance # noqa
from adpkd_segmentation.evaluate import validate # noqa
from adpkd_segmentation.utils.train_utils import load_model_data # noqa
sns.set()
# %%
# CONFIG = "experiments/september/11_stratified_albu_v2_b4_simple_norm/val/val.yaml" # noqa
# CONFIG = "./experiments/september06/random_split_new_data_less_albu/test/test.yaml"
CONFIG = "./experiments/september03/random_split_new_data_less_albu/test/test.yaml"
with open(CONFIG, "r") as f:
config = yaml.load(f, Loader=yaml.FullLoader)
# %%
model_config = config["_MODEL_CONFIG"]
dataloader_config = config["_TEST_DATALOADER_CONFIG"]
losses_config = config["_LOSSES_METRICS_CONFIG"]
saved_checkpoint = config["_MODEL_CHECKPOINT"]
# override
dataloader_config["batchsize"] = 1
# %%
model = get_object_instance(model_config)()
load_model_data(saved_checkpoint, model, True)
dataloader = get_object_instance(dataloader_config)()
loss_metric = get_object_instance(losses_config)()
# %%
device = torch.device("cpu:0")
model = model.to(device)
model.eval()
# %%
averaged, all_losses_and_metrics = validate(
dataloader, model, loss_metric, device, output_losses_list=True
)
# %%
dice_scores = all_losses_and_metrics["dice_metric"]
# %%
plt.xlabel("Example index")
plt.ylabel("Dice score")
sns.scatterplot(x=range(len(dice_scores)), y=dice_scores, alpha=0.4)
# %%
very_low_dice = [
(dice, idx) for idx, dice in enumerate(dice_scores) if dice < 0.05
]
# %%
# check one example
image, mask, idx = dataloader.dataset[very_low_dice[1][1]]
print(mask.sum())
im_tensor = torch.from_numpy(image).unsqueeze(0)
# %%
pred = model(im_tensor) # (1, 1, 224, 224)
# %%
im = image[0] # (3, 224, 224) original
msk = mask[0] # (1, 224, 224) original
pred_sigm = torch.sigmoid(pred)[0][0].detach().numpy()
# %%
plt.rcParams["axes.grid"] = False
f, axarr = plt.subplots(1, 3)
axarr[0].imshow(im, cmap="gray")
axarr[1].imshow(im, cmap="gray") # background for mask
axarr[1].imshow(msk, alpha=0.7)
axarr[2].imshow(im, cmap="gray") # background for mask
axarr[2].imshow(pred_sigm, alpha=0.7)
# %%
# check all low dice
for dice, idx in very_low_dice:
image, mask, _ = dataloader.dataset[idx]
im_tensor = torch.from_numpy(image).unsqueeze(0)
pred = model(im_tensor)
im = image[0] # (3, 224, 224) original
msk = mask[0] # (1, 224, 224) original
pred_sigm = torch.sigmoid(pred)[0][0].detach().numpy()
plt.rcParams["axes.grid"] = False
f, axarr = plt.subplots(1, 3)
axarr[0].imshow(im, cmap="gray")
axarr[1].imshow(im, cmap="gray") # background for mask
axarr[1].imshow(msk, alpha=0.5)
axarr[2].imshow(im, cmap="gray") # background for mask
axarr[2].imshow(pred_sigm, alpha=0.5)
# %%
middle_dice = [
(dice, idx)
for idx, dice in enumerate(dice_scores)
if dice > 0.05 and dice < 0.8
]
# %%
def check_prediction(idx):
image, mask, _ = dataloader.dataset[idx]
im_tensor = torch.from_numpy(image).unsqueeze(0)
pred = model(im_tensor)
im = image[0] # (3, 224, 224) original
msk = mask[0] # (1, 224, 224) original
pred_sigm = torch.sigmoid(pred)[0][0].detach().numpy()
plt.rcParams["axes.grid"] = False
f, axarr = plt.subplots(1, 3)
axarr[0].imshow(im, cmap="gray")
axarr[1].imshow(im, cmap="gray") # background for mask
axarr[1].imshow(msk, alpha=0.5)
axarr[2].imshow(im, cmap="gray") # background for mask
axarr[2].imshow(pred_sigm, alpha=0.5)
# %%
for dice, idx in middle_dice:
check_prediction(idx)
# %%
def get_patients(dice_scores, dataset):
patients = []
for score, idx in dice_scores:
_, _, attribs = dataset.get_verbose(idx)
patients.append(attribs["patient"])
return patients
# %%