Switch to unified view

a b/experiments/collect_generalizability_metrics.py
1
import pickle
2
from os import listdir
3
from os.path import join
4
import pickle as pkl
5
import matplotlib.pyplot as plt
6
from tqdm import tqdm
7
import pandas as pd
8
import torch
9
import numpy as np
10
from data.etis import EtisDataset
11
from data.hyperkvasir import KvasirSegmentationDataset
12
from data.endocv import EndoCV2020
13
from data.cvc import CVC_ClinicDB
14
from models.segmentation_models import *
15
from models.ensembles import *
16
from evaluation import metrics
17
from torch.utils.data import DataLoader
18
from perturbation.model import ModelOfNaturalVariation
19
import random
20
21
22
class ModelEvaluator:
23
    def __init__(self):
24
        self.datasets = [
25
            EtisDataset("Datasets/ETIS-LaribPolypDB"),
26
            CVC_ClinicDB("Datasets/CVC-ClinicDB"),
27
            EndoCV2020("Datasets/EndoCV2020"),
28
        ]
29
        self.dataloaders = [
30
                               DataLoader(KvasirSegmentationDataset("Datasets/HyperKvasir", split="test"))] + \
31
                           [DataLoader(dataset) for dataset in self.datasets]
32
        self.dataset_names = ["Kvasir-Seg", "Etis-LaribDB", "CVC-ClinicDB", "EndoCV2020"]
33
        # self.models = [DeepLab, FPN, InductiveNet, TriUnet, Unet]
34
        # self.model_names = ["DeepLab", "FPN", "InductiveNet", "TriUnet", "Unet"]
35
        # self.models = [FPN]
36
        # self.model_names = ["FPN"]
37
        # self.models = [InductiveNet]
38
        # self.model_names = ["InductiveNet"]
39
        self.models = [Unet]
40
        self.model_names = ["Unet"]
41
42
    def parse_experiment_details(self, model_name, eval_method, loss_fn, aug, id, last_epoch=False):
43
        """
44
            Note: Supremely messy, since the file structure just sort of evolved
45
        """
46
        path = "Predictors"
47
        if aug != "0":
48
            path = join(path, "Augmented")
49
            path = join(path, model_name)
50
            path = join(path, eval_method)
51
            if aug == "G":
52
                path += "inpainter_"
53
            if loss_fn == "sil":
54
                path += "consistency"
55
            else:
56
                if model_name == "InductiveNet" and aug != "V":
57
                    path += "zaugmentation"  # oops
58
                else:
59
                    path += "augmentation"
60
            path += f"_{id}"
61
62
63
        else:
64
            path = join(path, "Vanilla")
65
            path = join(path, model_name)
66
            path = join(path, "vanilla")
67
            path += f"_{id}"
68
            if eval_method == "maximum_consistency":
69
                path += "-maximum-consistency"
70
            elif last_epoch:
71
                path += "_last_epoch"
72
        return torch.load(path), path
73
74
    def get_table_data(self, sample_range, id_range, show_reconstruction=False, show_consistency_examples=False):
75
        mnv = ModelOfNaturalVariation(1)
76
        for model_constructor, model_name in zip(self.models, self.model_names):
77
            for eval_method in [""]:
78
                for loss_fn in ["j"]:
79
                    for aug in ["G"]:
80
                        sis_matrix = np.zeros((len(self.dataloaders), len(id_range)))
81
                        mean_ious = np.zeros((len(self.dataloaders), len(id_range)))
82
                        for id in id_range:
83
                            try:
84
                                state_dict, full_name = self.parse_experiment_details(model_name, eval_method, loss_fn,
85
                                                                                      aug,
86
                                                                                      id)
87
                                model = model_constructor().to("cuda")
88
                                model.load_state_dict(state_dict)
89
                                print(f"Evaluating {full_name}")
90
                            except FileNotFoundError:
91
                                print(f"{model_name}-{eval_method}-{loss_fn}-{aug}-{id} not found, continuing...")
92
                                continue
93
                            # fig, ax = plt.subplots(ncols=4, nrows=3, figsize=(4, 3), dpi=1000)
94
                            # fig.subplots_adjust(wspace=0, hspace=0)
95
                            for dl_idx, dataloader in enumerate(self.dataloaders):
96
                                # print("dl idx: ", dl_idx)
97
                                # seeding ensures SIS metrics are non-stochastic
