Switch to unified view

a b/training/inductivenet_trainers.py
1
import torch
2
import numpy as np
3
import matplotlib.pyplot as plt
4
import torch.optim.optimizer
5
from torch.utils.data import DataLoader
6
7
from data.hyperkvasir import KvasirSegmentationDataset, KvasirMNVset
8
from evaluation.metrics import iou
9
from losses.consistency_losses import *
10
from perturbation.model import ModelOfNaturalVariation
11
from training.vanilla_trainer import VanillaTrainer
12
from utils import logging
13
from training.consistency_trainers import ConsistencyTrainer
14
from models.segmentation_models import InductiveNet
15
from models.ensembles import TrainedEnsemble
16
from data.hyperkvasir import KvasirSegmentationDataset, KvasirMNVset
17
from evaluation.metrics import iou
18
from losses.consistency_losses import *
19
from perturbation.model import ModelOfNaturalVariation
20
from training.vanilla_trainer import VanillaTrainer
21
from utils import logging
22
from data.etis import EtisDataset
23
24
25
class InductiveNetConsistencyTrainer:
26
    def __init__(self, id, config):
27
        """
28
29
        :param model: String describing the model type. Can be DeepLab, TriUnet, ... TODO
30
        :param config: Contains hyperparameters : lr, epochs, batch_size, T_0, T_mult
31
        """
32
        self.config = config
33
        self.device = config["device"]
34
        self.lr = config["lr"]
35
        self.batch_size = config["batch_size"]
36
        self.epochs = config["epochs"]
37
        self.id = id
38
        self.model_str = "InductiveNet"
39
        self.mnv = ModelOfNaturalVariation(T0=1).to(self.device)
40
        self.nakedcloss = NakedConsistencyLoss()
41
        self.closs = ConsistencyLoss()
42
        self.model = InductiveNet().to(self.device)
43
44
        self.optimizer = torch.optim.Adam(self.model.parameters(), self.lr)
45
        self.jaccard = vanilla_losses.JaccardLoss()
46
        self.mse = nn.MSELoss()
47
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(self.optimizer, T_0=50, T_mult=2)
48
        self.train_set = KvasirSegmentationDataset("Datasets/HyperKvasir", split="train", augment=False)
49
        self.val_set = KvasirSegmentationDataset("Datasets/HyperKvasir", split="val")
50
        self.test_set = KvasirSegmentationDataset("Datasets/HyperKvasir", split="test")
51
        self.train_loader = DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True)
52
        self.val_loader = DataLoader(self.val_set)
53
        self.test_loader = DataLoader(self.test_set)
54
55
    def train_epoch(self):
56
        self.model.train()
57
        losses = []
58
        for x, y, fname in self.train_loader:
59
            image = x.to("cuda")
60
            mask = y.to("cuda")
61
            aug_img, aug_mask = self.mnv(image, mask)
62
            self.optimizer.zero_grad()
63
            aug_output, _ = self.model(aug_img)
64
            output, reconstruction = self.model(image)
65
            mean_iou = torch.mean(iou(output, mask))
66
            loss = 0.5 * (self.closs(aug_mask, mask, aug_output, output, mean_iou) + self.mse(
67
                image, reconstruction))
68
            loss.backward()
69
            self.optimizer.step()
70
            losses.append(np.abs(loss.item()))
71
        return np.mean(losses)
72
73
    def train(self):
74
        best_val_loss = 10
75
        print("Starting Segmentation training")
76
        best_consistency = 0
77
        for i in range(self.epochs):
78
            training_loss = np.abs(self.train_epoch())
79
            val_loss, ious, closs = self.validate(epoch=i, plot=False)
80
            gen_ious = self.validate_generalizability(epoch=i, plot=False)
81
            mean_iou = float(torch.mean(ious))
82
            gen_iou = float(torch.mean(gen_ious))
83
            consistency = 1 - np.mean(closs)
84
            test_iou = np.mean(self.test().numpy())
85
86
            self.config["lr"] = [group['lr'] for group in self.optimizer.param_groups]
87
            logging.log_full(epoch=i, id=self.id, config=self.config, result_dict=
88
            {"train_loss": training_loss, "val_loss": val_loss,
89
             "iid_val_iou": mean_iou, "iid_test_iou": test_iou, "ood_iou": gen_iou,
90
             "consistency": consistency}, type="consistency")
91
92
            self.scheduler.step(i)
