Switch to unified view

a b/training/vanilla_trainer.py
1
import matplotlib.pyplot as plt
2
import numpy as np
3
import segmentation_models_pytorch.utils.losses as vanilla_losses
4
import torch.optim.optimizer
5
from torch.utils.data import DataLoader
6
7
from data.etis import EtisDataset
8
from data.hyperkvasir import KvasirSegmentationDataset
9
from models import segmentation_models
10
from evaluation.metrics import iou
11
from losses.consistency_losses import NakedConsistencyLoss, ConsistencyLoss
12
from perturbation.model import ModelOfNaturalVariation
13
from utils import logging
14
15
16
class VanillaTrainer:
17
    def __init__(self, id, config):
18
        """
19
20
        :param model: String describing the model type. Can be DeepLab, TriUnet, ... TODO
21
        :param config: Contains hyperparameters : lr, epochs, batch_size, T_0, T_mult
22
        """
23
        self.config = config
24
        self.device = config["device"]
25
        self.lr = config["lr"]
26
        self.batch_size = config["batch_size"]
27
        self.epochs = config["epochs"]
28
        self.model = None
29
        self.id = id
30
        self.model_str = config["model"]
31
        self.mnv = ModelOfNaturalVariation(T0=1).to(self.device)
32
        self.nakedcloss = NakedConsistencyLoss()
33
        self.closs = ConsistencyLoss()
34
35
        if self.model_str == "DeepLab":
36
            self.model = segmentation_models.DeepLab().to(self.device)
37
        elif self.model_str == "TriUnet":
38
            self.model = segmentation_models.TriUnet().to(self.device)
39
        elif self.model_str == "Unet":
40
            self.model = segmentation_models.Unet().to(self.device)
41
        elif self.model_str == "FPN":
42
            self.model = segmentation_models.FPN().to(self.device)
43
        elif self.model_str == "InductiveNet":
44
            self.model = segmentation_models.InductiveNet().to(self.device)
45
46
        else:
47
            raise AttributeError("model_str not valid; choices are DeepLab, TriUnet, InductiveNet, FPN, Unet")
48
49
        self.optimizer = torch.optim.Adam(self.model.parameters(), self.lr)
50
        self.criterion = vanilla_losses.JaccardLoss()
51
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(self.optimizer, T_0=50, T_mult=2)
52
        self.train_set = KvasirSegmentationDataset("Datasets/HyperKvasir", split="train", augment=False)
53
        self.val_set = KvasirSegmentationDataset("Datasets/HyperKvasir", split="val", augment=False)
54
        self.test_set = KvasirSegmentationDataset("Datasets/HyperKvasir", split="test", augment=False)
55
        self.train_loader = DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True)
56
        self.val_loader = DataLoader(self.val_set)
57
        self.test_loader = DataLoader(self.test_set)
58
59
    def train_epoch(self):
60
        self.model.train()
61
        losses = []
62
        for x, y, fname in self.train_loader:
63
            image = x.to("cuda")
64
            mask = y.to("cuda")
65
            self.optimizer.zero_grad()
66
            output = self.model(image)
67
            loss = self.criterion(output, mask)
68
            loss.backward()
69
            self.optimizer.step()
70
            losses.append(np.abs(loss.item()))
71
        return np.mean(losses)
72
73
    def train(self):
74
        best_val_loss = 10
75
        print("Starting Segmentation training")
76
        best_closs = 100
77
78
        for i in range(self.epochs):
79
            training_loss = np.abs(self.train_epoch())
80
            val_loss, ious, closs = self.validate(epoch=i, plot=False)
81
            gen_ious = self.validate_generalizability(epoch=i, plot=False)
82
            mean_iou = float(torch.mean(ious))
83
            gen_iou = float(torch.mean(gen_ious))
84
            consistency = 1 - np.mean(closs)
85
            test_ious = np.mean(self.test().numpy())
86
            self.config["lr"] = [group['lr'] for group in self.optimizer.param_groups]
87
            logging.log_full(epoch=i, id=self.id, config=self.config, result_dict=
88
            {"train_loss": training_loss, "val_loss": val_loss,
89
             "iid_val_iou": mean_iou, "iid_test_iou": test_ious, "ood_iou": gen_iou,
90
             "consistency": consistency}, type="vanilla")
91
92
            self.scheduler.step(i)