98
                                np.random.seed(0)
99
                                torch.manual_seed(0)
100
                                random.seed(0)
101
102
                                for i, (x, y, _) in enumerate(dataloader):
103
                                    img, mask = x.to("cuda"), y.to("cuda")
104
                                    aug_img, aug_mask = mnv(img, mask)
105
                                    out = model.predict(img)
106
                                    aug_out = model.predict(aug_img)
107
108
                                    if dl_idx == 0 and show_consistency_examples:
109
                                        fig, ax = plt.subplots(2, 3)
110
                                        xor = lambda a, b: a * (1 - b) + b * (1 - a)
111
                                        diff = xor(xor(out, aug_out), xor(mask, aug_mask))
112
                                        union = torch.clamp((out + aug_out + mask + aug_mask), 0, 1)
113
                                        fig.suptitle(
114
                                            f"Inconsistency: {metrics.sis(aug_mask, mask, aug_out, out)}")
115
                                        ax[0, 0].imshow(img[0].cpu().numpy().T)
116
                                        ax[0, 0].set_title("Unperturbed Image")
117
                                        ax[1, 0].imshow(aug_img[0].cpu().numpy().T)
118
                                        ax[1, 0].set_title("Perturbed Image")
119
                                        ax[0, 1].imshow(out[0].cpu().numpy().T)
120
                                        ax[0, 1].set_title("Unperturbed Output")
121
122
                                        ax[1, 1].imshow(aug_out[0].cpu().numpy().T)
123
                                        ax[1, 1].set_title("Perturbed Output")
124
                                        print(ax[1, 1].get_position())
125
                                        ax[0, 2].imshow(diff[0].cpu().numpy().T, cmap="viridis")
126
                                        ax[0, 2].set_title("Inconsistency")
127
                                        # print(ax[0, 2].get_position())
128
                                        ax[0, 2].set_position([0.67, 0.34, 0.90, 0])
129
130
                                        # ax[1, 2].imshow(intersection[0].cpu().numpy().T)
131
                                        # ax[1, 2].set_title("Consistency")
132
133
                                        for axi in ax.flatten():
134
                                            axi.set_yticks([])
135
                                            axi.set_xticks([])
136
                                            axi.spines['top'].set_visible(False)
137
                                            axi.spines['right'].set_visible(False)
138
                                            axi.spines['bottom'].set_visible(False)
139
                                            axi.spines['left'].set_visible(False)
140
                                        plt.subplots_adjust(wspace=0.1, hspace=0.1)
141
                                        plt.show()
142
143
                                    if i == 0 and dl_idx == 0 and show_consistency_examples:
144
                                        with torch.no_grad():
145
                                            fig, ax = plt.subplots(ncols=3, nrows=2, figsize=(2, 2), dpi=1000,
146
                                                                   sharex=True, sharey=True)
147
                                            out, reconstruction = model(img)
148
                                            img_n = img + torch.rand_like(img) / 2.5
149
                                            out_n, reconstruction_n = model(img_n)
150
                                            xor = lambda a, b: a * (1 - b) + b * (1 - a)
151
                                            diff = xor(xor(out, aug_out), xor(mask, aug_mask))
152
153
                                            union = torch.clamp((out + out_n), 0, 1)
154
                                            ax[0, 0].imshow(img[0].cpu().numpy().T)
155
                                            ax[0, 0].set_title("Unperturbed Image")
156
                                            ax[1, 0].imshow(img_n[0].cpu().numpy().T)
157
                                            ax[1, 0].set_title("Perturbed Image")
158
                                            ax[0, 1].imshow(out[0].cpu().numpy().T)
159
                                            ax[0, 1].set_title("Unperturbed Output")
160
161
                                            ax[1, 1].imshow(out_n[0].cpu().numpy().T)
162
                                            ax[1, 1].set_title("Perturbed Output")
163
164
                                            ax[0, 2].imshow(diff[0].cpu().numpy().T)
165
                                            ax[0, 2].set_title("Inconsistency")
166
167
                                            # ax[1, 2].imshow(intersection[0].cpu().numpy().T)
168
                                            # ax[1, 2].set_title("Consistency")
169
170
                                            for axi in ax.flatten():
171
                                                axi.title.set_size(3.5)
172
173
                                                axi.set_yticks([])
174
                                                axi.set_xticks([])
175
                                                axi.spines['top'].set_visible(False)