93
            print(
94
                f"Epoch {i} of {self.epochs} \t"
95
                f" lr={[group['lr'] for group in self.optimizer.param_groups]} \t"
96
                f" loss={training_loss} \t"
97
                f" val_loss={val_loss} \t"
98
                f" ood_iou={gen_iou}\t"
99
                f" val_iou={mean_iou} \t"
100
                f" gen_prop={gen_iou / mean_iou}"
101
            )
102
            if val_loss < best_val_loss:
103
                best_val_loss = val_loss
104
                np.save(
105
                    f"experiments/Data/Augmented-Pipelines/{self.model_str}/{self.id}",
106
                    test_iou)
107
                print(f"Saving new best model. IID test-set mean iou: {test_iou}")
108
                torch.save(self.model.state_dict(),
109
                           f"Predictors/Augmented/{self.model_str}/{self.id}")
110
                print("saved in: ", f"Predictors/Augmented/{self.model_str}/{self.id}")
111
112
            if consistency > best_consistency:
113
                best_consistency = consistency
114
                torch.save(self.model.state_dict(),
115
                           f"Predictors/Augmented/{self.model_str}/maximum_consistency{self.id}")
116
            torch.save(self.model.state_dict(),
117
                       f"Predictors/Augmented/{self.model_str}/{self.id}_last_epoch")
118
119
    def test(self):
120
        self.model.eval()
121
        ious = torch.empty((0,))
122
        with torch.no_grad():
123
            for x, y, fname in self.test_loader:
124
                image = x.to("cuda")
125
                mask = y.to("cuda")
126
                output, _ = self.model(image)
127
                batch_ious = torch.Tensor([iou(output_i, mask_j) for output_i, mask_j in zip(output, mask)])
128
                ious = torch.cat((ious, batch_ious.flatten()))
129
        return ious
130
131
    def validate(self, epoch, plot=False):
132
        self.model.eval()
133
        losses = []
134
        closses = []
135
        ious = torch.empty((0,))
136
        with torch.no_grad():
137
            for x, y, fname in self.val_loader:
138
                image = x.to("cuda")
139
                mask = y.to("cuda")
140
                aug_img, aug_mask = self.mnv(image, mask)
141
                output, reconstruction = self.model(image)
142
                aug_output, _ = self.model(aug_img)  # todo consider train on augmented vs non-augmented?
143
144
                batch_ious = torch.Tensor([iou(output_i, mask_j) for output_i, mask_j in zip(output, mask)])
145
                loss = 0.5 * (self.closs(aug_mask, mask, aug_output, output, torch.mean(batch_ious)) + self.mse(
146
                    image, reconstruction))
147
                losses.append(np.abs(loss.item()))
148
                closses.append(self.nakedcloss(aug_mask, mask, aug_output, output).item())
149
                ious = torch.cat((ious, batch_ious.cpu().flatten()))
150
151
                if plot:
152
                    plt.imshow(output[0, 0].cpu().numpy(), alpha=0.5)
153
                    plt.imshow(reconstruction[0].permute(1, 2, 0).cpu().numpy())
154
                    # plt.imshow((output[0, 0].cpu().numpy() > 0.5).astype(int), alpha=0.5)
155
                    # plt.imshow(y[0, 0].cpu().numpy().astype(int), alpha=0.5)
156
                    # plt.title("IoU {} at epoch {}".format(iou(output[0, 0], mask[0, 0]), epoch))
157
                    plt.show()
158
                    plot = False  # plot one example per epoch
159
        avg_val_loss = np.mean(losses)
160
        avg_closs = np.mean(closses)
161
        return avg_val_loss, ious, closses
162
163
    def validate_generalizability(self, epoch, plot=False):
164
        self.model.eval()
165
        ious = torch.empty((0,))
166
        with torch.no_grad():
167
            for x, y, index in DataLoader(EtisDataset("Datasets/ETIS-LaribPolypDB")):
168
                image = x.to("cuda")
169
                mask = y.to("cuda")
170
                output, _ = self.model(image)
171
                batch_ious = torch.Tensor([iou(output_i, mask_j) for output_i, mask_j in zip(output, mask)])
172
                ious = torch.cat((ious, batch_ious.flatten()))
173
                if plot:
174
                    plt.imshow(image[0].permute(1, 2, 0).cpu().numpy())
175
                    plt.imshow((output[0, 0].cpu().numpy() > 0.5).astype(int), alpha=0.5)
176
                    plt.title("IoU {} at epoch {}".format(iou(output[0, 0], mask[0, 0]), epoch))
