|
a |
|
b/training/inductivenet_trainers.py |
|
|
1 |
import torch |
|
|
2 |
import numpy as np |
|
|
3 |
import matplotlib.pyplot as plt |
|
|
4 |
import torch.optim.optimizer |
|
|
5 |
from torch.utils.data import DataLoader |
|
|
6 |
|
|
|
7 |
from data.hyperkvasir import KvasirSegmentationDataset, KvasirMNVset |
|
|
8 |
from evaluation.metrics import iou |
|
|
9 |
from losses.consistency_losses import * |
|
|
10 |
from perturbation.model import ModelOfNaturalVariation |
|
|
11 |
from training.vanilla_trainer import VanillaTrainer |
|
|
12 |
from utils import logging |
|
|
13 |
from training.consistency_trainers import ConsistencyTrainer |
|
|
14 |
from models.segmentation_models import InductiveNet |
|
|
15 |
from models.ensembles import TrainedEnsemble |
|
|
16 |
from data.hyperkvasir import KvasirSegmentationDataset, KvasirMNVset |
|
|
17 |
from evaluation.metrics import iou |
|
|
18 |
from losses.consistency_losses import * |
|
|
19 |
from perturbation.model import ModelOfNaturalVariation |
|
|
20 |
from training.vanilla_trainer import VanillaTrainer |
|
|
21 |
from utils import logging |
|
|
22 |
from data.etis import EtisDataset |
|
|
23 |
|
|
|
24 |
|
|
|
25 |
class InductiveNetConsistencyTrainer: |
|
|
26 |
def __init__(self, id, config): |
|
|
27 |
""" |
|
|
28 |
|
|
|
29 |
:param model: String describing the model type. Can be DeepLab, TriUnet, ... TODO |
|
|
30 |
:param config: Contains hyperparameters : lr, epochs, batch_size, T_0, T_mult |
|
|
31 |
""" |
|
|
32 |
self.config = config |
|
|
33 |
self.device = config["device"] |
|
|
34 |
self.lr = config["lr"] |
|
|
35 |
self.batch_size = config["batch_size"] |
|
|
36 |
self.epochs = config["epochs"] |
|
|
37 |
self.id = id |
|
|
38 |
self.model_str = "InductiveNet" |
|
|
39 |
self.mnv = ModelOfNaturalVariation(T0=1).to(self.device) |
|
|
40 |
self.nakedcloss = NakedConsistencyLoss() |
|
|
41 |
self.closs = ConsistencyLoss() |
|
|
42 |
self.model = InductiveNet().to(self.device) |
|
|
43 |
|
|
|
44 |
self.optimizer = torch.optim.Adam(self.model.parameters(), self.lr) |
|
|
45 |
self.jaccard = vanilla_losses.JaccardLoss() |
|
|
46 |
self.mse = nn.MSELoss() |
|
|
47 |
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(self.optimizer, T_0=50, T_mult=2) |
|
|
48 |
self.train_set = KvasirSegmentationDataset("Datasets/HyperKvasir", split="train", augment=False) |
|
|
49 |
self.val_set = KvasirSegmentationDataset("Datasets/HyperKvasir", split="val") |
|
|
50 |
self.test_set = KvasirSegmentationDataset("Datasets/HyperKvasir", split="test") |
|
|
51 |
self.train_loader = DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True) |
|
|
52 |
self.val_loader = DataLoader(self.val_set) |
|
|
53 |
self.test_loader = DataLoader(self.test_set) |
|
|
54 |
|
|
|
55 |
def train_epoch(self): |
|
|
56 |
self.model.train() |
|
|
57 |
losses = [] |
|
|
58 |
for x, y, fname in self.train_loader: |
|
|
59 |
image = x.to("cuda") |
|
|
60 |
mask = y.to("cuda") |
|
|
61 |
aug_img, aug_mask = self.mnv(image, mask) |
|
|
62 |
self.optimizer.zero_grad() |
|
|
63 |
aug_output, _ = self.model(aug_img) |
|
|
64 |
output, reconstruction = self.model(image) |
|
|
65 |
mean_iou = torch.mean(iou(output, mask)) |
|
|
66 |
loss = 0.5 * (self.closs(aug_mask, mask, aug_output, output, mean_iou) + self.mse( |
|
|
67 |
image, reconstruction)) |
|
|
68 |
loss.backward() |
|
|
69 |
self.optimizer.step() |
|
|
70 |
losses.append(np.abs(loss.item())) |
|
|
71 |
return np.mean(losses) |
|
|
72 |
|
|
|
73 |
def train(self): |
|
|
74 |
best_val_loss = 10 |
|
|
75 |
print("Starting Segmentation training") |
|
|
76 |
best_consistency = 0 |
|
|
77 |
for i in range(self.epochs): |
|
|
78 |
training_loss = np.abs(self.train_epoch()) |
|
|
79 |
val_loss, ious, closs = self.validate(epoch=i, plot=False) |
|
|
80 |
gen_ious = self.validate_generalizability(epoch=i, plot=False) |
|
|
81 |
mean_iou = float(torch.mean(ious)) |
|
|
82 |
gen_iou = float(torch.mean(gen_ious)) |
|
|
83 |
consistency = 1 - np.mean(closs) |
|
|
84 |
test_iou = np.mean(self.test().numpy()) |
|
|
85 |
|
|
|
86 |
self.config["lr"] = [group['lr'] for group in self.optimizer.param_groups] |
|
|
87 |
logging.log_full(epoch=i, id=self.id, config=self.config, result_dict= |
|
|
88 |
{"train_loss": training_loss, "val_loss": val_loss, |
|
|
89 |
"iid_val_iou": mean_iou, "iid_test_iou": test_iou, "ood_iou": gen_iou, |
|
|
90 |
"consistency": consistency}, type="consistency") |
|
|
91 |
|
|
|
92 |
self.scheduler.step(i) |
|
|
93 |
print( |
|
|
94 |
f"Epoch {i} of {self.epochs} \t" |
|
|
95 |
f" lr={[group['lr'] for group in self.optimizer.param_groups]} \t" |
|
|
96 |
f" loss={training_loss} \t" |
|
|
97 |
f" val_loss={val_loss} \t" |
|
|
98 |
f" ood_iou={gen_iou}\t" |
|
|
99 |
f" val_iou={mean_iou} \t" |
|
|
100 |
f" gen_prop={gen_iou / mean_iou}" |
|
|
101 |
) |
|
|
102 |
if val_loss < best_val_loss: |
|
|
103 |
best_val_loss = val_loss |
|
|
104 |
np.save( |
|
|
105 |
f"experiments/Data/Augmented-Pipelines/{self.model_str}/{self.id}", |
|
|
106 |
test_iou) |
|
|
107 |
print(f"Saving new best model. IID test-set mean iou: {test_iou}") |
|
|
108 |
torch.save(self.model.state_dict(), |
|
|
109 |
f"Predictors/Augmented/{self.model_str}/{self.id}") |
|
|
110 |
print("saved in: ", f"Predictors/Augmented/{self.model_str}/{self.id}") |
|
|
111 |
|
|
|
112 |
if consistency > best_consistency: |
|
|
113 |
best_consistency = consistency |
|
|
114 |
torch.save(self.model.state_dict(), |
|
|
115 |
f"Predictors/Augmented/{self.model_str}/maximum_consistency{self.id}") |
|
|
116 |
torch.save(self.model.state_dict(), |
|
|
117 |
f"Predictors/Augmented/{self.model_str}/{self.id}_last_epoch") |
|
|
118 |
|
|
|
119 |
def test(self): |
|
|
120 |
self.model.eval() |
|
|
121 |
ious = torch.empty((0,)) |
|
|
122 |
with torch.no_grad(): |
|
|
123 |
for x, y, fname in self.test_loader: |
|
|
124 |
image = x.to("cuda") |
|
|
125 |
mask = y.to("cuda") |
|
|
126 |
output, _ = self.model(image) |
|
|
127 |
batch_ious = torch.Tensor([iou(output_i, mask_j) for output_i, mask_j in zip(output, mask)]) |
|
|
128 |
ious = torch.cat((ious, batch_ious.flatten())) |
|
|
129 |
return ious |
|
|
130 |
|
|
|
131 |
def validate(self, epoch, plot=False): |
|
|
132 |
self.model.eval() |
|
|
133 |
losses = [] |
|
|
134 |
closses = [] |
|
|
135 |
ious = torch.empty((0,)) |
|
|
136 |
with torch.no_grad(): |
|
|
137 |
for x, y, fname in self.val_loader: |
|
|
138 |
image = x.to("cuda") |
|
|
139 |
mask = y.to("cuda") |
|
|
140 |
aug_img, aug_mask = self.mnv(image, mask) |
|
|
141 |
output, reconstruction = self.model(image) |
|
|
142 |
aug_output, _ = self.model(aug_img) # todo consider train on augmented vs non-augmented? |
|
|
143 |
|
|
|
144 |
batch_ious = torch.Tensor([iou(output_i, mask_j) for output_i, mask_j in zip(output, mask)]) |
|
|
145 |
loss = 0.5 * (self.closs(aug_mask, mask, aug_output, output, torch.mean(batch_ious)) + self.mse( |
|
|
146 |
image, reconstruction)) |
|
|
147 |
losses.append(np.abs(loss.item())) |
|
|
148 |
closses.append(self.nakedcloss(aug_mask, mask, aug_output, output).item()) |
|
|
149 |
ious = torch.cat((ious, batch_ious.cpu().flatten())) |
|
|
150 |
|
|
|
151 |
if plot: |
|
|
152 |
plt.imshow(output[0, 0].cpu().numpy(), alpha=0.5) |
|
|
153 |
plt.imshow(reconstruction[0].permute(1, 2, 0).cpu().numpy()) |
|
|
154 |
# plt.imshow((output[0, 0].cpu().numpy() > 0.5).astype(int), alpha=0.5) |
|
|
155 |
# plt.imshow(y[0, 0].cpu().numpy().astype(int), alpha=0.5) |
|
|
156 |
# plt.title("IoU {} at epoch {}".format(iou(output[0, 0], mask[0, 0]), epoch)) |
|
|
157 |
plt.show() |
|
|
158 |
plot = False # plot one example per epoch |
|
|
159 |
avg_val_loss = np.mean(losses) |
|
|
160 |
avg_closs = np.mean(closses) |
|
|
161 |
return avg_val_loss, ious, closses |
|
|
162 |
|
|
|
163 |
def validate_generalizability(self, epoch, plot=False): |
|
|
164 |
self.model.eval() |
|
|
165 |
ious = torch.empty((0,)) |
|
|
166 |
with torch.no_grad(): |
|
|
167 |
for x, y, index in DataLoader(EtisDataset("Datasets/ETIS-LaribPolypDB")): |
|
|
168 |
image = x.to("cuda") |
|
|
169 |
mask = y.to("cuda") |
|
|
170 |
output, _ = self.model(image) |
|
|
171 |
batch_ious = torch.Tensor([iou(output_i, mask_j) for output_i, mask_j in zip(output, mask)]) |
|
|
172 |
ious = torch.cat((ious, batch_ious.flatten())) |
|
|
173 |
if plot: |
|
|
174 |
plt.imshow(image[0].permute(1, 2, 0).cpu().numpy()) |
|
|
175 |
plt.imshow((output[0, 0].cpu().numpy() > 0.5).astype(int), alpha=0.5) |
|
|
176 |
plt.title("IoU {} at epoch {}".format(iou(output[0, 0], mask[0, 0]), epoch)) |
|
|
177 |
plt.show() |
|
|
178 |
plot = False # plot one example per epoch (hacky, but works) |
|
|
179 |
return ious |
|
|
180 |
|
|
|
181 |
|
|
|
182 |
class InductiveNetVanillaTrainer(InductiveNetConsistencyTrainer): |
|
|
183 |
|
|
|
184 |
def __init__(self, id, config): |
|
|
185 |
super(InductiveNetVanillaTrainer, self).__init__(id, config) |
|
|
186 |
|
|
|
187 |
def train(self): |
|
|
188 |
best_val_loss = 10 |
|
|
189 |
print("Starting Segmentation training") |
|
|
190 |
best_consistency = 0 |
|
|
191 |
for i in range(self.epochs): |
|
|
192 |
training_loss = np.abs(self.train_epoch()) |
|
|
193 |
val_loss, ious, closs = self.validate(epoch=i, plot=False) |
|
|
194 |
gen_ious = self.validate_generalizability(epoch=i, plot=False) |
|
|
195 |
mean_iou = float(torch.mean(ious)) |
|
|
196 |
gen_iou = float(torch.mean(gen_ious)) |
|
|
197 |
consistency = 1 - np.mean(closs) |
|
|
198 |
test_iou = np.mean(self.test().numpy()) |
|
|
199 |
|
|
|
200 |
self.config["lr"] = [group['lr'] for group in self.optimizer.param_groups] |
|
|
201 |
logging.log_full(epoch=i, id=self.id, config=self.config, result_dict= |
|
|
202 |
{"train_loss": training_loss, "val_loss": val_loss, |
|
|
203 |
"iid_val_iou": mean_iou, "iid_test_iou": test_iou, "ood_iou": gen_iou, |
|
|
204 |
"consistency": consistency}, type="consistency") |
|
|
205 |
|
|
|
206 |
self.scheduler.step(i) |
|
|
207 |
print( |
|
|
208 |
f"Epoch {i} of {self.epochs} \t" |
|
|
209 |
f" lr={[group['lr'] for group in self.optimizer.param_groups]} \t" |
|
|
210 |
f" loss={training_loss} \t" |
|
|
211 |
f" val_loss={val_loss} \t" |
|
|
212 |
f" ood_iou={gen_iou}\t" |
|
|
213 |
f" val_iou={mean_iou} \t" |
|
|
214 |
f" gen_prop={gen_iou / mean_iou}" |
|
|
215 |
) |
|
|
216 |
if val_loss < best_val_loss: |
|
|
217 |
best_val_loss = val_loss |
|
|
218 |
np.save( |
|
|
219 |
f"experiments/Data/Normal-Pipelines/{self.model_str}/{self.id}", |
|
|
220 |
test_iou) |
|
|
221 |
print(f"Saving new best model. IID test-set mean iou: {test_iou}") |
|
|
222 |
torch.save(self.model.state_dict(), |
|
|
223 |
f"Predictors/Vanilla/{self.model_str}/{self.id}") |
|
|
224 |
print("saved in: ", f"Predictors/Vanilla/{self.model_str}/{self.id}") |
|
|
225 |
|
|
|
226 |
if consistency > best_consistency: |
|
|
227 |
best_consistency = consistency |
|
|
228 |
torch.save(self.model.state_dict(), |
|
|
229 |
f"Predictors/Vanilla/{self.model_str}/maximum_consistency{self.id}") |
|
|
230 |
torch.save(self.model.state_dict(), |
|
|
231 |
f"Predictors/Vanilla/{self.model_str}/{self.id}_last_epoch") |
|
|
232 |
|
|
|
233 |
def train_epoch(self): |
|
|
234 |
self.model.train() |
|
|
235 |
losses = [] |
|
|
236 |
for x, y, fname in self.train_loader: |
|
|
237 |
image = x.to("cuda") |
|
|
238 |
mask = y.to("cuda") |
|
|
239 |
self.optimizer.zero_grad() |
|
|
240 |
output, reconstruction = self.model(image) |
|
|
241 |
mean_iou = torch.mean(iou(output, mask)) |
|
|
242 |
loss = 0.5 * (self.jaccard(output, mask) + self.mse( |
|
|
243 |
image, reconstruction)) |
|
|
244 |
loss.backward() |
|
|
245 |
self.optimizer.step() |
|
|
246 |
losses.append(np.abs(loss.item())) |
|
|
247 |
return np.mean(losses) |
|
|
248 |
|
|
|
249 |
def test(self): |
|
|
250 |
self.model.eval() |
|
|
251 |
ious = torch.empty((0,)) |
|
|
252 |
with torch.no_grad(): |
|
|
253 |
for x, y, fname in self.test_loader: |
|
|
254 |
image = x.to("cuda") |
|
|
255 |
mask = y.to("cuda") |
|
|
256 |
output = self.model(image) |
|
|
257 |
batch_ious = torch.Tensor([iou(output_i, mask_j) for output_i, mask_j in zip(output, mask)]) |
|
|
258 |
ious = torch.cat((ious, batch_ious.flatten())) |
|
|
259 |
return ious |
|
|
260 |
|
|
|
261 |
def validate(self, epoch, plot=False): |
|
|
262 |
self.model.eval() |
|
|
263 |
losses = [] |
|
|
264 |
closses = [] |
|
|
265 |
ious = torch.empty((0,)) |
|
|
266 |
with torch.no_grad(): |
|
|
267 |
for x, y, fname in self.val_loader: |
|
|
268 |
image = x.to("cuda") |
|
|
269 |
mask = y.to("cuda") |
|
|
270 |
aug_img, aug_mask = self.mnv(image, mask) |
|
|
271 |
output, reconstruction = self.model(image) |
|
|
272 |
aug_output, _ = self.model(aug_img) # todo consider train on augmented vs non-augmented? |
|
|
273 |
|
|
|
274 |
batch_ious = torch.Tensor([iou(output_i, mask_j) for output_i, mask_j in zip(output, mask)]) |
|
|
275 |
loss = 0.5 * (self.jaccard(output, mask) + self.mse( |
|
|
276 |
image, reconstruction)) |
|
|
277 |
losses.append(np.abs(loss.item())) |
|
|
278 |
closses.append(self.nakedcloss(aug_mask, mask, aug_output, output).item()) |
|
|
279 |
ious = torch.cat((ious, batch_ious.cpu().flatten())) |
|
|
280 |
|
|
|
281 |
if plot: |
|
|
282 |
plt.imshow(output[0, 0].cpu().numpy(), alpha=0.5) |
|
|
283 |
plt.imshow(reconstruction[0].permute(1, 2, 0).cpu().numpy()) |
|
|
284 |
# plt.imshow((output[0, 0].cpu().numpy() > 0.5).astype(int), alpha=0.5) |
|
|
285 |
# plt.imshow(y[0, 0].cpu().numpy().astype(int), alpha=0.5) |
|
|
286 |
# plt.title("IoU {} at epoch {}".format(iou(output[0, 0], mask[0, 0]), epoch)) |
|
|
287 |
plt.show() |
|
|
288 |
plot = False # plot one example per epoch |
|
|
289 |
avg_val_loss = np.mean(losses) |
|
|
290 |
avg_closs = np.mean(closses) |
|
|
291 |
return avg_val_loss, ious, avg_closs |
|
|
292 |
|
|
|
293 |
def validate_generalizability(self, epoch, plot=False): |
|
|
294 |
self.model.eval() |
|
|
295 |
ious = torch.empty((0,)) |
|
|
296 |
with torch.no_grad(): |
|
|
297 |
for x, y, index in DataLoader(EtisDataset("Datasets/ETIS-LaribPolypDB")): |
|
|
298 |
image = x.to("cuda") |
|
|
299 |
mask = y.to("cuda") |
|
|
300 |
output, _ = self.model(image) |
|
|
301 |
batch_ious = torch.Tensor([iou(output_i, mask_j) for output_i, mask_j in zip(output, mask)]) |
|
|
302 |
ious = torch.cat((ious, batch_ious.flatten())) |
|
|
303 |
if plot: |
|
|
304 |
plt.imshow(image[0].permute(1, 2, 0).cpu().numpy()) |
|
|
305 |
plt.imshow((output[0, 0].cpu().numpy() > 0.5).astype(int), alpha=0.5) |
|
|
306 |
plt.title("IoU {} at epoch {}".format(iou(output[0, 0], mask[0, 0]), epoch)) |
|
|
307 |
plt.show() |
|
|
308 |
plot = False # plot one example per epoch (hacky, but works) |
|
|
309 |
return ious |
|
|
310 |
|
|
|
311 |
|
|
|
312 |
class InductiveNetAugmentationTrainer(InductiveNetConsistencyTrainer): |
|
|
313 |
""" |
|
|
314 |
Uses vanilla data augmentation with p=0.5 instead of a a custom loss |
|
|
315 |
""" |
|
|
316 |
|
|
|
317 |
def __init__(self, id, config): |
|
|
318 |
super(InductiveNetAugmentationTrainer, self).__init__(id, config) |
|
|
319 |
self.jaccard = vanilla_losses.JaccardLoss() |
|
|
320 |
self.mse = vanilla_losses.MSELoss() |
|
|
321 |
self.prob = 0 |
|
|
322 |
self.dataset = KvasirMNVset("Datasets/HyperKvasir", "train", inpaint=config["use_inpainter"]) |
|
|
323 |
self.train_loader = DataLoader(self.dataset, batch_size=config["batch_size"], shuffle=True) |
|
|
324 |
|
|
|
325 |
def get_iou_weights(self, image, mask): |
|
|
326 |
self.model.eval() |
|
|
327 |
with torch.no_grad(): |
|
|
328 |
output, _ = self.model(image) |
|
|
329 |
return torch.mean(iou(output, mask)) |
|
|
330 |
|
|
|
331 |
def get_consistency(self, image, mask, augmented, augmask): |
|
|
332 |
self.model.eval() |
|
|
333 |
with torch.no_grad(): |
|
|
334 |
output, _ = self.model(image) |
|
|
335 |
self.model.train() |
|
|
336 |
return torch.mean(self.nakedcloss(output, mask, augmented, augmask)) |
|
|
337 |
|
|
|
338 |
def train_epoch(self): |
|
|
339 |
self.model.train() |
|
|
340 |
losses = [] |
|
|
341 |
for x, y, fname, flag in self.train_loader: |
|
|
342 |
image = x.to("cuda") |
|
|
343 |
mask = y.to("cuda") |
|
|
344 |
self.optimizer.zero_grad() |
|
|
345 |
output, reconstruction = self.model(image) |
|
|
346 |
mean_iou = torch.mean(iou(output, mask)) |
|
|
347 |
loss = 0.5 * (self.jaccard(output, mask) + self.mse( |
|
|
348 |
image, reconstruction)) |
|
|
349 |
loss.backward() |
|
|
350 |
self.optimizer.step() |
|
|
351 |
losses.append(np.abs(loss.item())) |
|
|
352 |
return np.mean(losses) |
|
|
353 |
|
|
|
354 |
|
|
|
355 |
class InductiveNetEnsembleTrainer(InductiveNetConsistencyTrainer): |
|
|
356 |
def __init__(self, id, config): |
|
|
357 |
super(InductiveNetEnsembleTrainer, self).__init__(id, config) |
|
|
358 |
self.model = TrainedEnsemble("Singular") |
|
|
359 |
self.optimizer = torch.optim.Adam(self.model.parameters(), self.lr) |
|
|
360 |
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(self.optimizer, T_0=50, T_mult=2) |