176
                                                axi.spines['right'].set_visible(False)
177
                                                axi.spines['bottom'].set_visible(False)
178
                                                axi.spines['left'].set_visible(False)
179
                                            plt.subplots_adjust(wspace=0.1, hspace=0.1)
180
                                            plt.savefig("consistency_examples.png")
181
                                            plt.show()
182
183
                                            # print(torch.sum(diff) / torch.sum(union))
184
                                            # print(torch.sum(intersection) / torch.sum(union))
185
186
                                            input()
187
                                    if show_reconstruction and i == 0:
188
189
                                        with torch.no_grad():
190
                                            out, reconstruction = model(img)
191
                                            # all_l1s[dl_idx].append(np.mean(np.mean(
192
                                            #     np.abs(reconstruction[0].cpu().numpy().T - x[0].cpu().numpy().T))))
193
                                            #             axis=-1)))
194
                                            # ax[0, dl_idx].axis("off")
195
                                            # ax[1, dl_idx].axis("off")
196
                                            # ax[2, dl_idx].axis("off")
197
                                            # ax[3, dl_idx].axis("off")
198
199
                                            # ax[0, dl_idx].set_xlabel(self.dataset_names[dl_idx])
200
                                            for i in range(4):
201
                                                ax[0, i].title.set_text(self.dataset_names[i])
202
                                                ax[0, i].title.set_size(8)
203
                                            ax[0, 0].set_ylabel("Original", fontsize=8)
204
                                            ax[1, 0].set_ylabel("Reconstruction", fontsize=8)
205
                                            ax[2, 0].set_ylabel("L1", fontsize=8)
206
                                            for axi in ax.flatten():
207
                                                axi.set_yticks([])
208
                                                axi.set_xticks([])
209
                                                axi.spines['top'].set_visible(False)
210
                                                axi.spines['right'].set_visible(False)
211
                                                axi.spines['bottom'].set_visible(False)
212
                                                axi.spines['left'].set_visible(False)
213
214
                                            ax[0, dl_idx].imshow(x[0].cpu().numpy().T)
215
                                            ax[1, dl_idx].imshow(reconstruction[0].cpu().numpy().T)
216
                                            ax[2, dl_idx].imshow(
217
                                                np.mean(
218
                                                    np.abs(reconstruction[0].cpu().numpy().T - x[0].cpu().numpy().T),
219
                                                    axis=-1))
220
                                            # all_l1s[dl_idx].append(np.mean(
221
                                            #     np.mean(np.abs(reconstruction[0].cpu().numpy().T - x[0].cpu().numpy().T),
222
                                            #             axis=-1)))
223
224
                                    iou = metrics.iou(out, mask)
225
                                    # consistency
226
                                    sis = metrics.sis(aug_mask, mask, aug_out, out)
227
                                    sis_matrix[dl_idx, id - id_range[0]] += sis / len(dataloader)
228
                                    mean_ious[dl_idx, id - id_range[0]] += iou / len(dataloader)
229
230
                            # print(
231
                            #     f"{full_name} has iou {mean_ious[0, id - 1]} and consistency {sis_matrix[0, id - 1]} ")
232
233
                            if mean_ious[0, id - 1] < 0.8:
234
                                print(f"{full_name} has iou {mean_ious[0, id - 1]} ")
235
                        with open(f"experiments/Data/pickles/{model_name}_{eval_method}_{loss_fn}_{aug}.pkl",
236
                                  "wb") as file:
237
                            pickle.dump({"ious": mean_ious, "sis": sis_matrix}, file)
238
239
240
class SingularEnsembleEvaluator:
241
    def __init__(self, samples=10):
242
        self.datasets = [
243
            EtisDataset("Datasets/ETIS-LaribPolypDB"),
244
            CVC_ClinicDB("Datasets/CVC-ClinicDB"),
245
            EndoCV2020("Datasets/EndoCV2020"),
246
        ]
247
        self.dataloaders = [
248
                               DataLoader(KvasirSegmentationDataset("Datasets/HyperKvasir", split="test"))] + \
249
                           [DataLoader(dataset) for dataset in self.datasets]
250
        self.dataset_names = ["HyperKvasir", "Etis-LaribDB", "CVC-ClinicDB", "EndoCV2020"]
251
        self.models = [DeepLab, FPN, InductiveNet, TriUnet, Unet]
