|
a |
|
b/training/consistency_trainers.py |
|
|
1 |
import matplotlib.pyplot as plt |
|
|
2 |
import numpy as np |
|
|
3 |
import torch.optim.optimizer |
|
|
4 |
from torch.utils.data import DataLoader |
|
|
5 |
from tqdm import tqdm |
|
|
6 |
from data.hyperkvasir import KvasirSegmentationDataset, KvasirMNVset |
|
|
7 |
from evaluation.metrics import iou |
|
|
8 |
from losses.consistency_losses import * |
|
|
9 |
from perturbation.model import ModelOfNaturalVariation |
|
|
10 |
from training.vanilla_trainer import VanillaTrainer |
|
|
11 |
from utils import logging |
|
|
12 |
from models.ensembles import TrainedEnsemble |
|
|
13 |
|
|
|
14 |
|
|
|
15 |
class ConsistencyTrainer(VanillaTrainer): |
|
|
16 |
def __init__(self, id, config): |
|
|
17 |
super(ConsistencyTrainer, self).__init__(id, config) |
|
|
18 |
self.criterion = ConsistencyLoss().to(self.device) |
|
|
19 |
self.nakedcloss = NakedConsistencyLoss() |
|
|
20 |
|
|
|
21 |
def train_epoch(self): |
|
|
22 |
self.model.train() |
|
|
23 |
losses = [] |
|
|
24 |
for x, y, fname in self.train_loader: |
|
|
25 |
image = x.to("cuda") |
|
|
26 |
mask = y.to("cuda") |
|
|
27 |
aug_img, aug_mask = self.mnv(image, mask) |
|
|
28 |
self.optimizer.zero_grad() |
|
|
29 |
output = self.model(image) |
|
|
30 |
aug_output = self.model(aug_img) |
|
|
31 |
mean_iou = torch.mean(iou(output, mask)) |
|
|
32 |
loss = self.criterion(aug_mask, mask, aug_output, output, mean_iou) |
|
|
33 |
loss.backward() |
|
|
34 |
self.optimizer.step() |
|
|
35 |
losses.append(np.abs(loss.item())) |
|
|
36 |
return np.mean(losses) |
|
|
37 |
|
|
|
38 |
def train(self): |
|
|
39 |
|
|
|
40 |
best_val_loss = 1000 |
|
|
41 |
best_consistency = 0 |
|
|
42 |
print("Starting Segmentation training") |
|
|
43 |
for i in range(self.epochs): |
|
|
44 |
training_loss = np.abs(self.train_epoch()) |
|
|
45 |
val_loss, ious, closs = self.validate(epoch=i, plot=False) |
|
|
46 |
gen_ious = self.validate_generalizability(epoch=i, plot=False) |
|
|
47 |
mean_iou = float(torch.mean(ious)) |
|
|
48 |
gen_iou = float(torch.mean(gen_ious)) |
|
|
49 |
consistency = 1 - np.mean(closs) |
|
|
50 |
test_iou = np.mean(self.test().numpy()) |
|
|
51 |
|
|
|
52 |
self.config["lr"] = [group['lr'] for group in self.optimizer.param_groups] |
|
|
53 |
logging.log_full(epoch=i, id=self.id, config=self.config, result_dict= |
|
|
54 |
{"train_loss": training_loss, "val_loss": val_loss, |
|
|
55 |
"iid_val_iou": mean_iou, "iid_test_iou": test_iou, "ood_iou": gen_iou, |
|
|
56 |
"consistency": consistency}, type="consistency") |
|
|
57 |
|
|
|
58 |
self.scheduler.step(i) |
|
|
59 |
# self.mnv.step() |
|
|
60 |
print( |
|
|
61 |
f"Epoch {i} of {self.epochs} \t" |
|
|
62 |
f" lr={[group['lr'] for group in self.optimizer.param_groups]} \t" |
|
|
63 |
f" loss={training_loss} \t" |
|
|
64 |
f" val_loss={val_loss} \t" |
|
|
65 |
f" ood_iou={gen_iou}\t" |
|
|
66 |
f" val_iou={mean_iou} \t" |
|
|
67 |
f" gen_prop={gen_iou / mean_iou} \t," |
|
|
68 |
f" consistency={consistency}" |
|
|
69 |
) |
|
|
70 |
|
|
|
71 |
if val_loss < best_val_loss: |
|
|
72 |
best_val_loss = val_loss |
|
|
73 |
np.save( |
|
|
74 |
f"experiments/Data/Augmented-Pipelines/{self.model_str}/{self.id}", |
|
|
75 |
test_iou) |
|
|
76 |
print(f"Saving new best model. IID test-set mean iou: {test_iou}") |
|
|
77 |
torch.save(self.model.state_dict(), |
|
|
78 |
f"Predictors/Augmented/{self.model_str}/{self.id}") |
|
|
79 |
print("saved in: ", f"Predictors/Augmented/{self.model_str}/{self.id}") |
|
|
80 |
|
|
|
81 |
if consistency > best_consistency: |
|
|
82 |
best_consistency = consistency |
|
|
83 |
torch.save(self.model.state_dict(), |
|
|
84 |
f"Predictors/Augmented/{self.model_str}/maximum_consistency{self.id}") |
|
|
85 |
torch.save(self.model.state_dict(), |
|
|
86 |
f"Predictors/Augmented/{self.model_str}/{self.id}_last_epoch") |
|
|
87 |
|
|
|
88 |
def test(self): |
|
|
89 |
self.model.eval() |
|
|
90 |
ious = torch.empty((0,)) |
|
|
91 |
with torch.no_grad(): |
|
|
92 |
for x, y, fname in self.test_loader: |
|
|
93 |
image = x.to("cuda") |
|
|
94 |
mask = y.to("cuda") |
|
|
95 |
output = self.model(image) |
|
|
96 |
batch_ious = torch.Tensor([iou(output_i, mask_j) for output_i, mask_j in zip(output, mask)]) |
|
|
97 |
ious = torch.cat((ious, batch_ious.flatten())) |
|
|
98 |
return ious |
|
|
99 |
|
|
|
100 |
|
|
|
101 |
class ConsistencyTrainerUsingAugmentation(ConsistencyTrainer): |
|
|
102 |
""" |
|
|
103 |
Uses vanilla data augmentation with p=0.5 instead of a a custom loss |
|
|
104 |
""" |
|
|
105 |
|
|
|
106 |
def __init__(self, id, config): |
|
|
107 |
super(ConsistencyTrainerUsingAugmentation, self).__init__(id, config) |
|
|
108 |
self.jaccard = vanilla_losses.JaccardLoss() |
|
|
109 |
self.prob = 0 |
|
|
110 |
self.dataset = KvasirMNVset("Datasets/HyperKvasir", "train", inpaint=config["use_inpainter"]) |
|
|
111 |
self.train_loader = DataLoader(self.dataset, batch_size=config["batch_size"], shuffle=True) |
|
|
112 |
|
|
|
113 |
def get_iou_weights(self, image, mask): |
|
|
114 |
self.model.eval() |
|
|
115 |
with torch.no_grad(): |
|
|
116 |
output = self.model(image) |
|
|
117 |
return torch.mean(iou(output, mask)) |
|
|
118 |
|
|
|
119 |
def get_consistency(self, image, mask, augmented, augmask): |
|
|
120 |
self.model.eval() |
|
|
121 |
with torch.no_grad(): |
|
|
122 |
output = self.model(image) |
|
|
123 |
self.model.train() |
|
|
124 |
return torch.mean(self.nakedcloss(output, mask, augmented, augmask)) |
|
|
125 |
|
|
|
126 |
def train_epoch(self): |
|
|
127 |
self.model.train() |
|
|
128 |
losses = [] |
|
|
129 |
plotted = False |
|
|
130 |
for x, y, fname, flag in self.train_loader: |
|
|
131 |
image = x.to("cuda") |
|
|
132 |
mask = y.to("cuda") |
|
|
133 |
self.optimizer.zero_grad() |
|
|
134 |
output = self.model(image) |
|
|
135 |
loss = self.jaccard(output, mask) |
|
|
136 |
loss.backward() |
|
|
137 |
self.optimizer.step() |
|
|
138 |
losses.append(np.abs(loss.item())) |
|
|
139 |
return np.mean(losses) |
|
|
140 |
|
|
|
141 |
|
|
|
142 |
class AdversarialConsistencyTrainer(ConsistencyTrainer): |
|
|
143 |
""" |
|
|
144 |
Adversariall samples difficult |
|
|
145 |
""" |
|
|
146 |
|
|
|
147 |
def __init__(self, id, config): |
|
|
148 |
super(ConsistencyTrainer, self).__init__(id, config) |
|
|
149 |
self.mnv = ModelOfNaturalVariation(T0=1).to(self.device) |
|
|
150 |
self.num_adv_samples = 10 |
|
|
151 |
self.naked_closs = NakedConsistencyLoss() |
|
|
152 |
self.criterion = ConsistencyLoss().to(self.device) |
|
|
153 |
|
|
|
154 |
def sample_adversarial(self, image, mask, output): |
|
|
155 |
self.model.eval() |
|
|
156 |
aug_img, aug_mask = None, None # |
|
|
157 |
max_severity = -10 |
|
|
158 |
with torch.no_grad(): |
|
|
159 |
for i in range(self.num_adv_samples): |
|
|
160 |
adv_aug_img, adv_aug_mask = self.mnv(image, mask) |
|
|
161 |
adv_aug_output = self.model(adv_aug_img) |
|
|
162 |
severity = self.naked_closs(adv_aug_mask, mask, adv_aug_output, output) |
|
|
163 |
|
|
|
164 |
if severity > max_severity: |
|
|
165 |
max_severity = severity |
|
|
166 |
aug_img = adv_aug_img |
|
|
167 |
aug_mask = adv_aug_mask |
|
|
168 |
self.model.train() |
|
|
169 |
return aug_img, aug_mask |
|
|
170 |
|
|
|
171 |
def train_epoch(self): |
|
|
172 |
self.model.train() |
|
|
173 |
losses = [] |
|
|
174 |
for x, y, fname in self.train_loader: |
|
|
175 |
image = x.to("cuda") |
|
|
176 |
mask = y.to("cuda") |
|
|
177 |
self.optimizer.zero_grad() |
|
|
178 |
output = self.model(image) |
|
|
179 |
aug_img, aug_mask = self.sample_adversarial(image, mask, output) |
|
|
180 |
# aug_img, aug_mask = self.mnv(image, mask) |
|
|
181 |
aug_output = self.model(aug_img) |
|
|
182 |
mean_iou = torch.mean(iou(output, mask)) |
|
|
183 |
loss = self.criterion(aug_mask, mask, aug_output, output, mean_iou) |
|
|
184 |
loss.backward() |
|
|
185 |
self.optimizer.step() |
|
|
186 |
losses.append(np.abs(loss.item())) |
|
|
187 |
return np.mean(losses) |
|
|
188 |
|
|
|
189 |
|
|
|
190 |
class StrictConsistencyTrainer(ConsistencyTrainer): |
|
|
191 |
def __init__(self, id, config): |
|
|
192 |
super(StrictConsistencyTrainer, self).__init__(id, config) |
|
|
193 |
self.criterion = StrictConsistencyLoss() |
|
|
194 |
|
|
|
195 |
|
|
|
196 |
class ConsistencyTrainerUsingControlledAugmentation(ConsistencyTrainer): |
|
|
197 |
""" |
|
|
198 |
Uses vanilla data augmentation with p=0.5 instead of a a custom loss and has two samples |
|
|
199 |
""" |
|
|
200 |
|
|
|
201 |
def __init__(self, id, config): |
|
|
202 |
super(ConsistencyTrainerUsingControlledAugmentation, self).__init__(id, config) |
|
|
203 |
self.jaccard = vanilla_losses.JaccardLoss() |
|
|
204 |
self.mnv = ModelOfNaturalVariation(1) |
|
|
205 |
|
|
|
206 |
def train_epoch(self): |
|
|
207 |
self.model.train() |
|
|
208 |
losses = [] |
|
|
209 |
plotted = False |
|
|
210 |
for x, y, fname in self.train_loader: |
|
|
211 |
image = x.to("cuda") |
|
|
212 |
mask = y.to("cuda") |
|
|
213 |
aug_img, aug_mask = self.mnv(image, mask) |
|
|
214 |
img_batch = torch.cat((image, aug_img)) |
|
|
215 |
mask_batch = torch.cat((mask, aug_mask)) |
|
|
216 |
self.optimizer.zero_grad() |
|
|
217 |
output = self.model(img_batch) |
|
|
218 |
loss = self.jaccard(output, mask_batch) |
|
|
219 |
loss.backward() |
|
|
220 |
self.optimizer.step() |
|
|
221 |
losses.append(np.abs(loss.item())) |
|
|
222 |
return np.mean(losses) |
|
|
223 |
|
|
|
224 |
|
|
|
225 |
class EnsembleConsistencyTrainer(ConsistencyTrainer): |
|
|
226 |
def __init__(self, id, config): |
|
|
227 |
super(EnsembleConsistencyTrainer, self).__init__(id, config) |
|
|
228 |
self.model = TrainedEnsemble("Singular") |
|
|
229 |
self.optimizer = torch.optim.Adam(self.model.parameters(), self.lr) |
|
|
230 |
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(self.optimizer, T_0=50, T_mult=2) |