|
a |
|
b/experiments/collect_generalizability_metrics.py |
|
|
1 |
import pickle |
|
|
2 |
from os import listdir |
|
|
3 |
from os.path import join |
|
|
4 |
import pickle as pkl |
|
|
5 |
import matplotlib.pyplot as plt |
|
|
6 |
from tqdm import tqdm |
|
|
7 |
import pandas as pd |
|
|
8 |
import torch |
|
|
9 |
import numpy as np |
|
|
10 |
from data.etis import EtisDataset |
|
|
11 |
from data.hyperkvasir import KvasirSegmentationDataset |
|
|
12 |
from data.endocv import EndoCV2020 |
|
|
13 |
from data.cvc import CVC_ClinicDB |
|
|
14 |
from models.segmentation_models import * |
|
|
15 |
from models.ensembles import * |
|
|
16 |
from evaluation import metrics |
|
|
17 |
from torch.utils.data import DataLoader |
|
|
18 |
from perturbation.model import ModelOfNaturalVariation |
|
|
19 |
import random |
|
|
20 |
|
|
|
21 |
|
|
|
22 |
class ModelEvaluator: |
|
|
23 |
def __init__(self): |
|
|
24 |
self.datasets = [ |
|
|
25 |
EtisDataset("Datasets/ETIS-LaribPolypDB"), |
|
|
26 |
CVC_ClinicDB("Datasets/CVC-ClinicDB"), |
|
|
27 |
EndoCV2020("Datasets/EndoCV2020"), |
|
|
28 |
] |
|
|
29 |
self.dataloaders = [ |
|
|
30 |
DataLoader(KvasirSegmentationDataset("Datasets/HyperKvasir", split="test"))] + \ |
|
|
31 |
[DataLoader(dataset) for dataset in self.datasets] |
|
|
32 |
self.dataset_names = ["Kvasir-Seg", "Etis-LaribDB", "CVC-ClinicDB", "EndoCV2020"] |
|
|
33 |
# self.models = [DeepLab, FPN, InductiveNet, TriUnet, Unet] |
|
|
34 |
# self.model_names = ["DeepLab", "FPN", "InductiveNet", "TriUnet", "Unet"] |
|
|
35 |
# self.models = [FPN] |
|
|
36 |
# self.model_names = ["FPN"] |
|
|
37 |
# self.models = [InductiveNet] |
|
|
38 |
# self.model_names = ["InductiveNet"] |
|
|
39 |
self.models = [Unet] |
|
|
40 |
self.model_names = ["Unet"] |
|
|
41 |
|
|
|
42 |
def parse_experiment_details(self, model_name, eval_method, loss_fn, aug, id, last_epoch=False): |
|
|
43 |
""" |
|
|
44 |
Note: Supremely messy, since the file structure just sort of evolved |
|
|
45 |
""" |
|
|
46 |
path = "Predictors" |
|
|
47 |
if aug != "0": |
|
|
48 |
path = join(path, "Augmented") |
|
|
49 |
path = join(path, model_name) |
|
|
50 |
path = join(path, eval_method) |
|
|
51 |
if aug == "G": |
|
|
52 |
path += "inpainter_" |
|
|
53 |
if loss_fn == "sil": |
|
|
54 |
path += "consistency" |
|
|
55 |
else: |
|
|
56 |
if model_name == "InductiveNet" and aug != "V": |
|
|
57 |
path += "zaugmentation" # oops |
|
|
58 |
else: |
|
|
59 |
path += "augmentation" |
|
|
60 |
path += f"_{id}" |
|
|
61 |
|
|
|
62 |
|
|
|
63 |
else: |
|
|
64 |
path = join(path, "Vanilla") |
|
|
65 |
path = join(path, model_name) |
|
|
66 |
path = join(path, "vanilla") |
|
|
67 |
path += f"_{id}" |
|
|
68 |
if eval_method == "maximum_consistency": |
|
|
69 |
path += "-maximum-consistency" |
|
|
70 |
elif last_epoch: |
|
|
71 |
path += "_last_epoch" |
|
|
72 |
return torch.load(path), path |
|
|
73 |
|
|
|
74 |
def get_table_data(self, sample_range, id_range, show_reconstruction=False, show_consistency_examples=False): |
|
|
75 |
mnv = ModelOfNaturalVariation(1) |
|
|
76 |
for model_constructor, model_name in zip(self.models, self.model_names): |
|
|
77 |
for eval_method in [""]: |
|
|
78 |
for loss_fn in ["j"]: |
|
|
79 |
for aug in ["G"]: |
|
|
80 |
sis_matrix = np.zeros((len(self.dataloaders), len(id_range))) |
|
|
81 |
mean_ious = np.zeros((len(self.dataloaders), len(id_range))) |
|
|
82 |
for id in id_range: |
|
|
83 |
try: |
|
|
84 |
state_dict, full_name = self.parse_experiment_details(model_name, eval_method, loss_fn, |
|
|
85 |
aug, |
|
|
86 |
id) |
|
|
87 |
model = model_constructor().to("cuda") |
|
|
88 |
model.load_state_dict(state_dict) |
|
|
89 |
print(f"Evaluating {full_name}") |
|
|
90 |
except FileNotFoundError: |
|
|
91 |
print(f"{model_name}-{eval_method}-{loss_fn}-{aug}-{id} not found, continuing...") |
|
|
92 |
continue |
|
|
93 |
# fig, ax = plt.subplots(ncols=4, nrows=3, figsize=(4, 3), dpi=1000) |
|
|
94 |
# fig.subplots_adjust(wspace=0, hspace=0) |
|
|
95 |
for dl_idx, dataloader in enumerate(self.dataloaders): |
|
|
96 |
# print("dl idx: ", dl_idx) |
|
|
97 |
# seeding ensures SIS metrics are non-stochastic |
|
|
98 |
np.random.seed(0) |
|
|
99 |
torch.manual_seed(0) |
|
|
100 |
random.seed(0) |
|
|
101 |
|
|
|
102 |
for i, (x, y, _) in enumerate(dataloader): |
|
|
103 |
img, mask = x.to("cuda"), y.to("cuda") |
|
|
104 |
aug_img, aug_mask = mnv(img, mask) |
|
|
105 |
out = model.predict(img) |
|
|
106 |
aug_out = model.predict(aug_img) |
|
|
107 |
|
|
|
108 |
if dl_idx == 0 and show_consistency_examples: |
|
|
109 |
fig, ax = plt.subplots(2, 3) |
|
|
110 |
xor = lambda a, b: a * (1 - b) + b * (1 - a) |
|
|
111 |
diff = xor(xor(out, aug_out), xor(mask, aug_mask)) |
|
|
112 |
union = torch.clamp((out + aug_out + mask + aug_mask), 0, 1) |
|
|
113 |
fig.suptitle( |
|
|
114 |
f"Inconsistency: {metrics.sis(aug_mask, mask, aug_out, out)}") |
|
|
115 |
ax[0, 0].imshow(img[0].cpu().numpy().T) |
|
|
116 |
ax[0, 0].set_title("Unperturbed Image") |
|
|
117 |
ax[1, 0].imshow(aug_img[0].cpu().numpy().T) |
|
|
118 |
ax[1, 0].set_title("Perturbed Image") |
|
|
119 |
ax[0, 1].imshow(out[0].cpu().numpy().T) |
|
|
120 |
ax[0, 1].set_title("Unperturbed Output") |
|
|
121 |
|
|
|
122 |
ax[1, 1].imshow(aug_out[0].cpu().numpy().T) |
|
|
123 |
ax[1, 1].set_title("Perturbed Output") |
|
|
124 |
print(ax[1, 1].get_position()) |
|
|
125 |
ax[0, 2].imshow(diff[0].cpu().numpy().T, cmap="viridis") |
|
|
126 |
ax[0, 2].set_title("Inconsistency") |
|
|
127 |
# print(ax[0, 2].get_position()) |
|
|
128 |
ax[0, 2].set_position([0.67, 0.34, 0.90, 0]) |
|
|
129 |
|
|
|
130 |
# ax[1, 2].imshow(intersection[0].cpu().numpy().T) |
|
|
131 |
# ax[1, 2].set_title("Consistency") |
|
|
132 |
|
|
|
133 |
for axi in ax.flatten(): |
|
|
134 |
axi.set_yticks([]) |
|
|
135 |
axi.set_xticks([]) |
|
|
136 |
axi.spines['top'].set_visible(False) |
|
|
137 |
axi.spines['right'].set_visible(False) |
|
|
138 |
axi.spines['bottom'].set_visible(False) |
|
|
139 |
axi.spines['left'].set_visible(False) |
|
|
140 |
plt.subplots_adjust(wspace=0.1, hspace=0.1) |
|
|
141 |
plt.show() |
|
|
142 |
|
|
|
143 |
if i == 0 and dl_idx == 0 and show_consistency_examples: |
|
|
144 |
with torch.no_grad(): |
|
|
145 |
fig, ax = plt.subplots(ncols=3, nrows=2, figsize=(2, 2), dpi=1000, |
|
|
146 |
sharex=True, sharey=True) |
|
|
147 |
out, reconstruction = model(img) |
|
|
148 |
img_n = img + torch.rand_like(img) / 2.5 |
|
|
149 |
out_n, reconstruction_n = model(img_n) |
|
|
150 |
xor = lambda a, b: a * (1 - b) + b * (1 - a) |
|
|
151 |
diff = xor(xor(out, aug_out), xor(mask, aug_mask)) |
|
|
152 |
|
|
|
153 |
union = torch.clamp((out + out_n), 0, 1) |
|
|
154 |
ax[0, 0].imshow(img[0].cpu().numpy().T) |
|
|
155 |
ax[0, 0].set_title("Unperturbed Image") |
|
|
156 |
ax[1, 0].imshow(img_n[0].cpu().numpy().T) |
|
|
157 |
ax[1, 0].set_title("Perturbed Image") |
|
|
158 |
ax[0, 1].imshow(out[0].cpu().numpy().T) |
|
|
159 |
ax[0, 1].set_title("Unperturbed Output") |
|
|
160 |
|
|
|
161 |
ax[1, 1].imshow(out_n[0].cpu().numpy().T) |
|
|
162 |
ax[1, 1].set_title("Perturbed Output") |
|
|
163 |
|
|
|
164 |
ax[0, 2].imshow(diff[0].cpu().numpy().T) |
|
|
165 |
ax[0, 2].set_title("Inconsistency") |
|
|
166 |
|
|
|
167 |
# ax[1, 2].imshow(intersection[0].cpu().numpy().T) |
|
|
168 |
# ax[1, 2].set_title("Consistency") |
|
|
169 |
|
|
|
170 |
for axi in ax.flatten(): |
|
|
171 |
axi.title.set_size(3.5) |
|
|
172 |
|
|
|
173 |
axi.set_yticks([]) |
|
|
174 |
axi.set_xticks([]) |
|
|
175 |
axi.spines['top'].set_visible(False) |
|
|
176 |
axi.spines['right'].set_visible(False) |
|
|
177 |
axi.spines['bottom'].set_visible(False) |
|
|
178 |
axi.spines['left'].set_visible(False) |
|
|
179 |
plt.subplots_adjust(wspace=0.1, hspace=0.1) |
|
|
180 |
plt.savefig("consistency_examples.png") |
|
|
181 |
plt.show() |
|
|
182 |
|
|
|
183 |
# print(torch.sum(diff) / torch.sum(union)) |
|
|
184 |
# print(torch.sum(intersection) / torch.sum(union)) |
|
|
185 |
|
|
|
186 |
input() |
|
|
187 |
if show_reconstruction and i == 0: |
|
|
188 |
|
|
|
189 |
with torch.no_grad(): |
|
|
190 |
out, reconstruction = model(img) |
|
|
191 |
# all_l1s[dl_idx].append(np.mean(np.mean( |
|
|
192 |
# np.abs(reconstruction[0].cpu().numpy().T - x[0].cpu().numpy().T)))) |
|
|
193 |
# axis=-1))) |
|
|
194 |
# ax[0, dl_idx].axis("off") |
|
|
195 |
# ax[1, dl_idx].axis("off") |
|
|
196 |
# ax[2, dl_idx].axis("off") |
|
|
197 |
# ax[3, dl_idx].axis("off") |
|
|
198 |
|
|
|
199 |
# ax[0, dl_idx].set_xlabel(self.dataset_names[dl_idx]) |
|
|
200 |
for i in range(4): |
|
|
201 |
ax[0, i].title.set_text(self.dataset_names[i]) |
|
|
202 |
ax[0, i].title.set_size(8) |
|
|
203 |
ax[0, 0].set_ylabel("Original", fontsize=8) |
|
|
204 |
ax[1, 0].set_ylabel("Reconstruction", fontsize=8) |
|
|
205 |
ax[2, 0].set_ylabel("L1", fontsize=8) |
|
|
206 |
for axi in ax.flatten(): |
|
|
207 |
axi.set_yticks([]) |
|
|
208 |
axi.set_xticks([]) |
|
|
209 |
axi.spines['top'].set_visible(False) |
|
|
210 |
axi.spines['right'].set_visible(False) |
|
|
211 |
axi.spines['bottom'].set_visible(False) |
|
|
212 |
axi.spines['left'].set_visible(False) |
|
|
213 |
|
|
|
214 |
ax[0, dl_idx].imshow(x[0].cpu().numpy().T) |
|
|
215 |
ax[1, dl_idx].imshow(reconstruction[0].cpu().numpy().T) |
|
|
216 |
ax[2, dl_idx].imshow( |
|
|
217 |
np.mean( |
|
|
218 |
np.abs(reconstruction[0].cpu().numpy().T - x[0].cpu().numpy().T), |
|
|
219 |
axis=-1)) |
|
|
220 |
# all_l1s[dl_idx].append(np.mean( |
|
|
221 |
# np.mean(np.abs(reconstruction[0].cpu().numpy().T - x[0].cpu().numpy().T), |
|
|
222 |
# axis=-1))) |
|
|
223 |
|
|
|
224 |
iou = metrics.iou(out, mask) |
|
|
225 |
# consistency |
|
|
226 |
sis = metrics.sis(aug_mask, mask, aug_out, out) |
|
|
227 |
sis_matrix[dl_idx, id - id_range[0]] += sis / len(dataloader) |
|
|
228 |
mean_ious[dl_idx, id - id_range[0]] += iou / len(dataloader) |
|
|
229 |
|
|
|
230 |
# print( |
|
|
231 |
# f"{full_name} has iou {mean_ious[0, id - 1]} and consistency {sis_matrix[0, id - 1]} ") |
|
|
232 |
|
|
|
233 |
if mean_ious[0, id - 1] < 0.8: |
|
|
234 |
print(f"{full_name} has iou {mean_ious[0, id - 1]} ") |
|
|
235 |
with open(f"experiments/Data/pickles/{model_name}_{eval_method}_{loss_fn}_{aug}.pkl", |
|
|
236 |
"wb") as file: |
|
|
237 |
pickle.dump({"ious": mean_ious, "sis": sis_matrix}, file) |
|
|
238 |
|
|
|
239 |
|
|
|
240 |
class SingularEnsembleEvaluator: |
|
|
241 |
def __init__(self, samples=10): |
|
|
242 |
self.datasets = [ |
|
|
243 |
EtisDataset("Datasets/ETIS-LaribPolypDB"), |
|
|
244 |
CVC_ClinicDB("Datasets/CVC-ClinicDB"), |
|
|
245 |
EndoCV2020("Datasets/EndoCV2020"), |
|
|
246 |
] |
|
|
247 |
self.dataloaders = [ |
|
|
248 |
DataLoader(KvasirSegmentationDataset("Datasets/HyperKvasir", split="test"))] + \ |
|
|
249 |
[DataLoader(dataset) for dataset in self.datasets] |
|
|
250 |
self.dataset_names = ["HyperKvasir", "Etis-LaribDB", "CVC-ClinicDB", "EndoCV2020"] |
|
|
251 |
self.models = [DeepLab, FPN, InductiveNet, TriUnet, Unet] |
|
|
252 |
self.model_names = ["DeepLab", "FPN", "InductiveNet", "TriUnet", "Unet"] |
|
|
253 |
self.samples = samples |
|
|
254 |
|
|
|
255 |
def get_table_data(self, model_count): |
|
|
256 |
mnv = ModelOfNaturalVariation(0) |
|
|
257 |
for type in ["augmentation"]: |
|
|
258 |
for model_name in self.model_names: |
|
|
259 |
# if model_name != "TriUnet": |
|
|
260 |
# continue |
|
|
261 |
print(model_name) |
|
|
262 |
mean_ious = np.zeros((len(self.dataloaders), self.samples)) |
|
|
263 |
constituents = {} |
|
|
264 |
for i in range(self.samples): |
|
|
265 |
model = SingularEnsemble(model_name, type, model_count) |
|
|
266 |
constituents[i] = model.get_constituents() |
|
|
267 |
for dl_idx, dataloader in enumerate(self.dataloaders): |
|
|
268 |
for x, y, _ in tqdm(dataloader): |
|
|
269 |
img, mask = x.to("cuda"), y.to("cuda") |
|
|
270 |
out = model.predict(img, threshold=True) |
|
|
271 |
|
|
|
272 |
iou = metrics.iou(out, mask) |
|
|
273 |
mean_ious[dl_idx, i] += iou / len(dataloader) |
|
|
274 |
del model # avoid memory issues |
|
|
275 |
print(mean_ious) |
|
|
276 |
if type == "consistency": |
|
|
277 |
with open(f"experiments/Data/pickles/{model_name}-ensemble-{model_count}.pkl", |
|
|
278 |
"wb") as file: |
|
|
279 |
pickle.dump({"ious": mean_ious, "constituents": constituents}, file) |
|
|
280 |
else: |
|
|
281 |
with open(f"experiments/Data/pickles/{model_name}-ensemble-{model_count}-{type}.pkl", |
|
|
282 |
"wb") as file: |
|
|
283 |
pickle.dump({"ious": mean_ious, "constituents": constituents}, file) |
|
|
284 |
|
|
|
285 |
|
|
|
286 |
class DiverseEnsembleEvaluator: |
|
|
287 |
def __init__(self, samples=10): |
|
|
288 |
self.datasets = [ |
|
|
289 |
EtisDataset("Datasets/ETIS-LaribPolypDB"), |
|
|
290 |
CVC_ClinicDB("Datasets/CVC-ClinicDB"), |
|
|
291 |
EndoCV2020("Datasets/EndoCV2020"), |
|
|
292 |
] |
|
|
293 |
self.dataloaders = [ |
|
|
294 |
DataLoader(KvasirSegmentationDataset("Datasets/HyperKvasir", split="test"))] + \ |
|
|
295 |
[DataLoader(dataset) for dataset in self.datasets] |
|
|
296 |
self.dataset_names = ["HyperKvasir", "Etis-LaribDB", "CVC-ClinicDB", "EndoCV2020"] |
|
|
297 |
self.models = [DeepLab, FPN, InductiveNet, TriUnet, Unet] |
|
|
298 |
self.model_names = ["DeepLab", "FPN", "InductiveNet", "TriUnet", "Unet"] |
|
|
299 |
self.samples = samples |
|
|
300 |
|
|
|
301 |
def get_table_data(self): |
|
|
302 |
mnv = ModelOfNaturalVariation(0) |
|
|
303 |
for type in ["augmentation"]: |
|
|
304 |
mean_ious = np.zeros((len(self.dataloaders), self.samples)) |
|
|
305 |
constituents = {} |
|
|
306 |
for i in range(1, self.samples + 1): |
|
|
307 |
model = DiverseEnsemble(i, type) |
|
|
308 |
constituents[i] = model.get_constituents() |
|
|
309 |
for dl_idx, dataloader in enumerate(self.dataloaders): |
|
|
310 |
for x, y, _ in tqdm(dataloader): |
|
|
311 |
img, mask = x.to("cuda"), y.to("cuda") |
|
|
312 |
out = model.predict(img) |
|
|
313 |
iou = metrics.iou(out, mask) |
|
|
314 |
mean_ious[dl_idx, i - 1] += iou / len(dataloader) |
|
|
315 |
if mean_ious[0, i - 1] < 0.80: |
|
|
316 |
print(f"{i} has iou {mean_ious[0, i - 1]}") |
|
|
317 |
print(mean_ious) |
|
|
318 |
with open(f"experiments/Data/pickles/diverse-ensemble-{type}.pkl", |
|
|
319 |
"wb") as file: |
|
|
320 |
pickle.dump({"ious": mean_ious, "constituents": constituents}, file) |
|
|
321 |
|
|
|
322 |
|
|
|
323 |
def write_to_latex_table(pkl_file): |
|
|
324 |
table_template = open("table_template").read() |
|
|
325 |
|
|
|
326 |
|
|
|
327 |
if __name__ == '__main__': |
|
|
328 |
np.set_printoptions(precision=3, suppress=True) |
|
|
329 |
evaluator = ModelEvaluator() |
|
|
330 |
evaluator.get_table_data(np.arange(0, 10), np.arange(1, 11), show_reconstruction=False, |
|
|
331 |
show_consistency_examples=False) |
|
|
332 |
# evaluator = DiverseEnsembleEvaluator(samples=10) |
|
|
333 |
# evaluator.get_table_data() |
|
|
334 |
# evaluator = SingularEnsembleEvaluator() |
|
|
335 |
# evaluator.get_table_data(5) |
|
|
336 |
# |
|
|
337 |
# get_metrics_for_experiment("Augmented", "consistency_1") |