252
        self.model_names = ["DeepLab", "FPN", "InductiveNet", "TriUnet", "Unet"]
253
        self.samples = samples
254
255
    def get_table_data(self, model_count):
256
        mnv = ModelOfNaturalVariation(0)
257
        for type in ["augmentation"]:
258
            for model_name in self.model_names:
259
                # if model_name != "TriUnet":
260
                #     continue
261
                print(model_name)
262
                mean_ious = np.zeros((len(self.dataloaders), self.samples))
263
                constituents = {}
264
                for i in range(self.samples):
265
                    model = SingularEnsemble(model_name, type, model_count)
266
                    constituents[i] = model.get_constituents()
267
                    for dl_idx, dataloader in enumerate(self.dataloaders):
268
                        for x, y, _ in tqdm(dataloader):
269
                            img, mask = x.to("cuda"), y.to("cuda")
270
                            out = model.predict(img, threshold=True)
271
272
                            iou = metrics.iou(out, mask)
273
                            mean_ious[dl_idx, i] += iou / len(dataloader)
274
                    del model  # avoid memory issues
275
                print(mean_ious)
276
                if type == "consistency":
277
                    with open(f"experiments/Data/pickles/{model_name}-ensemble-{model_count}.pkl",
278
                              "wb") as file:
279
                        pickle.dump({"ious": mean_ious, "constituents": constituents}, file)
280
                else:
281
                    with open(f"experiments/Data/pickles/{model_name}-ensemble-{model_count}-{type}.pkl",
282
                              "wb") as file:
283
                        pickle.dump({"ious": mean_ious, "constituents": constituents}, file)
284
285
286
class DiverseEnsembleEvaluator:
287
    def __init__(self, samples=10):
288
        self.datasets = [
289
            EtisDataset("Datasets/ETIS-LaribPolypDB"),
290
            CVC_ClinicDB("Datasets/CVC-ClinicDB"),
291
            EndoCV2020("Datasets/EndoCV2020"),
292
        ]
293
        self.dataloaders = [
294
                               DataLoader(KvasirSegmentationDataset("Datasets/HyperKvasir", split="test"))] + \
295
                           [DataLoader(dataset) for dataset in self.datasets]
296
        self.dataset_names = ["HyperKvasir", "Etis-LaribDB", "CVC-ClinicDB", "EndoCV2020"]
297
        self.models = [DeepLab, FPN, InductiveNet, TriUnet, Unet]
298
        self.model_names = ["DeepLab", "FPN", "InductiveNet", "TriUnet", "Unet"]
299
        self.samples = samples
300
301
    def get_table_data(self):
302
        mnv = ModelOfNaturalVariation(0)
303
        for type in ["augmentation"]:
304
            mean_ious = np.zeros((len(self.dataloaders), self.samples))
305
            constituents = {}
306
            for i in range(1, self.samples + 1):
307
                model = DiverseEnsemble(i, type)
308
                constituents[i] = model.get_constituents()
309
                for dl_idx, dataloader in enumerate(self.dataloaders):
310
                    for x, y, _ in tqdm(dataloader):
311
                        img, mask = x.to("cuda"), y.to("cuda")
312
                        out = model.predict(img)
313
                        iou = metrics.iou(out, mask)
314
                        mean_ious[dl_idx, i - 1] += iou / len(dataloader)
315
                if mean_ious[0, i - 1] < 0.80:
316
                    print(f"{i} has iou {mean_ious[0, i - 1]}")
317
                print(mean_ious)
318
            with open(f"experiments/Data/pickles/diverse-ensemble-{type}.pkl",
319
                      "wb") as file:
320
                pickle.dump({"ious": mean_ious, "constituents": constituents}, file)
321
322
323
def write_to_latex_table(pkl_file):
324
    table_template = open("table_template").read()
325
326
327
if __name__ == '__main__':
328
    np.set_printoptions(precision=3, suppress=True)
329
    evaluator = ModelEvaluator()
330
    evaluator.get_table_data(np.arange(0, 10), np.arange(1, 11), show_reconstruction=False,
331
                             show_consistency_examples=False)
332
    # evaluator = DiverseEnsembleEvaluator(samples=10)
333
    # evaluator.get_table_data()
334
    # evaluator = SingularEnsembleEvaluator()
335
    # evaluator.get_table_data(5)
336
    #
337
    # get_metrics_for_experiment("Augmented", "consistency_1")