177
                    plt.show()
178
                    plot = False  # plot one example per epoch (hacky, but works)
179
            return ious
180
181
182
class InductiveNetVanillaTrainer(InductiveNetConsistencyTrainer):
183
184
    def __init__(self, id, config):
185
        super(InductiveNetVanillaTrainer, self).__init__(id, config)
186
187
    def train(self):
188
        best_val_loss = 10
189
        print("Starting Segmentation training")
190
        best_consistency = 0
191
        for i in range(self.epochs):
192
            training_loss = np.abs(self.train_epoch())
193
            val_loss, ious, closs = self.validate(epoch=i, plot=False)
194
            gen_ious = self.validate_generalizability(epoch=i, plot=False)
195
            mean_iou = float(torch.mean(ious))
196
            gen_iou = float(torch.mean(gen_ious))
197
            consistency = 1 - np.mean(closs)
198
            test_iou = np.mean(self.test().numpy())
199
200
            self.config["lr"] = [group['lr'] for group in self.optimizer.param_groups]
201
            logging.log_full(epoch=i, id=self.id, config=self.config, result_dict=
202
            {"train_loss": training_loss, "val_loss": val_loss,
203
             "iid_val_iou": mean_iou, "iid_test_iou": test_iou, "ood_iou": gen_iou,
204
             "consistency": consistency}, type="consistency")
205
206
            self.scheduler.step(i)
207
            print(
208
                f"Epoch {i} of {self.epochs} \t"
209
                f" lr={[group['lr'] for group in self.optimizer.param_groups]} \t"
210
                f" loss={training_loss} \t"
211
                f" val_loss={val_loss} \t"
212
                f" ood_iou={gen_iou}\t"
213
                f" val_iou={mean_iou} \t"
214
                f" gen_prop={gen_iou / mean_iou}"
215
            )
216
            if val_loss < best_val_loss:
217
                best_val_loss = val_loss
218
                np.save(
219
                    f"experiments/Data/Normal-Pipelines/{self.model_str}/{self.id}",
220
                    test_iou)
221
                print(f"Saving new best model. IID test-set mean iou: {test_iou}")
222
                torch.save(self.model.state_dict(),
223
                           f"Predictors/Vanilla/{self.model_str}/{self.id}")
224
                print("saved in: ", f"Predictors/Vanilla/{self.model_str}/{self.id}")
225
226
            if consistency > best_consistency:
227
                best_consistency = consistency
228
                torch.save(self.model.state_dict(),
229
                           f"Predictors/Vanilla/{self.model_str}/maximum_consistency{self.id}")
230
            torch.save(self.model.state_dict(),
231
                       f"Predictors/Vanilla/{self.model_str}/{self.id}_last_epoch")
232
233
    def train_epoch(self):
234
        self.model.train()
235
        losses = []
236
        for x, y, fname in self.train_loader:
237
            image = x.to("cuda")
238
            mask = y.to("cuda")
239
            self.optimizer.zero_grad()
240
            output, reconstruction = self.model(image)
241
            mean_iou = torch.mean(iou(output, mask))
242
            loss = 0.5 * (self.jaccard(output, mask) + self.mse(
243
                image, reconstruction))
244
            loss.backward()
245
            self.optimizer.step()
246
            losses.append(np.abs(loss.item()))
247
        return np.mean(losses)
248
249
    def test(self):
250
        self.model.eval()
251
        ious = torch.empty((0,))
252
        with torch.no_grad():
253
            for x, y, fname in self.test_loader:
254
                image = x.to("cuda")
255
                mask = y.to("cuda")
256
                output = self.model(image)
257
                batch_ious = torch.Tensor([iou(output_i, mask_j) for output_i, mask_j in zip(output, mask)])
258
                ious = torch.cat((ious, batch_ious.flatten()))
259
        return ious
260
261
    def validate(self, epoch, plot=False):
262
        self.model.eval()
263
        losses = []
264
        closses = []
265
        ious = torch.empty((0,))
266
        with torch.no_grad():
267
            for x, y, fname in self.val_loader:
268
                image = x.to("cuda")
269
                mask = y.to("cuda")
270
                aug_img, aug_mask = self.mnv(image, mask)
271
                output, reconstruction = self.model(image)
272
                aug_output, _ = self.model(aug_img)  # todo consider train on augmented vs non-augmented?
