Switch to unified view

a b/training/consistency_trainers.py
1
import matplotlib.pyplot as plt
2
import numpy as np
3
import torch.optim.optimizer
4
from torch.utils.data import DataLoader
5
from tqdm import tqdm
6
from data.hyperkvasir import KvasirSegmentationDataset, KvasirMNVset
7
from evaluation.metrics import iou
8
from losses.consistency_losses import *
9
from perturbation.model import ModelOfNaturalVariation
10
from training.vanilla_trainer import VanillaTrainer
11
from utils import logging
12
from models.ensembles import TrainedEnsemble
13
14
15
class ConsistencyTrainer(VanillaTrainer):
16
    def __init__(self, id, config):
17
        super(ConsistencyTrainer, self).__init__(id, config)
18
        self.criterion = ConsistencyLoss().to(self.device)
19
        self.nakedcloss = NakedConsistencyLoss()
20
21
    def train_epoch(self):
22
        self.model.train()
23
        losses = []
24
        for x, y, fname in self.train_loader:
25
            image = x.to("cuda")
26
            mask = y.to("cuda")
27
            aug_img, aug_mask = self.mnv(image, mask)
28
            self.optimizer.zero_grad()
29
            output = self.model(image)
30
            aug_output = self.model(aug_img)
31
            mean_iou = torch.mean(iou(output, mask))
32
            loss = self.criterion(aug_mask, mask, aug_output, output, mean_iou)
33
            loss.backward()
34
            self.optimizer.step()
35
            losses.append(np.abs(loss.item()))
36
        return np.mean(losses)
37
38
    def train(self):
39
40
        best_val_loss = 1000
41
        best_consistency = 0
42
        print("Starting Segmentation training")
43
        for i in range(self.epochs):
44
            training_loss = np.abs(self.train_epoch())
45
            val_loss, ious, closs = self.validate(epoch=i, plot=False)
46
            gen_ious = self.validate_generalizability(epoch=i, plot=False)
47
            mean_iou = float(torch.mean(ious))
48
            gen_iou = float(torch.mean(gen_ious))
49
            consistency = 1 - np.mean(closs)
50
            test_iou = np.mean(self.test().numpy())
51
52
            self.config["lr"] = [group['lr'] for group in self.optimizer.param_groups]
53
            logging.log_full(epoch=i, id=self.id, config=self.config, result_dict=
54
            {"train_loss": training_loss, "val_loss": val_loss,
55
             "iid_val_iou": mean_iou, "iid_test_iou": test_iou, "ood_iou": gen_iou,
56
             "consistency": consistency}, type="consistency")
57
58
            self.scheduler.step(i)
59
            # self.mnv.step()
60
            print(
61
                f"Epoch {i} of {self.epochs} \t"
62
                f" lr={[group['lr'] for group in self.optimizer.param_groups]} \t"
63
                f" loss={training_loss} \t"
64
                f" val_loss={val_loss} \t"
65
                f" ood_iou={gen_iou}\t"
66
                f" val_iou={mean_iou} \t"
67
                f" gen_prop={gen_iou / mean_iou} \t,"
68
                f" consistency={consistency}"
69
            )
70
71
            if val_loss < best_val_loss:
72
                best_val_loss = val_loss
73
                np.save(
74
                    f"experiments/Data/Augmented-Pipelines/{self.model_str}/{self.id}",
75
                    test_iou)
76
                print(f"Saving new best model. IID test-set mean iou: {test_iou}")
77
                torch.save(self.model.state_dict(),
78
                           f"Predictors/Augmented/{self.model_str}/{self.id}")
79
                print("saved in: ", f"Predictors/Augmented/{self.model_str}/{self.id}")
80
81
            if consistency > best_consistency:
82
                best_consistency = consistency
83
                torch.save(self.model.state_dict(),
84
                           f"Predictors/Augmented/{self.model_str}/maximum_consistency{self.id}")
85
        torch.save(self.model.state_dict(),
86
                   f"Predictors/Augmented/{self.model_str}/{self.id}_last_epoch")
87
88
    def test(self):
89
        self.model.eval()
90
        ious = torch.empty((0,))
91
        with torch.no_grad():
92
            for x, y, fname in self.test_loader:
93
                image = x.to("cuda")
94
                mask = y.to("cuda")
95
                output = self.model(image)
96
                batch_ious = torch.Tensor([iou(output_i, mask_j) for output_i, mask_j in zip(output, mask)])
97
                ious = torch.cat((ious, batch_ious.flatten()))
98
        return ious
99
100
101
class ConsistencyTrainerUsingAugmentation(ConsistencyTrainer):
102
    """
103
        Uses vanilla data augmentation with p=0.5 instead of a a custom loss
104
    """
105
106
    def __init__(self, id, config):
107
        super(ConsistencyTrainerUsingAugmentation, self).__init__(id, config)
108
        self.jaccard = vanilla_losses.JaccardLoss()
109
        self.prob = 0
