|
a |
|
b/training/vanilla_trainer.py |
|
|
1 |
import matplotlib.pyplot as plt |
|
|
2 |
import numpy as np |
|
|
3 |
import segmentation_models_pytorch.utils.losses as vanilla_losses |
|
|
4 |
import torch.optim.optimizer |
|
|
5 |
from torch.utils.data import DataLoader |
|
|
6 |
|
|
|
7 |
from data.etis import EtisDataset |
|
|
8 |
from data.hyperkvasir import KvasirSegmentationDataset |
|
|
9 |
from models import segmentation_models |
|
|
10 |
from evaluation.metrics import iou |
|
|
11 |
from losses.consistency_losses import NakedConsistencyLoss, ConsistencyLoss |
|
|
12 |
from perturbation.model import ModelOfNaturalVariation |
|
|
13 |
from utils import logging |
|
|
14 |
|
|
|
15 |
|
|
|
16 |
class VanillaTrainer: |
|
|
17 |
def __init__(self, id, config): |
|
|
18 |
""" |
|
|
19 |
|
|
|
20 |
:param model: String describing the model type. Can be DeepLab, TriUnet, ... TODO |
|
|
21 |
:param config: Contains hyperparameters : lr, epochs, batch_size, T_0, T_mult |
|
|
22 |
""" |
|
|
23 |
self.config = config |
|
|
24 |
self.device = config["device"] |
|
|
25 |
self.lr = config["lr"] |
|
|
26 |
self.batch_size = config["batch_size"] |
|
|
27 |
self.epochs = config["epochs"] |
|
|
28 |
self.model = None |
|
|
29 |
self.id = id |
|
|
30 |
self.model_str = config["model"] |
|
|
31 |
self.mnv = ModelOfNaturalVariation(T0=1).to(self.device) |
|
|
32 |
self.nakedcloss = NakedConsistencyLoss() |
|
|
33 |
self.closs = ConsistencyLoss() |
|
|
34 |
|
|
|
35 |
if self.model_str == "DeepLab": |
|
|
36 |
self.model = segmentation_models.DeepLab().to(self.device) |
|
|
37 |
elif self.model_str == "TriUnet": |
|
|
38 |
self.model = segmentation_models.TriUnet().to(self.device) |
|
|
39 |
elif self.model_str == "Unet": |
|
|
40 |
self.model = segmentation_models.Unet().to(self.device) |
|
|
41 |
elif self.model_str == "FPN": |
|
|
42 |
self.model = segmentation_models.FPN().to(self.device) |
|
|
43 |
elif self.model_str == "InductiveNet": |
|
|
44 |
self.model = segmentation_models.InductiveNet().to(self.device) |
|
|
45 |
|
|
|
46 |
else: |
|
|
47 |
raise AttributeError("model_str not valid; choices are DeepLab, TriUnet, InductiveNet, FPN, Unet") |
|
|
48 |
|
|
|
49 |
self.optimizer = torch.optim.Adam(self.model.parameters(), self.lr) |
|
|
50 |
self.criterion = vanilla_losses.JaccardLoss() |
|
|
51 |
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(self.optimizer, T_0=50, T_mult=2) |
|
|
52 |
self.train_set = KvasirSegmentationDataset("Datasets/HyperKvasir", split="train", augment=False) |
|
|
53 |
self.val_set = KvasirSegmentationDataset("Datasets/HyperKvasir", split="val", augment=False) |
|
|
54 |
self.test_set = KvasirSegmentationDataset("Datasets/HyperKvasir", split="test", augment=False) |
|
|
55 |
self.train_loader = DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True) |
|
|
56 |
self.val_loader = DataLoader(self.val_set) |
|
|
57 |
self.test_loader = DataLoader(self.test_set) |
|
|
58 |
|
|
|
59 |
def train_epoch(self): |
|
|
60 |
self.model.train() |
|
|
61 |
losses = [] |
|
|
62 |
for x, y, fname in self.train_loader: |
|
|
63 |
image = x.to("cuda") |
|
|
64 |
mask = y.to("cuda") |
|
|
65 |
self.optimizer.zero_grad() |
|
|
66 |
output = self.model(image) |
|
|
67 |
loss = self.criterion(output, mask) |
|
|
68 |
loss.backward() |
|
|
69 |
self.optimizer.step() |
|
|
70 |
losses.append(np.abs(loss.item())) |
|
|
71 |
return np.mean(losses) |
|
|
72 |
|
|
|
73 |
def train(self): |
|
|
74 |
best_val_loss = 10 |
|
|
75 |
print("Starting Segmentation training") |
|
|
76 |
best_closs = 100 |
|
|
77 |
|
|
|
78 |
for i in range(self.epochs): |
|
|
79 |
training_loss = np.abs(self.train_epoch()) |
|
|
80 |
val_loss, ious, closs = self.validate(epoch=i, plot=False) |
|
|
81 |
gen_ious = self.validate_generalizability(epoch=i, plot=False) |
|
|
82 |
mean_iou = float(torch.mean(ious)) |
|
|
83 |
gen_iou = float(torch.mean(gen_ious)) |
|
|
84 |
consistency = 1 - np.mean(closs) |
|
|
85 |
test_ious = np.mean(self.test().numpy()) |
|
|
86 |
self.config["lr"] = [group['lr'] for group in self.optimizer.param_groups] |
|
|
87 |
logging.log_full(epoch=i, id=self.id, config=self.config, result_dict= |
|
|
88 |
{"train_loss": training_loss, "val_loss": val_loss, |
|
|
89 |
"iid_val_iou": mean_iou, "iid_test_iou": test_ious, "ood_iou": gen_iou, |
|
|
90 |
"consistency": consistency}, type="vanilla") |
|
|
91 |
|
|
|
92 |
self.scheduler.step(i) |
|
|
93 |
print( |
|
|
94 |
f"Epoch {i} of {self.epochs} \t" |
|
|
95 |
f" lr={[group['lr'] for group in self.optimizer.param_groups]} \t" |
|
|
96 |
f" loss={training_loss} \t" |
|
|
97 |
f" val_loss={val_loss} \t" |
|
|
98 |
f" ood_iou={gen_iou}\t" |
|
|
99 |
f" val_iou={mean_iou} \t" |
|
|
100 |
f" gen_prop={gen_iou / mean_iou}" |
|
|
101 |
) |
|
|
102 |
if val_loss < best_val_loss: |
|
|
103 |
test_ious = self.test() |
|
|
104 |
best_val_loss = val_loss |
|
|
105 |
np.save( |
|
|
106 |
f"experiments/Data/Normal-Pipelines/{self.model_str}/{self.id}", |
|
|
107 |
test_ious) |
|
|
108 |
print(f"Saving new best model. IID test-set mean iou: {float(np.mean(test_ious.numpy()))}") |
|
|
109 |
torch.save(self.model.state_dict(), |
|
|
110 |
f"Predictors/Vanilla/{self.model_str}/{self.id}") |
|
|
111 |
print("saved in: ", f"Predictors/Vanilla/{self.model_str}/{self.id}") |
|
|
112 |
if closs < best_closs: |
|
|
113 |
best_closs = closs |
|
|
114 |
torch.save(self.model.state_dict(), |
|
|
115 |
f"Predictors/Vanilla/{self.model_str}/{self.id}-maximum-consistency") |
|
|
116 |
torch.save(self.model.state_dict(), |
|
|
117 |
f"Predictors/Vanilla/{self.model_str}/{self.id}_last_epoch") |
|
|
118 |
|
|
|
119 |
def test(self): |
|
|
120 |
self.model.eval() |
|
|
121 |
ious = torch.empty((0,)) |
|
|
122 |
with torch.no_grad(): |
|
|
123 |
for x, y, fname in self.test_loader: |
|
|
124 |
image = x.to("cuda") |
|
|
125 |
mask = y.to("cuda") |
|
|
126 |
output = self.model(image) |
|
|
127 |
batch_ious = torch.Tensor([iou(output_i, mask_j) for output_i, mask_j in zip(output, mask)]) |
|
|
128 |
ious = torch.cat((ious, batch_ious.flatten())) |
|
|
129 |
return ious |
|
|
130 |
|
|
|
131 |
def validate(self, epoch, plot=False): |
|
|
132 |
self.model.eval() |
|
|
133 |
losses = [] |
|
|
134 |
closses = [] |
|
|
135 |
ious = torch.empty((0,)) |
|
|
136 |
with torch.no_grad(): |
|
|
137 |
for x, y, fname in self.val_loader: |
|
|
138 |
image = x.to("cuda") |
|
|
139 |
mask = y.to("cuda") |
|
|
140 |
aug_img, aug_mask = self.mnv(image, mask) |
|
|
141 |
output = self.model(image) |
|
|
142 |
aug_output = self.model(aug_img) # todo consider train on augmented vs non-augmented? |
|
|
143 |
|
|
|
144 |
batch_ious = torch.Tensor([iou(output_i, mask_j) for output_i, mask_j in zip(output, mask)]) |
|
|
145 |
loss = self.closs(aug_mask, mask, aug_output, output, torch.mean(batch_ious)) |
|
|
146 |
losses.append(np.abs(loss.item())) |
|
|
147 |
closses.append(self.nakedcloss(aug_mask, mask, aug_output, output).item()) |
|
|
148 |
ious = torch.cat((ious, batch_ious.cpu().flatten())) |
|
|
149 |
if plot: |
|
|
150 |
plt.imshow(y[0, 0].cpu().numpy(), alpha=0.5) |
|
|
151 |
plt.imshow(image[0].permute(1, 2, 0).cpu().numpy()) |
|
|
152 |
plt.imshow((output[0, 0].cpu().numpy() > 0.5).astype(int), alpha=0.5) |
|
|
153 |
plt.imshow(y[0, 0].cpu().numpy().astype(int), alpha=0.5) |
|
|
154 |
plt.title("IoU {} at epoch {}".format(iou(output[0, 0], mask[0, 0]), epoch)) |
|
|
155 |
plt.show() |
|
|
156 |
plot = False # plot one example per epoch |
|
|
157 |
avg_val_loss = np.mean(losses) |
|
|
158 |
avg_closs = np.mean(closses) |
|
|
159 |
return avg_val_loss, ious, avg_closs |
|
|
160 |
|
|
|
161 |
def validate_generalizability(self, epoch, plot=False): |
|
|
162 |
self.model.eval() |
|
|
163 |
ious = torch.empty((0,)) |
|
|
164 |
with torch.no_grad(): |
|
|
165 |
for x, y, index in DataLoader(EtisDataset("Datasets/ETIS-LaribPolypDB")): |
|
|
166 |
image = x.to("cuda") |
|
|
167 |
mask = y.to("cuda") |
|
|
168 |
output = self.model(image) |
|
|
169 |
batch_ious = torch.Tensor([iou(output_i, mask_j) for output_i, mask_j in zip(output, mask)]) |
|
|
170 |
ious = torch.cat((ious, batch_ious.flatten())) |
|
|
171 |
if plot: |
|
|
172 |
plt.imshow(image[0].permute(1, 2, 0).cpu().numpy()) |
|
|
173 |
plt.imshow((output[0, 0].cpu().numpy() > 0.5).astype(int), alpha=0.5) |
|
|
174 |
plt.title("IoU {} at epoch {}".format(iou(output[0, 0], mask[0, 0]), epoch)) |
|
|
175 |
plt.show() |
|
|
176 |
plot = False # plot one example per epoch (hacky, but works) |
|
|
177 |
return ious |