273
274
                batch_ious = torch.Tensor([iou(output_i, mask_j) for output_i, mask_j in zip(output, mask)])
275
                loss = 0.5 * (self.jaccard(output, mask) + self.mse(
276
                    image, reconstruction))
277
                losses.append(np.abs(loss.item()))
278
                closses.append(self.nakedcloss(aug_mask, mask, aug_output, output).item())
279
                ious = torch.cat((ious, batch_ious.cpu().flatten()))
280
281
                if plot:
282
                    plt.imshow(output[0, 0].cpu().numpy(), alpha=0.5)
283
                    plt.imshow(reconstruction[0].permute(1, 2, 0).cpu().numpy())
284
                    # plt.imshow((output[0, 0].cpu().numpy() > 0.5).astype(int), alpha=0.5)
285
                    # plt.imshow(y[0, 0].cpu().numpy().astype(int), alpha=0.5)
286
                    # plt.title("IoU {} at epoch {}".format(iou(output[0, 0], mask[0, 0]), epoch))
287
                    plt.show()
288
                    plot = False  # plot one example per epoch
289
        avg_val_loss = np.mean(losses)
290
        avg_closs = np.mean(closses)
291
        return avg_val_loss, ious, avg_closs
292
293
    def validate_generalizability(self, epoch, plot=False):
294
        self.model.eval()
295
        ious = torch.empty((0,))
296
        with torch.no_grad():
297
            for x, y, index in DataLoader(EtisDataset("Datasets/ETIS-LaribPolypDB")):
298
                image = x.to("cuda")
299
                mask = y.to("cuda")
300
                output, _ = self.model(image)
301
                batch_ious = torch.Tensor([iou(output_i, mask_j) for output_i, mask_j in zip(output, mask)])
302
                ious = torch.cat((ious, batch_ious.flatten()))
303
                if plot:
304
                    plt.imshow(image[0].permute(1, 2, 0).cpu().numpy())
305
                    plt.imshow((output[0, 0].cpu().numpy() > 0.5).astype(int), alpha=0.5)
306
                    plt.title("IoU {} at epoch {}".format(iou(output[0, 0], mask[0, 0]), epoch))
307
                    plt.show()
308
                    plot = False  # plot one example per epoch (hacky, but works)
309
            return ious
310
311
312
class InductiveNetAugmentationTrainer(InductiveNetConsistencyTrainer):
313
    """
314
        Uses vanilla data augmentation with p=0.5 instead of a a custom loss
315
    """
316
317
    def __init__(self, id, config):
318
        super(InductiveNetAugmentationTrainer, self).__init__(id, config)
319
        self.jaccard = vanilla_losses.JaccardLoss()
320
        self.mse = vanilla_losses.MSELoss()
321
        self.prob = 0
322
        self.dataset = KvasirMNVset("Datasets/HyperKvasir", "train", inpaint=config["use_inpainter"])
323
        self.train_loader = DataLoader(self.dataset, batch_size=config["batch_size"], shuffle=True)
324
325
    def get_iou_weights(self, image, mask):
326
        self.model.eval()
327
        with torch.no_grad():
328
            output, _ = self.model(image)
329
        return torch.mean(iou(output, mask))
330
331
    def get_consistency(self, image, mask, augmented, augmask):
332
        self.model.eval()
333
        with torch.no_grad():
334
            output, _ = self.model(image)
335
        self.model.train()
336
        return torch.mean(self.nakedcloss(output, mask, augmented, augmask))
337
338
    def train_epoch(self):
339
        self.model.train()
340
        losses = []
341
        for x, y, fname, flag in self.train_loader:
342
            image = x.to("cuda")
343
            mask = y.to("cuda")
344
            self.optimizer.zero_grad()
345
            output, reconstruction = self.model(image)
346
            mean_iou = torch.mean(iou(output, mask))
347
            loss = 0.5 * (self.jaccard(output, mask) + self.mse(
348
                image, reconstruction))
349
            loss.backward()
350
            self.optimizer.step()
351
            losses.append(np.abs(loss.item()))
352
        return np.mean(losses)
353
354
355
class InductiveNetEnsembleTrainer(InductiveNetConsistencyTrainer):
356
    def __init__(self, id, config):
357
        super(InductiveNetEnsembleTrainer, self).__init__(id, config)
358
        self.model = TrainedEnsemble("Singular")
359
        self.optimizer = torch.optim.Adam(self.model.parameters(), self.lr)
360
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(self.optimizer, T_0=50, T_mult=2)