110
        self.dataset = KvasirMNVset("Datasets/HyperKvasir", "train", inpaint=config["use_inpainter"])
111
        self.train_loader = DataLoader(self.dataset, batch_size=config["batch_size"], shuffle=True)
112
113
    def get_iou_weights(self, image, mask):
114
        self.model.eval()
115
        with torch.no_grad():
116
            output = self.model(image)
117
        return torch.mean(iou(output, mask))
118
119
    def get_consistency(self, image, mask, augmented, augmask):
120
        self.model.eval()
121
        with torch.no_grad():
122
            output = self.model(image)
123
        self.model.train()
124
        return torch.mean(self.nakedcloss(output, mask, augmented, augmask))
125
126
    def train_epoch(self):
127
        self.model.train()
128
        losses = []
129
        plotted = False
130
        for x, y, fname, flag in self.train_loader:
131
            image = x.to("cuda")
132
            mask = y.to("cuda")
133
            self.optimizer.zero_grad()
134
            output = self.model(image)
135
            loss = self.jaccard(output, mask)
136
            loss.backward()
137
            self.optimizer.step()
138
            losses.append(np.abs(loss.item()))
139
        return np.mean(losses)
140
141
142
class AdversarialConsistencyTrainer(ConsistencyTrainer):
143
    """
144
        Adversariall samples difficult
145
    """
146
147
    def __init__(self, id, config):
148
        super(ConsistencyTrainer, self).__init__(id, config)
149
        self.mnv = ModelOfNaturalVariation(T0=1).to(self.device)
150
        self.num_adv_samples = 10
151
        self.naked_closs = NakedConsistencyLoss()
152
        self.criterion = ConsistencyLoss().to(self.device)
153
154
    def sample_adversarial(self, image, mask, output):
155
        self.model.eval()
156
        aug_img, aug_mask = None, None  #
157
        max_severity = -10
158
        with torch.no_grad():
159
            for i in range(self.num_adv_samples):
160
                adv_aug_img, adv_aug_mask = self.mnv(image, mask)
161
                adv_aug_output = self.model(adv_aug_img)
162
                severity = self.naked_closs(adv_aug_mask, mask, adv_aug_output, output)
163
164
                if severity > max_severity:
165
                    max_severity = severity
166
                    aug_img = adv_aug_img
167
                    aug_mask = adv_aug_mask
168
        self.model.train()
169
        return aug_img, aug_mask
170
171
    def train_epoch(self):
172
        self.model.train()
173
        losses = []
174
        for x, y, fname in self.train_loader:
175
            image = x.to("cuda")
176
            mask = y.to("cuda")
177
            self.optimizer.zero_grad()
178
            output = self.model(image)
179
            aug_img, aug_mask = self.sample_adversarial(image, mask, output)
180
            # aug_img, aug_mask = self.mnv(image, mask)
181
            aug_output = self.model(aug_img)
182
            mean_iou = torch.mean(iou(output, mask))
183
            loss = self.criterion(aug_mask, mask, aug_output, output, mean_iou)
184
            loss.backward()
185
            self.optimizer.step()
186
            losses.append(np.abs(loss.item()))
187
        return np.mean(losses)
188
189
190
class StrictConsistencyTrainer(ConsistencyTrainer):
191
    def __init__(self, id, config):
192
        super(StrictConsistencyTrainer, self).__init__(id, config)
193
        self.criterion = StrictConsistencyLoss()
194
195
196
class ConsistencyTrainerUsingControlledAugmentation(ConsistencyTrainer):
197
    """
198
        Uses vanilla data augmentation with p=0.5 instead of a a custom loss and has two samples
199
    """
200
201
    def __init__(self, id, config):
202
        super(ConsistencyTrainerUsingControlledAugmentation, self).__init__(id, config)
203
        self.jaccard = vanilla_losses.JaccardLoss()
204
        self.mnv = ModelOfNaturalVariation(1)
205
206
    def train_epoch(self):
207
        self.model.train()
208
        losses = []
209
        plotted = False
210
        for x, y, fname in self.train_loader:
211
            image = x.to("cuda")
212
            mask = y.to("cuda")
213
            aug_img, aug_mask = self.mnv(image, mask)
214
            img_batch = torch.cat((image, aug_img))
215
            mask_batch = torch.cat((mask, aug_mask))
216
            self.optimizer.zero_grad()
217
            output = self.model(img_batch)
218
            loss = self.jaccard(output, mask_batch)
219
            loss.backward()
220
            self.optimizer.step()
221
            losses.append(np.abs(loss.item()))
222
        return np.mean(losses)
223
224
225
class EnsembleConsistencyTrainer(ConsistencyTrainer):
226
    def __init__(self, id, config):
227
        super(EnsembleConsistencyTrainer, self).__init__(id, config)
228
        self.model = TrainedEnsemble("Singular")
229
        self.optimizer = torch.optim.Adam(self.model.parameters(), self.lr)
230
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(self.optimizer, T_0=50, T_mult=2)