93
            print(
94
                f"Epoch {i} of {self.epochs} \t"
95
                f" lr={[group['lr'] for group in self.optimizer.param_groups]} \t"
96
                f" loss={training_loss} \t"
97
                f" val_loss={val_loss} \t"
98
                f" ood_iou={gen_iou}\t"
99
                f" val_iou={mean_iou} \t"
100
                f" gen_prop={gen_iou / mean_iou}"
101
            )
102
            if val_loss < best_val_loss:
103
                test_ious = self.test()
104
                best_val_loss = val_loss
105
                np.save(
106
                    f"experiments/Data/Normal-Pipelines/{self.model_str}/{self.id}",
107
                    test_ious)
108
                print(f"Saving new best model. IID test-set mean iou: {float(np.mean(test_ious.numpy()))}")
109
                torch.save(self.model.state_dict(),
110
                           f"Predictors/Vanilla/{self.model_str}/{self.id}")
111
                print("saved in: ", f"Predictors/Vanilla/{self.model_str}/{self.id}")
112
            if closs < best_closs:
113
                best_closs = closs
114
                torch.save(self.model.state_dict(),
115
                           f"Predictors/Vanilla/{self.model_str}/{self.id}-maximum-consistency")
116
        torch.save(self.model.state_dict(),
117
                   f"Predictors/Vanilla/{self.model_str}/{self.id}_last_epoch")
118
119
    def test(self):
120
        self.model.eval()
121
        ious = torch.empty((0,))
122
        with torch.no_grad():
123
            for x, y, fname in self.test_loader:
124
                image = x.to("cuda")
125
                mask = y.to("cuda")
126
                output = self.model(image)
127
                batch_ious = torch.Tensor([iou(output_i, mask_j) for output_i, mask_j in zip(output, mask)])
128
                ious = torch.cat((ious, batch_ious.flatten()))
129
        return ious
130
131
    def validate(self, epoch, plot=False):
132
        self.model.eval()
133
        losses = []
134
        closses = []
135
        ious = torch.empty((0,))
136
        with torch.no_grad():
137
            for x, y, fname in self.val_loader:
138
                image = x.to("cuda")
139
                mask = y.to("cuda")
140
                aug_img, aug_mask = self.mnv(image, mask)
141
                output = self.model(image)
142
                aug_output = self.model(aug_img)  # todo consider train on augmented vs non-augmented?
143
144
                batch_ious = torch.Tensor([iou(output_i, mask_j) for output_i, mask_j in zip(output, mask)])
145
                loss = self.closs(aug_mask, mask, aug_output, output, torch.mean(batch_ious))
146
                losses.append(np.abs(loss.item()))
147
                closses.append(self.nakedcloss(aug_mask, mask, aug_output, output).item())
148
                ious = torch.cat((ious, batch_ious.cpu().flatten()))
149
                if plot:
150
                    plt.imshow(y[0, 0].cpu().numpy(), alpha=0.5)
151
                    plt.imshow(image[0].permute(1, 2, 0).cpu().numpy())
152
                    plt.imshow((output[0, 0].cpu().numpy() > 0.5).astype(int), alpha=0.5)
153
                    plt.imshow(y[0, 0].cpu().numpy().astype(int), alpha=0.5)
154
                    plt.title("IoU {} at epoch {}".format(iou(output[0, 0], mask[0, 0]), epoch))
155
                    plt.show()
156
                    plot = False  # plot one example per epoch
157
        avg_val_loss = np.mean(losses)
158
        avg_closs = np.mean(closses)
159
        return avg_val_loss, ious, avg_closs
160
161
    def validate_generalizability(self, epoch, plot=False):
162
        self.model.eval()
163
        ious = torch.empty((0,))
164
        with torch.no_grad():
165
            for x, y, index in DataLoader(EtisDataset("Datasets/ETIS-LaribPolypDB")):
166
                image = x.to("cuda")
167
                mask = y.to("cuda")
168
                output = self.model(image)
169
                batch_ious = torch.Tensor([iou(output_i, mask_j) for output_i, mask_j in zip(output, mask)])
170
                ious = torch.cat((ious, batch_ious.flatten()))
171
                if plot:
172
                    plt.imshow(image[0].permute(1, 2, 0).cpu().numpy())
173
                    plt.imshow((output[0, 0].cpu().numpy() > 0.5).astype(int), alpha=0.5)
174
                    plt.title("IoU {} at epoch {}".format(iou(output[0, 0], mask[0, 0]), epoch))
175
                    plt.show()
176
                    plot = False  # plot one example per epoch (hacky, but works)
177
            return ious