|
a |
|
b/rocaseg/train_uda1.py |
|
|
1 |
import os |
|
|
2 |
import logging |
|
|
3 |
from collections import defaultdict |
|
|
4 |
import gc |
|
|
5 |
import click |
|
|
6 |
import resource |
|
|
7 |
|
|
|
8 |
import numpy as np |
|
|
9 |
import cv2 |
|
|
10 |
|
|
|
11 |
import torch |
|
|
12 |
import torch.nn.functional as torch_fn |
|
|
13 |
from torch import nn |
|
|
14 |
from torch.utils.data.dataloader import DataLoader |
|
|
15 |
from torch.utils.tensorboard import SummaryWriter |
|
|
16 |
from tqdm import tqdm |
|
|
17 |
|
|
|
18 |
from rocaseg.datasets import (DatasetOAIiMoSagittal2d, |
|
|
19 |
DatasetOKOASagittal2d, |
|
|
20 |
DatasetMAKNEESagittal2d, |
|
|
21 |
sources_from_path) |
|
|
22 |
from rocaseg.models import dict_models |
|
|
23 |
from rocaseg.components import (dict_losses, confusion_matrix, dice_score_from_cm, |
|
|
24 |
dict_optimizers, CheckpointHandler) |
|
|
25 |
from rocaseg.preproc import * |
|
|
26 |
from rocaseg.repro import set_ultimate_seed |
|
|
27 |
from rocaseg.components.mixup import mixup_criterion, mixup_data |
|
|
28 |
|
|
|
29 |
|
|
|
30 |
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) |
|
|
31 |
resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1])) |
|
|
32 |
|
|
|
33 |
cv2.ocl.setUseOpenCL(False) |
|
|
34 |
cv2.setNumThreads(0) |
|
|
35 |
|
|
|
36 |
logging.basicConfig() |
|
|
37 |
logger = logging.getLogger('train') |
|
|
38 |
logger.setLevel(logging.DEBUG) |
|
|
39 |
|
|
|
40 |
set_ultimate_seed() |
|
|
41 |
|
|
|
42 |
if torch.cuda.is_available(): |
|
|
43 |
maybe_gpu = 'cuda' |
|
|
44 |
else: |
|
|
45 |
maybe_gpu = 'cpu' |
|
|
46 |
|
|
|
47 |
|
|
|
48 |
class ModelTrainer: |
|
|
49 |
def __init__(self, config, fold_idx=None): |
|
|
50 |
self.config = config |
|
|
51 |
self.fold_idx = fold_idx |
|
|
52 |
|
|
|
53 |
self.paths_weights_fold = dict() |
|
|
54 |
self.paths_weights_fold['segm'] = \ |
|
|
55 |
os.path.join(config['path_weights'], 'segm', f'fold_{self.fold_idx}') |
|
|
56 |
os.makedirs(self.paths_weights_fold['segm'], exist_ok=True) |
|
|
57 |
self.paths_weights_fold['discr'] = \ |
|
|
58 |
os.path.join(config['path_weights'], 'discr', f'fold_{self.fold_idx}') |
|
|
59 |
os.makedirs(self.paths_weights_fold['discr'], exist_ok=True) |
|
|
60 |
|
|
|
61 |
self.path_logs_fold = \ |
|
|
62 |
os.path.join(config['path_logs'], f'fold_{self.fold_idx}') |
|
|
63 |
os.makedirs(self.path_logs_fold, exist_ok=True) |
|
|
64 |
|
|
|
65 |
self.handlers_ckpt = dict() |
|
|
66 |
self.handlers_ckpt['segm'] = CheckpointHandler(self.paths_weights_fold['segm']) |
|
|
67 |
self.handlers_ckpt['discr'] = CheckpointHandler(self.paths_weights_fold['discr']) |
|
|
68 |
|
|
|
69 |
paths_ckpt_sel = dict() |
|
|
70 |
paths_ckpt_sel['segm'] = self.handlers_ckpt['segm'].get_last_ckpt() |
|
|
71 |
paths_ckpt_sel['discr'] = self.handlers_ckpt['discr'].get_last_ckpt() |
|
|
72 |
|
|
|
73 |
# Initialize and configure the models |
|
|
74 |
self.models = dict() |
|
|
75 |
self.models['segm'] = (dict_models[config['model_segm']] |
|
|
76 |
(input_channels=self.config['input_channels'], |
|
|
77 |
output_channels=self.config['output_channels'], |
|
|
78 |
center_depth=self.config['center_depth'], |
|
|
79 |
pretrained=self.config['pretrained'], |
|
|
80 |
path_pretrained=self.config['path_pretrained_segm'], |
|
|
81 |
restore_weights=self.config['restore_weights'], |
|
|
82 |
path_weights=paths_ckpt_sel['segm'])) |
|
|
83 |
self.models['segm'] = nn.DataParallel(self.models['segm']) |
|
|
84 |
self.models['segm'] = self.models['segm'].to(maybe_gpu) |
|
|
85 |
|
|
|
86 |
self.models['discr'] = (dict_models[config['model_discr']] |
|
|
87 |
(input_channels=self.config['output_channels'], |
|
|
88 |
output_channels=1, |
|
|
89 |
pretrained=self.config['pretrained'], |
|
|
90 |
restore_weights=self.config['restore_weights'], |
|
|
91 |
path_weights=paths_ckpt_sel['discr'])) |
|
|
92 |
self.models['discr'] = nn.DataParallel(self.models['discr']) |
|
|
93 |
self.models['discr'] = self.models['discr'].to(maybe_gpu) |
|
|
94 |
|
|
|
95 |
# Configure the training |
|
|
96 |
self.optimizers = dict() |
|
|
97 |
self.optimizers['segm'] = (dict_optimizers['adam']( |
|
|
98 |
self.models['segm'].parameters(), |
|
|
99 |
lr=self.config['lr_segm'], |
|
|
100 |
weight_decay=self.config['wd_segm'])) |
|
|
101 |
self.optimizers['discr'] = (dict_optimizers['adam']( |
|
|
102 |
self.models['discr'].parameters(), |
|
|
103 |
lr=self.config['lr_discr'], |
|
|
104 |
weight_decay=self.config['wd_discr'])) |
|
|
105 |
|
|
|
106 |
self.lr_update_rule = {25: 0.1} |
|
|
107 |
|
|
|
108 |
self.losses = dict() |
|
|
109 |
self.losses['segm'] = dict_losses[self.config['loss_segm']]( |
|
|
110 |
num_classes=self.config['output_channels'], |
|
|
111 |
) |
|
|
112 |
self.losses['advers'] = dict_losses['bce_loss']() |
|
|
113 |
self.losses['discr'] = dict_losses['bce_loss']() |
|
|
114 |
|
|
|
115 |
self.losses['segm'] = self.losses['segm'].to(maybe_gpu) |
|
|
116 |
self.losses['advers'] = self.losses['advers'].to(maybe_gpu) |
|
|
117 |
self.losses['discr'] = self.losses['discr'].to(maybe_gpu) |
|
|
118 |
|
|
|
119 |
self.tensorboard = SummaryWriter(self.path_logs_fold) |
|
|
120 |
|
|
|
121 |
def run_one_epoch(self, epoch_idx, loaders): |
|
|
122 |
COEFF_DISCR = 1 |
|
|
123 |
COEFF_SEGM = 1 |
|
|
124 |
COEFF_ADVERS = 0.001 |
|
|
125 |
|
|
|
126 |
fnames_acc = defaultdict(list) |
|
|
127 |
metrics_acc = dict() |
|
|
128 |
metrics_acc['samplew'] = defaultdict(list) |
|
|
129 |
metrics_acc['batchw'] = defaultdict(list) |
|
|
130 |
metrics_acc['datasetw'] = defaultdict(list) |
|
|
131 |
metrics_acc['datasetw']['cm_oai'] = \ |
|
|
132 |
np.zeros((self.config['output_channels'],) * 2, dtype=np.uint32) |
|
|
133 |
metrics_acc['datasetw']['cm_okoa'] = \ |
|
|
134 |
np.zeros((self.config['output_channels'],) * 2, dtype=np.uint32) |
|
|
135 |
|
|
|
136 |
prog_bar_params = {'postfix': {'epoch': epoch_idx}, } |
|
|
137 |
|
|
|
138 |
if self.models['segm'].training and self.models['discr'].training: |
|
|
139 |
# ------------------------ Training regime ------------------------ |
|
|
140 |
loader_oai = loaders['oai_imo']['train'] |
|
|
141 |
loader_maknee = loaders['maknee']['train'] |
|
|
142 |
|
|
|
143 |
steps_oai, steps_maknee = len(loader_oai), len(loader_maknee) |
|
|
144 |
steps_total = steps_oai |
|
|
145 |
prog_bar_params.update({'total': steps_total, |
|
|
146 |
'desc': f'Train, epoch {epoch_idx}'}) |
|
|
147 |
|
|
|
148 |
loader_oai_iter = iter(loader_oai) |
|
|
149 |
loader_maknee_iter = iter(loader_maknee) |
|
|
150 |
|
|
|
151 |
loader_oai_iter_old = None |
|
|
152 |
loader_maknee_iter_old = None |
|
|
153 |
|
|
|
154 |
with tqdm(**prog_bar_params) as prog_bar: |
|
|
155 |
for step_idx in range(steps_total): |
|
|
156 |
self.optimizers['segm'].zero_grad() |
|
|
157 |
self.optimizers['discr'].zero_grad() |
|
|
158 |
|
|
|
159 |
metrics_acc['batchw']['loss_total'].append(0) |
|
|
160 |
|
|
|
161 |
try: |
|
|
162 |
data_batch_oai = next(loader_oai_iter) |
|
|
163 |
except StopIteration: |
|
|
164 |
loader_oai_iter_old = loader_oai_iter |
|
|
165 |
loader_oai_iter = iter(loader_oai) |
|
|
166 |
data_batch_oai = next(loader_oai_iter) |
|
|
167 |
|
|
|
168 |
try: |
|
|
169 |
data_batch_maknee = next(loader_maknee_iter) |
|
|
170 |
except StopIteration: |
|
|
171 |
loader_maknee_iter_old = loader_maknee_iter |
|
|
172 |
loader_maknee_iter = iter(loader_maknee) |
|
|
173 |
data_batch_maknee = next(loader_maknee_iter) |
|
|
174 |
|
|
|
175 |
xs_oai, ys_true_oai = data_batch_oai['xs'], data_batch_oai['ys'] |
|
|
176 |
fnames_acc['oai'].extend(data_batch_oai['path_image']) |
|
|
177 |
xs_oai = xs_oai.to(maybe_gpu) |
|
|
178 |
ys_true_arg_oai = torch.argmax(ys_true_oai.long().to(maybe_gpu), dim=1) |
|
|
179 |
|
|
|
180 |
xs_maknee, _ = data_batch_maknee['xs'], data_batch_maknee['ys'] |
|
|
181 |
fnames_acc['maknee'].extend(data_batch_maknee['path_image']) |
|
|
182 |
xs_maknee = xs_maknee.to(maybe_gpu) |
|
|
183 |
|
|
|
184 |
# -------------- Train discriminator network ------------- |
|
|
185 |
# With source |
|
|
186 |
ys_pred_oai = self.models['segm'](xs_oai) |
|
|
187 |
ys_pred_softmax_oai = torch_fn.softmax(ys_pred_oai, dim=1) |
|
|
188 |
|
|
|
189 |
zs_pred_oai = self.models['discr'](ys_pred_softmax_oai) |
|
|
190 |
|
|
|
191 |
# Use 0 as a label for the source domain |
|
|
192 |
loss_discr_0 = self.losses['discr']( |
|
|
193 |
input=zs_pred_oai, |
|
|
194 |
target=torch.zeros_like(zs_pred_oai, device=maybe_gpu)) |
|
|
195 |
loss_discr_0 = loss_discr_0 / 2 * COEFF_DISCR |
|
|
196 |
loss_discr_0.backward(retain_graph=True) |
|
|
197 |
metrics_acc['batchw']['loss_discr_0'].append(loss_discr_0.item()) |
|
|
198 |
metrics_acc['batchw']['loss_total'][-1] += \ |
|
|
199 |
metrics_acc['batchw']['loss_discr_0'][-1] |
|
|
200 |
|
|
|
201 |
# With target |
|
|
202 |
self.models['segm'] = self.models['segm'].eval() |
|
|
203 |
ys_pred_maknee = self.models['segm'](xs_maknee) |
|
|
204 |
self.models['segm'] = self.models['segm'].train() |
|
|
205 |
|
|
|
206 |
ys_pred_softmax_maknee = torch_fn.softmax(ys_pred_maknee, dim=1) |
|
|
207 |
zs_pred_maknee = self.models['discr'](ys_pred_softmax_maknee) |
|
|
208 |
|
|
|
209 |
# Use 1 as a label for the target domain |
|
|
210 |
loss_discr_1 = self.losses['discr']( |
|
|
211 |
input=zs_pred_maknee, |
|
|
212 |
target=torch.ones_like(zs_pred_maknee, device=maybe_gpu)) |
|
|
213 |
loss_discr_1 = loss_discr_1 / 2 * COEFF_DISCR |
|
|
214 |
loss_discr_1.backward() |
|
|
215 |
metrics_acc['batchw']['loss_discr_1'].append(loss_discr_1.item()) |
|
|
216 |
metrics_acc['batchw']['loss_total'][-1] += \ |
|
|
217 |
metrics_acc['batchw']['loss_discr_1'][-1] |
|
|
218 |
|
|
|
219 |
self.models['segm'].zero_grad() |
|
|
220 |
self.optimizers['discr'].step() |
|
|
221 |
self.models['discr'].zero_grad() |
|
|
222 |
|
|
|
223 |
# ---------------- Train segmentation network ------------ |
|
|
224 |
# With source |
|
|
225 |
if not self.config['with_mixup']: |
|
|
226 |
ys_pred_oai = self.models['segm'](xs_oai) |
|
|
227 |
loss_segm = self.losses['segm'](input_=ys_pred_oai, |
|
|
228 |
target=ys_true_arg_oai) |
|
|
229 |
else: |
|
|
230 |
xs_mixup, ys_mixup_a, ys_mixup_b, lambda_mixup = mixup_data( |
|
|
231 |
x=xs_oai, y=ys_true_arg_oai, |
|
|
232 |
alpha=self.config['mixup_alpha'], device=maybe_gpu) |
|
|
233 |
ys_pred_oai = self.models['segm'](xs_mixup) |
|
|
234 |
loss_segm = mixup_criterion(criterion=self.losses['segm'], |
|
|
235 |
pred=ys_pred_oai, |
|
|
236 |
y_a=ys_mixup_a, |
|
|
237 |
y_b=ys_mixup_b, |
|
|
238 |
lam=lambda_mixup) |
|
|
239 |
|
|
|
240 |
loss_segm.backward(retain_graph=True) |
|
|
241 |
loss_segm = loss_segm * COEFF_SEGM |
|
|
242 |
metrics_acc['batchw']['loss_segm'].append(loss_segm.item()) |
|
|
243 |
metrics_acc['batchw']['loss_total'][-1] += \ |
|
|
244 |
metrics_acc['batchw']['loss_segm'][-1] |
|
|
245 |
|
|
|
246 |
# With target |
|
|
247 |
self.models['segm'] = self.models['segm'].eval() |
|
|
248 |
ys_pred_maknee = self.models['segm'](xs_maknee) |
|
|
249 |
self.models['segm'] = self.models['segm'].train() |
|
|
250 |
|
|
|
251 |
ys_pred_softmax_maknee = torch_fn.softmax(ys_pred_maknee, dim=1) |
|
|
252 |
zs_pred_maknee = self.models['discr'](ys_pred_softmax_maknee) |
|
|
253 |
|
|
|
254 |
# Use 0 as a label for the source domain |
|
|
255 |
loss_advers = self.losses['advers']( |
|
|
256 |
input=zs_pred_maknee, |
|
|
257 |
target=torch.zeros_like(zs_pred_maknee, device=maybe_gpu)) |
|
|
258 |
loss_advers = loss_advers * COEFF_ADVERS |
|
|
259 |
loss_advers.backward() |
|
|
260 |
metrics_acc['batchw']['loss_advers'].append(loss_advers.item()) |
|
|
261 |
metrics_acc['batchw']['loss_total'][-1] += \ |
|
|
262 |
metrics_acc['batchw']['loss_advers'][-1] |
|
|
263 |
|
|
|
264 |
self.models['discr'].zero_grad() |
|
|
265 |
self.optimizers['segm'].step() |
|
|
266 |
|
|
|
267 |
if step_idx % 10 == 0: |
|
|
268 |
self.tensorboard.add_scalars( |
|
|
269 |
f'fold_{self.fold_idx}/losses_train', |
|
|
270 |
{'discr_0_batchw': metrics_acc['batchw']['loss_discr_0'][-1], |
|
|
271 |
'discr_1_batchw': metrics_acc['batchw']['loss_discr_1'][-1], |
|
|
272 |
'discr_sum_batchw': |
|
|
273 |
(metrics_acc['batchw']['loss_discr_0'][-1] + |
|
|
274 |
metrics_acc['batchw']['loss_discr_1'][-1]), |
|
|
275 |
'segm_batchw': metrics_acc['batchw']['loss_segm'][-1], |
|
|
276 |
'advers_batchw': |
|
|
277 |
metrics_acc['batchw']['loss_advers'][-1], |
|
|
278 |
'total_batchw': metrics_acc['batchw']['loss_total'][-1], |
|
|
279 |
}, global_step=(epoch_idx * steps_total + step_idx)) |
|
|
280 |
|
|
|
281 |
prog_bar.update(1) |
|
|
282 |
|
|
|
283 |
del [loader_oai_iter_old, loader_maknee_iter_old] |
|
|
284 |
gc.collect() |
|
|
285 |
else: |
|
|
286 |
# ----------------------- Validation regime ----------------------- |
|
|
287 |
loader_oai = loaders['oai_imo']['val'] |
|
|
288 |
loader_okoa = loaders['okoa']['val'] |
|
|
289 |
loader_maknee = loaders['maknee']['val'] |
|
|
290 |
|
|
|
291 |
steps_oai, steps_okoa, steps_maknee = len(loader_oai), len(loader_okoa), len(loader_maknee) |
|
|
292 |
steps_total = steps_oai |
|
|
293 |
prog_bar_params.update({'total': steps_total, |
|
|
294 |
'desc': f'Validate, epoch {epoch_idx}'}) |
|
|
295 |
|
|
|
296 |
loader_oai_iter = iter(loader_oai) |
|
|
297 |
loader_okoa_iter = iter(loader_okoa) |
|
|
298 |
loader_maknee_iter = iter(loader_maknee) |
|
|
299 |
|
|
|
300 |
loader_oai_iter_old = None |
|
|
301 |
loader_okoa_iter_old = None |
|
|
302 |
loader_maknee_iter_old = None |
|
|
303 |
|
|
|
304 |
with torch.no_grad(), tqdm(**prog_bar_params) as prog_bar: |
|
|
305 |
for step_idx in range(steps_total): |
|
|
306 |
metrics_acc['batchw']['loss_total'].append(0) |
|
|
307 |
|
|
|
308 |
try: |
|
|
309 |
data_batch_oai = next(loader_oai_iter) |
|
|
310 |
except StopIteration: |
|
|
311 |
loader_oai_iter_old = loader_oai_iter |
|
|
312 |
loader_oai_iter = iter(loader_oai) |
|
|
313 |
data_batch_oai = next(loader_oai_iter) |
|
|
314 |
|
|
|
315 |
try: |
|
|
316 |
data_batch_okoa = next(loader_okoa_iter) |
|
|
317 |
except StopIteration: |
|
|
318 |
loader_okoa_iter_old = loader_okoa_iter |
|
|
319 |
loader_okoa_iter = iter(loader_okoa) |
|
|
320 |
data_batch_okoa = next(loader_okoa_iter) |
|
|
321 |
|
|
|
322 |
try: |
|
|
323 |
data_batch_maknee = next(loader_maknee_iter) |
|
|
324 |
except StopIteration: |
|
|
325 |
loader_maknee_iter_old = loader_maknee_iter |
|
|
326 |
loader_maknee_iter = iter(loader_maknee) |
|
|
327 |
data_batch_maknee = next(loader_maknee_iter) |
|
|
328 |
|
|
|
329 |
xs_oai, ys_true_oai = data_batch_oai['xs'], data_batch_oai['ys'] |
|
|
330 |
fnames_acc['oai'].extend(data_batch_oai['path_image']) |
|
|
331 |
xs_oai = xs_oai.to(maybe_gpu) |
|
|
332 |
ys_true_arg_oai = torch.argmax(ys_true_oai.long().to(maybe_gpu), dim=1) |
|
|
333 |
|
|
|
334 |
xs_maknee, _ = data_batch_maknee['xs'], data_batch_maknee['ys'] |
|
|
335 |
fnames_acc['maknee'].extend(data_batch_maknee['path_image']) |
|
|
336 |
xs_maknee = xs_maknee.to(maybe_gpu) |
|
|
337 |
|
|
|
338 |
# -------------- Validate discriminator network ------------- |
|
|
339 |
# With source |
|
|
340 |
ys_pred_oai = self.models['segm'](xs_oai) |
|
|
341 |
ys_pred_softmax_oai = torch_fn.softmax(ys_pred_oai, dim=1) |
|
|
342 |
|
|
|
343 |
zs_pred_oai = self.models['discr'](ys_pred_softmax_oai) |
|
|
344 |
|
|
|
345 |
# Use 0 as a label for the source domain |
|
|
346 |
loss_discr_0 = self.losses['discr']( |
|
|
347 |
input=zs_pred_oai, |
|
|
348 |
target=torch.zeros_like(zs_pred_oai, device=maybe_gpu)) |
|
|
349 |
loss_discr_0 = loss_discr_0 / 2 * COEFF_DISCR |
|
|
350 |
metrics_acc['batchw']['loss_discr_0'].append(loss_discr_0.item()) |
|
|
351 |
metrics_acc['batchw']['loss_total'][-1] += \ |
|
|
352 |
metrics_acc['batchw']['loss_discr_0'][-1] |
|
|
353 |
|
|
|
354 |
# With target |
|
|
355 |
ys_pred_maknee = self.models['segm'](xs_maknee) |
|
|
356 |
|
|
|
357 |
ys_pred_softmax_maknee = torch_fn.softmax(ys_pred_maknee, dim=1) |
|
|
358 |
zs_pred_maknee = self.models['discr'](ys_pred_softmax_maknee) |
|
|
359 |
|
|
|
360 |
# Use 1 as a label for the target domain |
|
|
361 |
loss_discr_1 = self.losses['discr']( |
|
|
362 |
input=zs_pred_maknee, |
|
|
363 |
target=torch.ones_like(zs_pred_oai, device=maybe_gpu)) |
|
|
364 |
loss_discr_1 = loss_discr_1 / 2 * COEFF_DISCR |
|
|
365 |
metrics_acc['batchw']['loss_discr_1'].append(loss_discr_1.item()) |
|
|
366 |
metrics_acc['batchw']['loss_total'][-1] += \ |
|
|
367 |
metrics_acc['batchw']['loss_discr_1'][-1] |
|
|
368 |
|
|
|
369 |
# ---------------- Validate segmentation network ------------ |
|
|
370 |
# With source |
|
|
371 |
if not self.config['with_mixup']: |
|
|
372 |
ys_pred_oai = self.models['segm'](xs_oai) |
|
|
373 |
loss_segm = self.losses['segm'](input_=ys_pred_oai, |
|
|
374 |
target=ys_true_arg_oai) |
|
|
375 |
else: |
|
|
376 |
xs_mixup, ys_mixup_a, ys_mixup_b, lambda_mixup = mixup_data( |
|
|
377 |
x=xs_oai, y=ys_true_arg_oai, |
|
|
378 |
alpha=self.config['mixup_alpha'], device=maybe_gpu) |
|
|
379 |
ys_pred_oai = self.models['segm'](xs_mixup) |
|
|
380 |
loss_segm = mixup_criterion(criterion=self.losses['segm'], |
|
|
381 |
pred=ys_pred_oai, |
|
|
382 |
y_a=ys_mixup_a, |
|
|
383 |
y_b=ys_mixup_b, |
|
|
384 |
lam=lambda_mixup) |
|
|
385 |
|
|
|
386 |
loss_segm = loss_segm * COEFF_SEGM |
|
|
387 |
metrics_acc['batchw']['loss_segm'].append(loss_segm.item()) |
|
|
388 |
metrics_acc['batchw']['loss_total'][-1] += \ |
|
|
389 |
metrics_acc['batchw']['loss_segm'][-1] |
|
|
390 |
|
|
|
391 |
# With target |
|
|
392 |
ys_pred_maknee = self.models['segm'](xs_maknee) |
|
|
393 |
|
|
|
394 |
ys_pred_softmax_maknee = torch_fn.softmax(ys_pred_maknee, dim=1) |
|
|
395 |
zs_pred_maknee = self.models['discr'](ys_pred_softmax_maknee) |
|
|
396 |
|
|
|
397 |
# Use 0 as a label for the source domain |
|
|
398 |
loss_advers = self.losses['advers']( |
|
|
399 |
input=zs_pred_maknee, |
|
|
400 |
target=torch.zeros_like(zs_pred_maknee, device=maybe_gpu)) |
|
|
401 |
loss_advers = loss_advers * COEFF_ADVERS |
|
|
402 |
metrics_acc['batchw']['loss_advers'].append(loss_advers.item()) |
|
|
403 |
metrics_acc['batchw']['loss_total'][-1] += \ |
|
|
404 |
metrics_acc['batchw']['loss_advers'][-1] |
|
|
405 |
|
|
|
406 |
if step_idx % 10 == 0: |
|
|
407 |
self.tensorboard.add_scalars( |
|
|
408 |
f'fold_{self.fold_idx}/losses_val', |
|
|
409 |
{'discr_0_batchw': metrics_acc['batchw']['loss_discr_0'][-1], |
|
|
410 |
'discr_1_batchw': metrics_acc['batchw']['loss_discr_1'][-1], |
|
|
411 |
'discr_sum_batchw': |
|
|
412 |
(metrics_acc['batchw']['loss_discr_0'][-1] + |
|
|
413 |
metrics_acc['batchw']['loss_discr_1'][-1]), |
|
|
414 |
'segm_batchw': metrics_acc['batchw']['loss_segm'][-1], |
|
|
415 |
'advers_batchw': |
|
|
416 |
metrics_acc['batchw']['loss_advers'][-1], |
|
|
417 |
'total_batchw': metrics_acc['batchw']['loss_total'][-1], |
|
|
418 |
}, global_step=(epoch_idx * steps_total + step_idx)) |
|
|
419 |
|
|
|
420 |
# ------------------ Calculate metrics ------------------- |
|
|
421 |
|
|
|
422 |
ys_pred_arg_np_oai = torch.argmax(ys_pred_softmax_oai, 1).to('cpu').numpy() |
|
|
423 |
ys_true_arg_np_oai = ys_true_arg_oai.to('cpu').numpy() |
|
|
424 |
|
|
|
425 |
metrics_acc['datasetw']['cm_oai'] += confusion_matrix( |
|
|
426 |
ys_pred_arg_np_oai, ys_true_arg_np_oai, |
|
|
427 |
self.config['output_channels']) |
|
|
428 |
|
|
|
429 |
# Don't consider repeating entries for the metrics calculation |
|
|
430 |
if step_idx < steps_okoa: |
|
|
431 |
xs_okoa, ys_true_okoa = data_batch_okoa['xs'], data_batch_okoa['ys'] |
|
|
432 |
fnames_acc['okoa'].extend(data_batch_okoa['path_image']) |
|
|
433 |
xs_okoa = xs_okoa.to(maybe_gpu) |
|
|
434 |
|
|
|
435 |
ys_pred_okoa = self.models['segm'](xs_okoa) |
|
|
436 |
|
|
|
437 |
ys_true_arg_okoa = torch.argmax(ys_true_okoa.long().to(maybe_gpu), dim=1) |
|
|
438 |
ys_pred_softmax_okoa = torch_fn.softmax(ys_pred_okoa, dim=1) |
|
|
439 |
|
|
|
440 |
ys_pred_arg_np_okoa = torch.argmax(ys_pred_softmax_okoa, 1).to('cpu').numpy() |
|
|
441 |
ys_true_arg_np_okoa = ys_true_arg_okoa.to('cpu').numpy() |
|
|
442 |
|
|
|
443 |
metrics_acc['datasetw']['cm_okoa'] += confusion_matrix( |
|
|
444 |
ys_pred_arg_np_okoa, ys_true_arg_np_okoa, |
|
|
445 |
self.config['output_channels']) |
|
|
446 |
|
|
|
447 |
prog_bar.update(1) |
|
|
448 |
|
|
|
449 |
del [loader_oai_iter_old, loader_okoa_iter_old, loader_maknee_iter_old] |
|
|
450 |
gc.collect() |
|
|
451 |
|
|
|
452 |
for k, v in metrics_acc['samplew'].items(): |
|
|
453 |
metrics_acc['samplew'][k] = np.asarray(v) |
|
|
454 |
metrics_acc['datasetw']['dice_score_oai'] = np.asarray( |
|
|
455 |
dice_score_from_cm(metrics_acc['datasetw']['cm_oai'])) |
|
|
456 |
metrics_acc['datasetw']['dice_score_okoa'] = np.asarray( |
|
|
457 |
dice_score_from_cm(metrics_acc['datasetw']['cm_okoa'])) |
|
|
458 |
return metrics_acc, fnames_acc |
|
|
459 |
|
|
|
460 |
def fit(self, loaders): |
|
|
461 |
epoch_idx_best = -1 |
|
|
462 |
loss_best = float('inf') |
|
|
463 |
metrics_train_best = dict() |
|
|
464 |
fnames_train_best = [] |
|
|
465 |
metrics_val_best = dict() |
|
|
466 |
fnames_val_best = [] |
|
|
467 |
|
|
|
468 |
for epoch_idx in range(self.config['epoch_num']): |
|
|
469 |
self.models = {n: m.train() for n, m in self.models.items()} |
|
|
470 |
metrics_train, fnames_train = \ |
|
|
471 |
self.run_one_epoch(epoch_idx, loaders) |
|
|
472 |
|
|
|
473 |
# Process the accumulated metrics |
|
|
474 |
for k, v in metrics_train['batchw'].items(): |
|
|
475 |
if k.startswith('loss'): |
|
|
476 |
metrics_train['datasetw'][k] = np.mean(np.asarray(v)) |
|
|
477 |
else: |
|
|
478 |
logger.warning(f'Non-processed batch-wise entry: {k}') |
|
|
479 |
|
|
|
480 |
self.models = {n: m.eval() for n, m in self.models.items()} |
|
|
481 |
metrics_val, fnames_val = \ |
|
|
482 |
self.run_one_epoch(epoch_idx, loaders) |
|
|
483 |
|
|
|
484 |
# Process the accumulated metrics |
|
|
485 |
for k, v in metrics_val['batchw'].items(): |
|
|
486 |
if k.startswith('loss'): |
|
|
487 |
metrics_val['datasetw'][k] = np.mean(np.asarray(v)) |
|
|
488 |
else: |
|
|
489 |
logger.warning(f'Non-processed batch-wise entry: {k}') |
|
|
490 |
|
|
|
491 |
# Learning rate update |
|
|
492 |
for s, m in self.lr_update_rule.items(): |
|
|
493 |
if epoch_idx == s: |
|
|
494 |
for name, optim in self.optimizers.items(): |
|
|
495 |
for param_group in optim.param_groups: |
|
|
496 |
param_group['lr'] *= m |
|
|
497 |
|
|
|
498 |
# Add console logging |
|
|
499 |
logger.info(f'Epoch: {epoch_idx}') |
|
|
500 |
for subset, metrics in (('train', metrics_train), |
|
|
501 |
('val', metrics_val)): |
|
|
502 |
logger.info(f'{subset} metrics:') |
|
|
503 |
for k, v in metrics['datasetw'].items(): |
|
|
504 |
logger.info(f'{k}: \n{v}') |
|
|
505 |
|
|
|
506 |
# Add TensorBoard logging |
|
|
507 |
for subset, metrics in (('train', metrics_train), |
|
|
508 |
('val', metrics_val)): |
|
|
509 |
# Log only dataset-reduced metrics |
|
|
510 |
for k, v in metrics['datasetw'].items(): |
|
|
511 |
if isinstance(v, np.ndarray): |
|
|
512 |
self.tensorboard.add_scalars( |
|
|
513 |
f'fold_{self.fold_idx}/{k}_{subset}', |
|
|
514 |
{f'class{i}': e for i, e in enumerate(v.ravel().tolist())}, |
|
|
515 |
global_step=epoch_idx) |
|
|
516 |
elif isinstance(v, (str, int, float)): |
|
|
517 |
self.tensorboard.add_scalar( |
|
|
518 |
f'fold_{self.fold_idx}/{k}_{subset}', |
|
|
519 |
float(v), |
|
|
520 |
global_step=epoch_idx) |
|
|
521 |
else: |
|
|
522 |
logger.warning(f'{k} is of unsupported dtype {v}') |
|
|
523 |
for name, optim in self.optimizers.items(): |
|
|
524 |
for param_group in optim.param_groups: |
|
|
525 |
self.tensorboard.add_scalar( |
|
|
526 |
f'fold_{self.fold_idx}/learning_rate/{name}', |
|
|
527 |
param_group['lr'], |
|
|
528 |
global_step=epoch_idx) |
|
|
529 |
|
|
|
530 |
# Save the model |
|
|
531 |
loss_curr = metrics_val['datasetw']['loss_total'] |
|
|
532 |
if loss_curr < loss_best: |
|
|
533 |
loss_best = loss_curr |
|
|
534 |
epoch_idx_best = epoch_idx |
|
|
535 |
metrics_train_best = metrics_train |
|
|
536 |
metrics_val_best = metrics_val |
|
|
537 |
fnames_train_best = fnames_train |
|
|
538 |
fnames_val_best = fnames_val |
|
|
539 |
|
|
|
540 |
self.handlers_ckpt['segm'].save_new_ckpt( |
|
|
541 |
model=self.models['segm'], |
|
|
542 |
model_name=self.config['model_segm'], |
|
|
543 |
fold_idx=self.fold_idx, |
|
|
544 |
epoch_idx=epoch_idx) |
|
|
545 |
self.handlers_ckpt['discr'].save_new_ckpt( |
|
|
546 |
model=self.models['discr'], |
|
|
547 |
model_name=self.config['model_discr'], |
|
|
548 |
fold_idx=self.fold_idx, |
|
|
549 |
epoch_idx=epoch_idx) |
|
|
550 |
|
|
|
551 |
msg = (f'Finished fold {self.fold_idx} ' |
|
|
552 |
f'with the best loss {loss_best:.5f} ' |
|
|
553 |
f'on epoch {epoch_idx_best}, ' |
|
|
554 |
f'weights: ({self.paths_weights_fold})') |
|
|
555 |
logger.info(msg) |
|
|
556 |
return (metrics_train_best, fnames_train_best, |
|
|
557 |
metrics_val_best, fnames_val_best) |
|
|
558 |
|
|
|
559 |
|
|
|
560 |
@click.command() |
|
|
561 |
@click.option('--path_data_root', default='../../data') |
|
|
562 |
@click.option('--path_experiment_root', default='../../results/temporary') |
|
|
563 |
@click.option('--model_segm', default='unet_lext') |
|
|
564 |
@click.option('--center_depth', default=1, type=int) |
|
|
565 |
@click.option('--model_discr', default='discriminator_a') |
|
|
566 |
@click.option('--pretrained', is_flag=True) |
|
|
567 |
@click.option('--path_pretrained_segm', type=str, help='Path to .pth file') |
|
|
568 |
@click.option('--restore_weights', is_flag=True) |
|
|
569 |
@click.option('--input_channels', default=1, type=int) |
|
|
570 |
@click.option('--output_channels', default=1, type=int) |
|
|
571 |
@click.option('--mask_mode', default='all_unitibial_unimeniscus', type=str) |
|
|
572 |
@click.option('--sample_mode', default='x_y', type=str) |
|
|
573 |
@click.option('--loss_segm', default='multi_ce_loss') |
|
|
574 |
@click.option('--lr_segm', default=0.0001, type=float) |
|
|
575 |
@click.option('--lr_discr', default=0.0001, type=float) |
|
|
576 |
@click.option('--wd_segm', default=5e-5, type=float) |
|
|
577 |
@click.option('--wd_discr', default=5e-5, type=float) |
|
|
578 |
@click.option('--optimizer_segm', default='adam') |
|
|
579 |
@click.option('--optimizer_discr', default='adam') |
|
|
580 |
@click.option('--batch_size', default=64, type=int) |
|
|
581 |
@click.option('--epoch_size', default=1.0, type=float) |
|
|
582 |
@click.option('--epoch_num', default=2, type=int) |
|
|
583 |
@click.option('--fold_num', default=5, type=int) |
|
|
584 |
@click.option('--fold_idx', default=-1, type=int) |
|
|
585 |
@click.option('--fold_idx_ignore', multiple=True, type=int) |
|
|
586 |
@click.option('--num_workers', default=1, type=int) |
|
|
587 |
@click.option('--seed_trainval_test', default=0, type=int) |
|
|
588 |
@click.option('--with_mixup', is_flag=True) |
|
|
589 |
@click.option('--mixup_alpha', default=1, type=float) |
|
|
590 |
def main(**config): |
|
|
591 |
config['path_data_root'] = os.path.abspath(config['path_data_root']) |
|
|
592 |
config['path_experiment_root'] = os.path.abspath(config['path_experiment_root']) |
|
|
593 |
|
|
|
594 |
config['path_weights'] = os.path.join(config['path_experiment_root'], 'weights') |
|
|
595 |
config['path_logs'] = os.path.join(config['path_experiment_root'], 'logs_train') |
|
|
596 |
os.makedirs(config['path_weights'], exist_ok=True) |
|
|
597 |
os.makedirs(config['path_logs'], exist_ok=True) |
|
|
598 |
|
|
|
599 |
logging_fh = logging.FileHandler( |
|
|
600 |
os.path.join(config['path_logs'], 'main_{}.log'.format(config['fold_idx']))) |
|
|
601 |
logging_fh.setLevel(logging.DEBUG) |
|
|
602 |
logger.addHandler(logging_fh) |
|
|
603 |
|
|
|
604 |
# Collect the available and specified sources |
|
|
605 |
sources = sources_from_path(path_data_root=config['path_data_root'], |
|
|
606 |
selection=('oai_imo', 'okoa', 'maknee'), |
|
|
607 |
with_folds=True, |
|
|
608 |
fold_num=config['fold_num'], |
|
|
609 |
seed_trainval_test=config['seed_trainval_test']) |
|
|
610 |
|
|
|
611 |
# Build a list of folds to run on |
|
|
612 |
if config['fold_idx'] == -1: |
|
|
613 |
fold_idcs = list(range(config['fold_num'])) |
|
|
614 |
else: |
|
|
615 |
fold_idcs = [config['fold_idx'], ] |
|
|
616 |
for g in config['fold_idx_ignore']: |
|
|
617 |
fold_idcs = [i for i in fold_idcs if i != g] |
|
|
618 |
|
|
|
619 |
# Train each fold separately |
|
|
620 |
fold_scores = dict() |
|
|
621 |
|
|
|
622 |
# Use straightforward fold allocation strategy |
|
|
623 |
folds = list(zip(sources['oai_imo']['trainval_folds'], |
|
|
624 |
sources['okoa']['trainval_folds'], |
|
|
625 |
sources['maknee']['trainval_folds'])) |
|
|
626 |
|
|
|
627 |
for fold_idx, idcs_subsets in enumerate(folds): |
|
|
628 |
if fold_idx not in fold_idcs: |
|
|
629 |
continue |
|
|
630 |
logger.info(f'Training fold {fold_idx}') |
|
|
631 |
|
|
|
632 |
(sources['oai_imo']['train_idcs'], sources['oai_imo']['val_idcs']) = idcs_subsets[0] |
|
|
633 |
(sources['okoa']['train_idcs'], sources['okoa']['val_idcs']) = idcs_subsets[1] |
|
|
634 |
(sources['maknee']['train_idcs'], sources['maknee']['val_idcs']) = idcs_subsets[2] |
|
|
635 |
|
|
|
636 |
sources['oai_imo']['train_df'] = sources['oai_imo']['trainval_df'].iloc[sources['oai_imo']['train_idcs']] |
|
|
637 |
sources['oai_imo']['val_df'] = sources['oai_imo']['trainval_df'].iloc[sources['oai_imo']['val_idcs']] |
|
|
638 |
sources['okoa']['train_df'] = sources['okoa']['trainval_df'].iloc[sources['okoa']['train_idcs']] |
|
|
639 |
sources['okoa']['val_df'] = sources['okoa']['trainval_df'].iloc[sources['okoa']['val_idcs']] |
|
|
640 |
sources['maknee']['train_df'] = sources['maknee']['trainval_df'].iloc[sources['maknee']['train_idcs']] |
|
|
641 |
sources['maknee']['val_df'] = sources['maknee']['trainval_df'].iloc[sources['maknee']['val_idcs']] |
|
|
642 |
|
|
|
643 |
for n, s in sources.items(): |
|
|
644 |
logger.info('Made {} train-val split, number of samples: {}, {}' |
|
|
645 |
.format(n, len(s['train_df']), len(s['val_df']))) |
|
|
646 |
|
|
|
647 |
datasets = defaultdict(dict) |
|
|
648 |
|
|
|
649 |
datasets['oai_imo']['train'] = DatasetOAIiMoSagittal2d( |
|
|
650 |
df_meta=sources['oai_imo']['train_df'], |
|
|
651 |
mask_mode=config['mask_mode'], |
|
|
652 |
sample_mode=config['sample_mode'], |
|
|
653 |
transforms=[ |
|
|
654 |
PercentileClippingAndToFloat(cut_min=10, cut_max=99), |
|
|
655 |
CenterCrop(height=300, width=300), |
|
|
656 |
HorizontalFlip(prob=.5), |
|
|
657 |
GammaCorrection(gamma_range=(0.5, 1.5), prob=.5), |
|
|
658 |
OneOf([ |
|
|
659 |
DualCompose([ |
|
|
660 |
Scale(ratio_range=(0.7, 0.8), prob=1.), |
|
|
661 |
Scale(ratio_range=(1.5, 1.6), prob=1.), |
|
|
662 |
]), |
|
|
663 |
NoTransform() |
|
|
664 |
]), |
|
|
665 |
Crop(output_size=(300, 300)), |
|
|
666 |
BilateralFilter(d=5, sigma_color=50, sigma_space=50, prob=.3), |
|
|
667 |
Normalize(mean=0.252699, std=0.251142), |
|
|
668 |
ToTensor(), |
|
|
669 |
]) |
|
|
670 |
datasets['okoa']['train'] = DatasetOKOASagittal2d( |
|
|
671 |
df_meta=sources['okoa']['train_df'], |
|
|
672 |
mask_mode='background_femoral_unitibial', |
|
|
673 |
sample_mode=config['sample_mode'], |
|
|
674 |
transforms=[ |
|
|
675 |
PercentileClippingAndToFloat(cut_min=10, cut_max=99), |
|
|
676 |
CenterCrop(height=300, width=300), |
|
|
677 |
HorizontalFlip(prob=.5), |
|
|
678 |
GammaCorrection(gamma_range=(0.5, 1.5), prob=.5), |
|
|
679 |
OneOf([ |
|
|
680 |
DualCompose([ |
|
|
681 |
Scale(ratio_range=(0.7, 0.8), prob=1.), |
|
|
682 |
Scale(ratio_range=(1.5, 1.6), prob=1.), |
|
|
683 |
]), |
|
|
684 |
NoTransform() |
|
|
685 |
]), |
|
|
686 |
Crop(output_size=(300, 300)), |
|
|
687 |
BilateralFilter(d=5, sigma_color=50, sigma_space=50, prob=.3), |
|
|
688 |
|
|
|
689 |
Normalize(mean=0.252699, std=0.251142), |
|
|
690 |
ToTensor(), |
|
|
691 |
]) |
|
|
692 |
datasets['maknee']['train'] = DatasetMAKNEESagittal2d( |
|
|
693 |
df_meta=sources['maknee']['train_df'], |
|
|
694 |
mask_mode='', |
|
|
695 |
sample_mode=config['sample_mode'], |
|
|
696 |
transforms=[ |
|
|
697 |
PercentileClippingAndToFloat(cut_min=10, cut_max=99), |
|
|
698 |
CenterCrop(height=300, width=300), |
|
|
699 |
HorizontalFlip(prob=.5), |
|
|
700 |
GammaCorrection(gamma_range=(0.5, 1.5), prob=.5), |
|
|
701 |
OneOf([ |
|
|
702 |
DualCompose([ |
|
|
703 |
Scale(ratio_range=(0.7, 0.8), prob=1.), |
|
|
704 |
Scale(ratio_range=(1.5, 1.6), prob=1.), |
|
|
705 |
]), |
|
|
706 |
NoTransform() |
|
|
707 |
]), |
|
|
708 |
Crop(output_size=(300, 300)), |
|
|
709 |
BilateralFilter(d=5, sigma_color=50, sigma_space=50, prob=.3), |
|
|
710 |
Normalize(mean=0.252699, std=0.251142), |
|
|
711 |
ToTensor(), |
|
|
712 |
]) |
|
|
713 |
datasets['oai_imo']['val'] = DatasetOAIiMoSagittal2d( |
|
|
714 |
df_meta=sources['oai_imo']['val_df'], |
|
|
715 |
mask_mode=config['mask_mode'], |
|
|
716 |
sample_mode=config['sample_mode'], |
|
|
717 |
transforms=[ |
|
|
718 |
PercentileClippingAndToFloat(cut_min=10, cut_max=99), |
|
|
719 |
CenterCrop(height=300, width=300), |
|
|
720 |
Normalize(mean=0.252699, std=0.251142), |
|
|
721 |
ToTensor() |
|
|
722 |
]) |
|
|
723 |
datasets['okoa']['val'] = DatasetOKOASagittal2d( |
|
|
724 |
df_meta=sources['okoa']['val_df'], |
|
|
725 |
mask_mode='background_femoral_unitibial', |
|
|
726 |
sample_mode=config['sample_mode'], |
|
|
727 |
transforms=[ |
|
|
728 |
PercentileClippingAndToFloat(cut_min=10, cut_max=99), |
|
|
729 |
CenterCrop(height=300, width=300), |
|
|
730 |
Normalize(mean=0.252699, std=0.251142), |
|
|
731 |
ToTensor() |
|
|
732 |
]) |
|
|
733 |
datasets['maknee']['val'] = DatasetMAKNEESagittal2d( |
|
|
734 |
df_meta=sources['maknee']['val_df'], |
|
|
735 |
mask_mode='', |
|
|
736 |
sample_mode=config['sample_mode'], |
|
|
737 |
transforms=[ |
|
|
738 |
PercentileClippingAndToFloat(cut_min=10, cut_max=99), |
|
|
739 |
CenterCrop(height=300, width=300), |
|
|
740 |
Normalize(mean=0.252699, std=0.251142), |
|
|
741 |
ToTensor() |
|
|
742 |
]) |
|
|
743 |
|
|
|
744 |
loaders = defaultdict(dict) |
|
|
745 |
|
|
|
746 |
loaders['oai_imo']['train'] = DataLoader( |
|
|
747 |
datasets['oai_imo']['train'], |
|
|
748 |
batch_size=int(config['batch_size'] / 2), |
|
|
749 |
shuffle=True, |
|
|
750 |
num_workers=config['num_workers'], |
|
|
751 |
drop_last=True) |
|
|
752 |
loaders['oai_imo']['val'] = DataLoader( |
|
|
753 |
datasets['oai_imo']['val'], |
|
|
754 |
batch_size=int(config['batch_size'] / 2), |
|
|
755 |
shuffle=False, |
|
|
756 |
num_workers=config['num_workers'], |
|
|
757 |
drop_last=True) |
|
|
758 |
loaders['okoa']['train'] = DataLoader( |
|
|
759 |
datasets['okoa']['train'], |
|
|
760 |
batch_size=int(config['batch_size'] / 2), |
|
|
761 |
shuffle=True, |
|
|
762 |
num_workers=config['num_workers'], |
|
|
763 |
drop_last=True) |
|
|
764 |
loaders['okoa']['val'] = DataLoader( |
|
|
765 |
datasets['okoa']['val'], |
|
|
766 |
batch_size=int(config['batch_size'] / 2), |
|
|
767 |
shuffle=False, |
|
|
768 |
num_workers=config['num_workers'], |
|
|
769 |
drop_last=True) |
|
|
770 |
loaders['maknee']['train'] = DataLoader( |
|
|
771 |
datasets['maknee']['train'], |
|
|
772 |
batch_size=int(config['batch_size'] / 2), |
|
|
773 |
shuffle=True, |
|
|
774 |
num_workers=config['num_workers'], |
|
|
775 |
drop_last=True) |
|
|
776 |
loaders['maknee']['val'] = DataLoader( |
|
|
777 |
datasets['maknee']['val'], |
|
|
778 |
batch_size=int(config['batch_size'] / 2), |
|
|
779 |
shuffle=False, |
|
|
780 |
num_workers=config['num_workers'], |
|
|
781 |
drop_last=True) |
|
|
782 |
|
|
|
783 |
trainer = ModelTrainer(config=config, fold_idx=fold_idx) |
|
|
784 |
|
|
|
785 |
tmp = trainer.fit(loaders=loaders) |
|
|
786 |
metrics_train, fnames_train, metrics_val, fnames_val = tmp |
|
|
787 |
|
|
|
788 |
fold_scores[fold_idx] = (metrics_val['datasetw']['dice_score_oai'], |
|
|
789 |
metrics_val['datasetw']['dice_score_okoa']) |
|
|
790 |
trainer.tensorboard.close() |
|
|
791 |
logger.info(f'Fold scores:\n{repr(fold_scores)}') |
|
|
792 |
|
|
|
793 |
|
|
|
794 |
if __name__ == '__main__': |
|
|
795 |
main() |