Diff of /rocaseg/train_uda1.py [000000] .. [6969be]

Switch to unified view

a b/rocaseg/train_uda1.py
1
import os
2
import logging
3
from collections import defaultdict
4
import gc
5
import click
6
import resource
7
8
import numpy as np
9
import cv2
10
11
import torch
12
import torch.nn.functional as torch_fn
13
from torch import nn
14
from torch.utils.data.dataloader import DataLoader
15
from torch.utils.tensorboard import SummaryWriter
16
from tqdm import tqdm
17
18
from rocaseg.datasets import (DatasetOAIiMoSagittal2d,
19
                              DatasetOKOASagittal2d,
20
                              DatasetMAKNEESagittal2d,
21
                              sources_from_path)
22
from rocaseg.models import dict_models
23
from rocaseg.components import (dict_losses, confusion_matrix, dice_score_from_cm,
24
                                dict_optimizers, CheckpointHandler)
25
from rocaseg.preproc import *
26
from rocaseg.repro import set_ultimate_seed
27
from rocaseg.components.mixup import mixup_criterion, mixup_data
28
29
30
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
31
resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))
32
33
cv2.ocl.setUseOpenCL(False)
34
cv2.setNumThreads(0)
35
36
logging.basicConfig()
37
logger = logging.getLogger('train')
38
logger.setLevel(logging.DEBUG)
39
40
set_ultimate_seed()
41
42
if torch.cuda.is_available():
43
    maybe_gpu = 'cuda'
44
else:
45
    maybe_gpu = 'cpu'
46
47
48
class ModelTrainer:
49
    def __init__(self, config, fold_idx=None):
50
        self.config = config
51
        self.fold_idx = fold_idx
52
53
        self.paths_weights_fold = dict()
54
        self.paths_weights_fold['segm'] = \
55
            os.path.join(config['path_weights'], 'segm', f'fold_{self.fold_idx}')
56
        os.makedirs(self.paths_weights_fold['segm'], exist_ok=True)
57
        self.paths_weights_fold['discr'] = \
58
            os.path.join(config['path_weights'], 'discr', f'fold_{self.fold_idx}')
59
        os.makedirs(self.paths_weights_fold['discr'], exist_ok=True)
60
61
        self.path_logs_fold = \
62
            os.path.join(config['path_logs'], f'fold_{self.fold_idx}')
63
        os.makedirs(self.path_logs_fold, exist_ok=True)
64
65
        self.handlers_ckpt = dict()
66
        self.handlers_ckpt['segm'] = CheckpointHandler(self.paths_weights_fold['segm'])
67
        self.handlers_ckpt['discr'] = CheckpointHandler(self.paths_weights_fold['discr'])
68
69
        paths_ckpt_sel = dict()
70
        paths_ckpt_sel['segm'] = self.handlers_ckpt['segm'].get_last_ckpt()
71
        paths_ckpt_sel['discr'] = self.handlers_ckpt['discr'].get_last_ckpt()
72
73
        # Initialize and configure the models
74
        self.models = dict()
75
        self.models['segm'] = (dict_models[config['model_segm']]
76
                               (input_channels=self.config['input_channels'],
77
                                output_channels=self.config['output_channels'],
78
                                center_depth=self.config['center_depth'],
79
                                pretrained=self.config['pretrained'],
80
                                path_pretrained=self.config['path_pretrained_segm'],
81
                                restore_weights=self.config['restore_weights'],
82
                                path_weights=paths_ckpt_sel['segm']))
83
        self.models['segm'] = nn.DataParallel(self.models['segm'])
84
        self.models['segm'] = self.models['segm'].to(maybe_gpu)
85
86
        self.models['discr'] = (dict_models[config['model_discr']]
87
                                (input_channels=self.config['output_channels'],
88
                                 output_channels=1,
89
                                 pretrained=self.config['pretrained'],
90
                                 restore_weights=self.config['restore_weights'],
91
                                 path_weights=paths_ckpt_sel['discr']))
92
        self.models['discr'] = nn.DataParallel(self.models['discr'])
93
        self.models['discr'] = self.models['discr'].to(maybe_gpu)
94
95
        # Configure the training
96
        self.optimizers = dict()
97
        self.optimizers['segm'] = (dict_optimizers['adam'](
98
            self.models['segm'].parameters(),
99
            lr=self.config['lr_segm'],
100
            weight_decay=self.config['wd_segm']))
101
        self.optimizers['discr'] = (dict_optimizers['adam'](
102
            self.models['discr'].parameters(),
103
            lr=self.config['lr_discr'],
104
            weight_decay=self.config['wd_discr']))
105
106
        self.lr_update_rule = {25: 0.1}
107
108
        self.losses = dict()
109
        self.losses['segm'] = dict_losses[self.config['loss_segm']](
110
            num_classes=self.config['output_channels'],
111
        )
112
        self.losses['advers'] = dict_losses['bce_loss']()
113
        self.losses['discr'] = dict_losses['bce_loss']()
114
115
        self.losses['segm'] = self.losses['segm'].to(maybe_gpu)
116
        self.losses['advers'] = self.losses['advers'].to(maybe_gpu)
117
        self.losses['discr'] = self.losses['discr'].to(maybe_gpu)
118
119
        self.tensorboard = SummaryWriter(self.path_logs_fold)
120
121
    def run_one_epoch(self, epoch_idx, loaders):
122
        COEFF_DISCR = 1
123
        COEFF_SEGM = 1
124
        COEFF_ADVERS = 0.001
125
126
        fnames_acc = defaultdict(list)
127
        metrics_acc = dict()
128
        metrics_acc['samplew'] = defaultdict(list)
129
        metrics_acc['batchw'] = defaultdict(list)
130
        metrics_acc['datasetw'] = defaultdict(list)
131
        metrics_acc['datasetw']['cm_oai'] = \
132
            np.zeros((self.config['output_channels'],) * 2, dtype=np.uint32)
133
        metrics_acc['datasetw']['cm_okoa'] = \
134
            np.zeros((self.config['output_channels'],) * 2, dtype=np.uint32)
135
136
        prog_bar_params = {'postfix': {'epoch': epoch_idx}, }
137
138
        if self.models['segm'].training and self.models['discr'].training:
139
            # ------------------------ Training regime ------------------------
140
            loader_oai = loaders['oai_imo']['train']
141
            loader_maknee = loaders['maknee']['train']
142
143
            steps_oai, steps_maknee = len(loader_oai), len(loader_maknee)
144
            steps_total = steps_oai
145
            prog_bar_params.update({'total': steps_total,
146
                                    'desc': f'Train, epoch {epoch_idx}'})
147
148
            loader_oai_iter = iter(loader_oai)
149
            loader_maknee_iter = iter(loader_maknee)
150
151
            loader_oai_iter_old = None
152
            loader_maknee_iter_old = None
153
154
            with tqdm(**prog_bar_params) as prog_bar:
155
                for step_idx in range(steps_total):
156
                    self.optimizers['segm'].zero_grad()
157
                    self.optimizers['discr'].zero_grad()
158
159
                    metrics_acc['batchw']['loss_total'].append(0)
160
161
                    try:
162
                        data_batch_oai = next(loader_oai_iter)
163
                    except StopIteration:
164
                        loader_oai_iter_old = loader_oai_iter
165
                        loader_oai_iter = iter(loader_oai)
166
                        data_batch_oai = next(loader_oai_iter)
167
168
                    try:
169
                        data_batch_maknee = next(loader_maknee_iter)
170
                    except StopIteration:
171
                        loader_maknee_iter_old = loader_maknee_iter
172
                        loader_maknee_iter = iter(loader_maknee)
173
                        data_batch_maknee = next(loader_maknee_iter)
174
175
                    xs_oai, ys_true_oai = data_batch_oai['xs'], data_batch_oai['ys']
176
                    fnames_acc['oai'].extend(data_batch_oai['path_image'])
177
                    xs_oai = xs_oai.to(maybe_gpu)
178
                    ys_true_arg_oai = torch.argmax(ys_true_oai.long().to(maybe_gpu), dim=1)
179
180
                    xs_maknee, _ = data_batch_maknee['xs'], data_batch_maknee['ys']
181
                    fnames_acc['maknee'].extend(data_batch_maknee['path_image'])
182
                    xs_maknee = xs_maknee.to(maybe_gpu)
183
184
                    # -------------- Train discriminator network -------------
185
                    # With source
186
                    ys_pred_oai = self.models['segm'](xs_oai)
187
                    ys_pred_softmax_oai = torch_fn.softmax(ys_pred_oai, dim=1)
188
189
                    zs_pred_oai = self.models['discr'](ys_pred_softmax_oai)
190
191
                    # Use 0 as a label for the source domain
192
                    loss_discr_0 = self.losses['discr'](
193
                        input=zs_pred_oai,
194
                        target=torch.zeros_like(zs_pred_oai, device=maybe_gpu))
195
                    loss_discr_0 = loss_discr_0 / 2 * COEFF_DISCR
196
                    loss_discr_0.backward(retain_graph=True)
197
                    metrics_acc['batchw']['loss_discr_0'].append(loss_discr_0.item())
198
                    metrics_acc['batchw']['loss_total'][-1] += \
199
                        metrics_acc['batchw']['loss_discr_0'][-1]
200
201
                    # With target
202
                    self.models['segm'] = self.models['segm'].eval()
203
                    ys_pred_maknee = self.models['segm'](xs_maknee)
204
                    self.models['segm'] = self.models['segm'].train()
205
206
                    ys_pred_softmax_maknee = torch_fn.softmax(ys_pred_maknee, dim=1)
207
                    zs_pred_maknee = self.models['discr'](ys_pred_softmax_maknee)
208
209
                    # Use 1 as a label for the target domain
210
                    loss_discr_1 = self.losses['discr'](
211
                        input=zs_pred_maknee,
212
                        target=torch.ones_like(zs_pred_maknee, device=maybe_gpu))
213
                    loss_discr_1 = loss_discr_1 / 2 * COEFF_DISCR
214
                    loss_discr_1.backward()
215
                    metrics_acc['batchw']['loss_discr_1'].append(loss_discr_1.item())
216
                    metrics_acc['batchw']['loss_total'][-1] += \
217
                        metrics_acc['batchw']['loss_discr_1'][-1]
218
219
                    self.models['segm'].zero_grad()
220
                    self.optimizers['discr'].step()
221
                    self.models['discr'].zero_grad()
222
223
                    # ---------------- Train segmentation network ------------
224
                    # With source
225
                    if not self.config['with_mixup']:
226
                        ys_pred_oai = self.models['segm'](xs_oai)
227
                        loss_segm = self.losses['segm'](input_=ys_pred_oai,
228
                                                        target=ys_true_arg_oai)
229
                    else:
230
                        xs_mixup, ys_mixup_a, ys_mixup_b, lambda_mixup = mixup_data(
231
                            x=xs_oai, y=ys_true_arg_oai,
232
                            alpha=self.config['mixup_alpha'], device=maybe_gpu)
233
                        ys_pred_oai = self.models['segm'](xs_mixup)
234
                        loss_segm = mixup_criterion(criterion=self.losses['segm'],
235
                                                    pred=ys_pred_oai,
236
                                                    y_a=ys_mixup_a,
237
                                                    y_b=ys_mixup_b,
238
                                                    lam=lambda_mixup)
239
240
                    loss_segm.backward(retain_graph=True)
241
                    loss_segm = loss_segm * COEFF_SEGM
242
                    metrics_acc['batchw']['loss_segm'].append(loss_segm.item())
243
                    metrics_acc['batchw']['loss_total'][-1] += \
244
                        metrics_acc['batchw']['loss_segm'][-1]
245
246
                    # With target
247
                    self.models['segm'] = self.models['segm'].eval()
248
                    ys_pred_maknee = self.models['segm'](xs_maknee)
249
                    self.models['segm'] = self.models['segm'].train()
250
251
                    ys_pred_softmax_maknee = torch_fn.softmax(ys_pred_maknee, dim=1)
252
                    zs_pred_maknee = self.models['discr'](ys_pred_softmax_maknee)
253
254
                    # Use 0 as a label for the source domain
255
                    loss_advers = self.losses['advers'](
256
                        input=zs_pred_maknee,
257
                        target=torch.zeros_like(zs_pred_maknee, device=maybe_gpu))
258
                    loss_advers = loss_advers * COEFF_ADVERS
259
                    loss_advers.backward()
260
                    metrics_acc['batchw']['loss_advers'].append(loss_advers.item())
261
                    metrics_acc['batchw']['loss_total'][-1] += \
262
                        metrics_acc['batchw']['loss_advers'][-1]
263
264
                    self.models['discr'].zero_grad()
265
                    self.optimizers['segm'].step()
266
267
                    if step_idx % 10 == 0:
268
                        self.tensorboard.add_scalars(
269
                            f'fold_{self.fold_idx}/losses_train',
270
                            {'discr_0_batchw': metrics_acc['batchw']['loss_discr_0'][-1],
271
                             'discr_1_batchw': metrics_acc['batchw']['loss_discr_1'][-1],
272
                             'discr_sum_batchw':
273
                                 (metrics_acc['batchw']['loss_discr_0'][-1] +
274
                                  metrics_acc['batchw']['loss_discr_1'][-1]),
275
                             'segm_batchw': metrics_acc['batchw']['loss_segm'][-1],
276
                             'advers_batchw':
277
                                 metrics_acc['batchw']['loss_advers'][-1],
278
                             'total_batchw': metrics_acc['batchw']['loss_total'][-1],
279
                             }, global_step=(epoch_idx * steps_total + step_idx))
280
281
                    prog_bar.update(1)
282
283
            del [loader_oai_iter_old, loader_maknee_iter_old]
284
            gc.collect()
285
        else:
286
            # ----------------------- Validation regime -----------------------
287
            loader_oai = loaders['oai_imo']['val']
288
            loader_okoa = loaders['okoa']['val']
289
            loader_maknee = loaders['maknee']['val']
290
291
            steps_oai, steps_okoa, steps_maknee = len(loader_oai), len(loader_okoa), len(loader_maknee)
292
            steps_total = steps_oai
293
            prog_bar_params.update({'total': steps_total,
294
                                    'desc': f'Validate, epoch {epoch_idx}'})
295
296
            loader_oai_iter = iter(loader_oai)
297
            loader_okoa_iter = iter(loader_okoa)
298
            loader_maknee_iter = iter(loader_maknee)
299
300
            loader_oai_iter_old = None
301
            loader_okoa_iter_old = None
302
            loader_maknee_iter_old = None
303
304
            with torch.no_grad(), tqdm(**prog_bar_params) as prog_bar:
305
                for step_idx in range(steps_total):
306
                    metrics_acc['batchw']['loss_total'].append(0)
307
308
                    try:
309
                        data_batch_oai = next(loader_oai_iter)
310
                    except StopIteration:
311
                        loader_oai_iter_old = loader_oai_iter
312
                        loader_oai_iter = iter(loader_oai)
313
                        data_batch_oai = next(loader_oai_iter)
314
315
                    try:
316
                        data_batch_okoa = next(loader_okoa_iter)
317
                    except StopIteration:
318
                        loader_okoa_iter_old = loader_okoa_iter
319
                        loader_okoa_iter = iter(loader_okoa)
320
                        data_batch_okoa = next(loader_okoa_iter)
321
322
                    try:
323
                        data_batch_maknee = next(loader_maknee_iter)
324
                    except StopIteration:
325
                        loader_maknee_iter_old = loader_maknee_iter
326
                        loader_maknee_iter = iter(loader_maknee)
327
                        data_batch_maknee = next(loader_maknee_iter)
328
329
                    xs_oai, ys_true_oai = data_batch_oai['xs'], data_batch_oai['ys']
330
                    fnames_acc['oai'].extend(data_batch_oai['path_image'])
331
                    xs_oai = xs_oai.to(maybe_gpu)
332
                    ys_true_arg_oai = torch.argmax(ys_true_oai.long().to(maybe_gpu), dim=1)
333
334
                    xs_maknee, _ = data_batch_maknee['xs'], data_batch_maknee['ys']
335
                    fnames_acc['maknee'].extend(data_batch_maknee['path_image'])
336
                    xs_maknee = xs_maknee.to(maybe_gpu)
337
338
                    # -------------- Validate discriminator network -------------
339
                    # With source
340
                    ys_pred_oai = self.models['segm'](xs_oai)
341
                    ys_pred_softmax_oai = torch_fn.softmax(ys_pred_oai, dim=1)
342
343
                    zs_pred_oai = self.models['discr'](ys_pred_softmax_oai)
344
345
                    # Use 0 as a label for the source domain
346
                    loss_discr_0 = self.losses['discr'](
347
                        input=zs_pred_oai,
348
                        target=torch.zeros_like(zs_pred_oai, device=maybe_gpu))
349
                    loss_discr_0 = loss_discr_0 / 2 * COEFF_DISCR
350
                    metrics_acc['batchw']['loss_discr_0'].append(loss_discr_0.item())
351
                    metrics_acc['batchw']['loss_total'][-1] += \
352
                        metrics_acc['batchw']['loss_discr_0'][-1]
353
354
                    # With target
355
                    ys_pred_maknee = self.models['segm'](xs_maknee)
356
357
                    ys_pred_softmax_maknee = torch_fn.softmax(ys_pred_maknee, dim=1)
358
                    zs_pred_maknee = self.models['discr'](ys_pred_softmax_maknee)
359
360
                    # Use 1 as a label for the target domain
361
                    loss_discr_1 = self.losses['discr'](
362
                        input=zs_pred_maknee,
363
                        target=torch.ones_like(zs_pred_oai, device=maybe_gpu))
364
                    loss_discr_1 = loss_discr_1 / 2 * COEFF_DISCR
365
                    metrics_acc['batchw']['loss_discr_1'].append(loss_discr_1.item())
366
                    metrics_acc['batchw']['loss_total'][-1] += \
367
                        metrics_acc['batchw']['loss_discr_1'][-1]
368
369
                    # ---------------- Validate segmentation network ------------
370
                    # With source
371
                    if not self.config['with_mixup']:
372
                        ys_pred_oai = self.models['segm'](xs_oai)
373
                        loss_segm = self.losses['segm'](input_=ys_pred_oai,
374
                                                        target=ys_true_arg_oai)
375
                    else:
376
                        xs_mixup, ys_mixup_a, ys_mixup_b, lambda_mixup = mixup_data(
377
                            x=xs_oai, y=ys_true_arg_oai,
378
                            alpha=self.config['mixup_alpha'], device=maybe_gpu)
379
                        ys_pred_oai = self.models['segm'](xs_mixup)
380
                        loss_segm = mixup_criterion(criterion=self.losses['segm'],
381
                                                    pred=ys_pred_oai,
382
                                                    y_a=ys_mixup_a,
383
                                                    y_b=ys_mixup_b,
384
                                                    lam=lambda_mixup)
385
386
                    loss_segm = loss_segm * COEFF_SEGM
387
                    metrics_acc['batchw']['loss_segm'].append(loss_segm.item())
388
                    metrics_acc['batchw']['loss_total'][-1] += \
389
                        metrics_acc['batchw']['loss_segm'][-1]
390
391
                    # With target
392
                    ys_pred_maknee = self.models['segm'](xs_maknee)
393
394
                    ys_pred_softmax_maknee = torch_fn.softmax(ys_pred_maknee, dim=1)
395
                    zs_pred_maknee = self.models['discr'](ys_pred_softmax_maknee)
396
397
                    # Use 0 as a label for the source domain
398
                    loss_advers = self.losses['advers'](
399
                        input=zs_pred_maknee,
400
                        target=torch.zeros_like(zs_pred_maknee, device=maybe_gpu))
401
                    loss_advers = loss_advers * COEFF_ADVERS
402
                    metrics_acc['batchw']['loss_advers'].append(loss_advers.item())
403
                    metrics_acc['batchw']['loss_total'][-1] += \
404
                        metrics_acc['batchw']['loss_advers'][-1]
405
406
                    if step_idx % 10 == 0:
407
                        self.tensorboard.add_scalars(
408
                            f'fold_{self.fold_idx}/losses_val',
409
                            {'discr_0_batchw': metrics_acc['batchw']['loss_discr_0'][-1],
410
                             'discr_1_batchw': metrics_acc['batchw']['loss_discr_1'][-1],
411
                             'discr_sum_batchw':
412
                                 (metrics_acc['batchw']['loss_discr_0'][-1] +
413
                                  metrics_acc['batchw']['loss_discr_1'][-1]),
414
                             'segm_batchw': metrics_acc['batchw']['loss_segm'][-1],
415
                             'advers_batchw':
416
                                 metrics_acc['batchw']['loss_advers'][-1],
417
                             'total_batchw': metrics_acc['batchw']['loss_total'][-1],
418
                             }, global_step=(epoch_idx * steps_total + step_idx))
419
420
                    # ------------------ Calculate metrics -------------------
421
422
                    ys_pred_arg_np_oai = torch.argmax(ys_pred_softmax_oai, 1).to('cpu').numpy()
423
                    ys_true_arg_np_oai = ys_true_arg_oai.to('cpu').numpy()
424
425
                    metrics_acc['datasetw']['cm_oai'] += confusion_matrix(
426
                        ys_pred_arg_np_oai, ys_true_arg_np_oai,
427
                        self.config['output_channels'])
428
429
                    # Don't consider repeating entries for the metrics calculation
430
                    if step_idx < steps_okoa:
431
                        xs_okoa, ys_true_okoa = data_batch_okoa['xs'], data_batch_okoa['ys']
432
                        fnames_acc['okoa'].extend(data_batch_okoa['path_image'])
433
                        xs_okoa = xs_okoa.to(maybe_gpu)
434
435
                        ys_pred_okoa = self.models['segm'](xs_okoa)
436
437
                        ys_true_arg_okoa = torch.argmax(ys_true_okoa.long().to(maybe_gpu), dim=1)
438
                        ys_pred_softmax_okoa = torch_fn.softmax(ys_pred_okoa, dim=1)
439
440
                        ys_pred_arg_np_okoa = torch.argmax(ys_pred_softmax_okoa, 1).to('cpu').numpy()
441
                        ys_true_arg_np_okoa = ys_true_arg_okoa.to('cpu').numpy()
442
443
                        metrics_acc['datasetw']['cm_okoa'] += confusion_matrix(
444
                            ys_pred_arg_np_okoa, ys_true_arg_np_okoa,
445
                            self.config['output_channels'])
446
447
                    prog_bar.update(1)
448
449
            del [loader_oai_iter_old, loader_okoa_iter_old, loader_maknee_iter_old]
450
            gc.collect()
451
452
        for k, v in metrics_acc['samplew'].items():
453
            metrics_acc['samplew'][k] = np.asarray(v)
454
        metrics_acc['datasetw']['dice_score_oai'] = np.asarray(
455
            dice_score_from_cm(metrics_acc['datasetw']['cm_oai']))
456
        metrics_acc['datasetw']['dice_score_okoa'] = np.asarray(
457
            dice_score_from_cm(metrics_acc['datasetw']['cm_okoa']))
458
        return metrics_acc, fnames_acc
459
460
    def fit(self, loaders):
461
        epoch_idx_best = -1
462
        loss_best = float('inf')
463
        metrics_train_best = dict()
464
        fnames_train_best = []
465
        metrics_val_best = dict()
466
        fnames_val_best = []
467
468
        for epoch_idx in range(self.config['epoch_num']):
469
            self.models = {n: m.train() for n, m in self.models.items()}
470
            metrics_train, fnames_train = \
471
                self.run_one_epoch(epoch_idx, loaders)
472
473
            # Process the accumulated metrics
474
            for k, v in metrics_train['batchw'].items():
475
                if k.startswith('loss'):
476
                    metrics_train['datasetw'][k] = np.mean(np.asarray(v))
477
                else:
478
                    logger.warning(f'Non-processed batch-wise entry: {k}')
479
480
            self.models = {n: m.eval() for n, m in self.models.items()}
481
            metrics_val, fnames_val = \
482
                self.run_one_epoch(epoch_idx, loaders)
483
484
            # Process the accumulated metrics
485
            for k, v in metrics_val['batchw'].items():
486
                if k.startswith('loss'):
487
                    metrics_val['datasetw'][k] = np.mean(np.asarray(v))
488
                else:
489
                    logger.warning(f'Non-processed batch-wise entry: {k}')
490
491
            # Learning rate update
492
            for s, m in self.lr_update_rule.items():
493
                if epoch_idx == s:
494
                    for name, optim in self.optimizers.items():
495
                        for param_group in optim.param_groups:
496
                            param_group['lr'] *= m
497
498
            # Add console logging
499
            logger.info(f'Epoch: {epoch_idx}')
500
            for subset, metrics in (('train', metrics_train),
501
                                    ('val', metrics_val)):
502
                logger.info(f'{subset} metrics:')
503
                for k, v in metrics['datasetw'].items():
504
                    logger.info(f'{k}: \n{v}')
505
506
            # Add TensorBoard logging
507
            for subset, metrics in (('train', metrics_train),
508
                                    ('val', metrics_val)):
509
                # Log only dataset-reduced metrics
510
                for k, v in metrics['datasetw'].items():
511
                    if isinstance(v, np.ndarray):
512
                        self.tensorboard.add_scalars(
513
                            f'fold_{self.fold_idx}/{k}_{subset}',
514
                            {f'class{i}': e for i, e in enumerate(v.ravel().tolist())},
515
                            global_step=epoch_idx)
516
                    elif isinstance(v, (str, int, float)):
517
                        self.tensorboard.add_scalar(
518
                            f'fold_{self.fold_idx}/{k}_{subset}',
519
                            float(v),
520
                            global_step=epoch_idx)
521
                    else:
522
                        logger.warning(f'{k} is of unsupported dtype {v}')
523
            for name, optim in self.optimizers.items():
524
                for param_group in optim.param_groups:
525
                    self.tensorboard.add_scalar(
526
                        f'fold_{self.fold_idx}/learning_rate/{name}',
527
                        param_group['lr'],
528
                        global_step=epoch_idx)
529
530
            # Save the model
531
            loss_curr = metrics_val['datasetw']['loss_total']
532
            if loss_curr < loss_best:
533
                loss_best = loss_curr
534
                epoch_idx_best = epoch_idx
535
                metrics_train_best = metrics_train
536
                metrics_val_best = metrics_val
537
                fnames_train_best = fnames_train
538
                fnames_val_best = fnames_val
539
540
                self.handlers_ckpt['segm'].save_new_ckpt(
541
                    model=self.models['segm'],
542
                    model_name=self.config['model_segm'],
543
                    fold_idx=self.fold_idx,
544
                    epoch_idx=epoch_idx)
545
                self.handlers_ckpt['discr'].save_new_ckpt(
546
                    model=self.models['discr'],
547
                    model_name=self.config['model_discr'],
548
                    fold_idx=self.fold_idx,
549
                    epoch_idx=epoch_idx)
550
551
        msg = (f'Finished fold {self.fold_idx} '
552
               f'with the best loss {loss_best:.5f} '
553
               f'on epoch {epoch_idx_best}, '
554
               f'weights: ({self.paths_weights_fold})')
555
        logger.info(msg)
556
        return (metrics_train_best, fnames_train_best,
557
                metrics_val_best, fnames_val_best)
558
559
560
@click.command()
561
@click.option('--path_data_root', default='../../data')
562
@click.option('--path_experiment_root', default='../../results/temporary')
563
@click.option('--model_segm', default='unet_lext')
564
@click.option('--center_depth', default=1, type=int)
565
@click.option('--model_discr', default='discriminator_a')
566
@click.option('--pretrained', is_flag=True)
567
@click.option('--path_pretrained_segm', type=str, help='Path to .pth file')
568
@click.option('--restore_weights', is_flag=True)
569
@click.option('--input_channels', default=1, type=int)
570
@click.option('--output_channels', default=1, type=int)
571
@click.option('--mask_mode', default='all_unitibial_unimeniscus', type=str)
572
@click.option('--sample_mode', default='x_y', type=str)
573
@click.option('--loss_segm', default='multi_ce_loss')
574
@click.option('--lr_segm', default=0.0001, type=float)
575
@click.option('--lr_discr', default=0.0001, type=float)
576
@click.option('--wd_segm', default=5e-5, type=float)
577
@click.option('--wd_discr', default=5e-5, type=float)
578
@click.option('--optimizer_segm', default='adam')
579
@click.option('--optimizer_discr', default='adam')
580
@click.option('--batch_size', default=64, type=int)
581
@click.option('--epoch_size', default=1.0, type=float)
582
@click.option('--epoch_num', default=2, type=int)
583
@click.option('--fold_num', default=5, type=int)
584
@click.option('--fold_idx', default=-1, type=int)
585
@click.option('--fold_idx_ignore', multiple=True, type=int)
586
@click.option('--num_workers', default=1, type=int)
587
@click.option('--seed_trainval_test', default=0, type=int)
588
@click.option('--with_mixup', is_flag=True)
589
@click.option('--mixup_alpha', default=1, type=float)
590
def main(**config):
591
    config['path_data_root'] = os.path.abspath(config['path_data_root'])
592
    config['path_experiment_root'] = os.path.abspath(config['path_experiment_root'])
593
594
    config['path_weights'] = os.path.join(config['path_experiment_root'], 'weights')
595
    config['path_logs'] = os.path.join(config['path_experiment_root'], 'logs_train')
596
    os.makedirs(config['path_weights'], exist_ok=True)
597
    os.makedirs(config['path_logs'], exist_ok=True)
598
599
    logging_fh = logging.FileHandler(
600
        os.path.join(config['path_logs'], 'main_{}.log'.format(config['fold_idx'])))
601
    logging_fh.setLevel(logging.DEBUG)
602
    logger.addHandler(logging_fh)
603
604
    # Collect the available and specified sources
605
    sources = sources_from_path(path_data_root=config['path_data_root'],
606
                                selection=('oai_imo', 'okoa', 'maknee'),
607
                                with_folds=True,
608
                                fold_num=config['fold_num'],
609
                                seed_trainval_test=config['seed_trainval_test'])
610
611
    # Build a list of folds to run on
612
    if config['fold_idx'] == -1:
613
        fold_idcs = list(range(config['fold_num']))
614
    else:
615
        fold_idcs = [config['fold_idx'], ]
616
    for g in config['fold_idx_ignore']:
617
        fold_idcs = [i for i in fold_idcs if i != g]
618
619
    # Train each fold separately
620
    fold_scores = dict()
621
622
    # Use straightforward fold allocation strategy
623
    folds = list(zip(sources['oai_imo']['trainval_folds'],
624
                     sources['okoa']['trainval_folds'],
625
                     sources['maknee']['trainval_folds']))
626
627
    for fold_idx, idcs_subsets in enumerate(folds):
628
        if fold_idx not in fold_idcs:
629
            continue
630
        logger.info(f'Training fold {fold_idx}')
631
632
        (sources['oai_imo']['train_idcs'], sources['oai_imo']['val_idcs']) = idcs_subsets[0]
633
        (sources['okoa']['train_idcs'], sources['okoa']['val_idcs']) = idcs_subsets[1]
634
        (sources['maknee']['train_idcs'], sources['maknee']['val_idcs']) = idcs_subsets[2]
635
636
        sources['oai_imo']['train_df'] = sources['oai_imo']['trainval_df'].iloc[sources['oai_imo']['train_idcs']]
637
        sources['oai_imo']['val_df'] = sources['oai_imo']['trainval_df'].iloc[sources['oai_imo']['val_idcs']]
638
        sources['okoa']['train_df'] = sources['okoa']['trainval_df'].iloc[sources['okoa']['train_idcs']]
639
        sources['okoa']['val_df'] = sources['okoa']['trainval_df'].iloc[sources['okoa']['val_idcs']]
640
        sources['maknee']['train_df'] = sources['maknee']['trainval_df'].iloc[sources['maknee']['train_idcs']]
641
        sources['maknee']['val_df'] = sources['maknee']['trainval_df'].iloc[sources['maknee']['val_idcs']]
642
643
        for n, s in sources.items():
644
            logger.info('Made {} train-val split, number of samples: {}, {}'
645
                        .format(n, len(s['train_df']), len(s['val_df'])))
646
647
        datasets = defaultdict(dict)
648
649
        datasets['oai_imo']['train'] = DatasetOAIiMoSagittal2d(
650
            df_meta=sources['oai_imo']['train_df'],
651
            mask_mode=config['mask_mode'],
652
            sample_mode=config['sample_mode'],
653
            transforms=[
654
                PercentileClippingAndToFloat(cut_min=10, cut_max=99),
655
                CenterCrop(height=300, width=300),
656
                HorizontalFlip(prob=.5),
657
                GammaCorrection(gamma_range=(0.5, 1.5), prob=.5),
658
                OneOf([
659
                    DualCompose([
660
                        Scale(ratio_range=(0.7, 0.8), prob=1.),
661
                        Scale(ratio_range=(1.5, 1.6), prob=1.),
662
                    ]),
663
                    NoTransform()
664
                ]),
665
                Crop(output_size=(300, 300)),
666
                BilateralFilter(d=5, sigma_color=50, sigma_space=50, prob=.3),
667
                Normalize(mean=0.252699, std=0.251142),
668
                ToTensor(),
669
            ])
670
        datasets['okoa']['train'] = DatasetOKOASagittal2d(
671
            df_meta=sources['okoa']['train_df'],
672
            mask_mode='background_femoral_unitibial',
673
            sample_mode=config['sample_mode'],
674
            transforms=[
675
                PercentileClippingAndToFloat(cut_min=10, cut_max=99),
676
                CenterCrop(height=300, width=300),
677
                HorizontalFlip(prob=.5),
678
                GammaCorrection(gamma_range=(0.5, 1.5), prob=.5),
679
                OneOf([
680
                    DualCompose([
681
                        Scale(ratio_range=(0.7, 0.8), prob=1.),
682
                        Scale(ratio_range=(1.5, 1.6), prob=1.),
683
                    ]),
684
                    NoTransform()
685
                ]),
686
                Crop(output_size=(300, 300)),
687
                BilateralFilter(d=5, sigma_color=50, sigma_space=50, prob=.3),
688
689
                Normalize(mean=0.252699, std=0.251142),
690
                ToTensor(),
691
            ])
692
        datasets['maknee']['train'] = DatasetMAKNEESagittal2d(
693
            df_meta=sources['maknee']['train_df'],
694
            mask_mode='',
695
            sample_mode=config['sample_mode'],
696
            transforms=[
697
                PercentileClippingAndToFloat(cut_min=10, cut_max=99),
698
                CenterCrop(height=300, width=300),
699
                HorizontalFlip(prob=.5),
700
                GammaCorrection(gamma_range=(0.5, 1.5), prob=.5),
701
                OneOf([
702
                    DualCompose([
703
                        Scale(ratio_range=(0.7, 0.8), prob=1.),
704
                        Scale(ratio_range=(1.5, 1.6), prob=1.),
705
                    ]),
706
                    NoTransform()
707
                ]),
708
                Crop(output_size=(300, 300)),
709
                BilateralFilter(d=5, sigma_color=50, sigma_space=50, prob=.3),
710
                Normalize(mean=0.252699, std=0.251142),
711
                ToTensor(),
712
            ])
713
        datasets['oai_imo']['val'] = DatasetOAIiMoSagittal2d(
714
            df_meta=sources['oai_imo']['val_df'],
715
            mask_mode=config['mask_mode'],
716
            sample_mode=config['sample_mode'],
717
            transforms=[
718
                PercentileClippingAndToFloat(cut_min=10, cut_max=99),
719
                CenterCrop(height=300, width=300),
720
                Normalize(mean=0.252699, std=0.251142),
721
                ToTensor()
722
            ])
723
        datasets['okoa']['val'] = DatasetOKOASagittal2d(
724
            df_meta=sources['okoa']['val_df'],
725
            mask_mode='background_femoral_unitibial',
726
            sample_mode=config['sample_mode'],
727
            transforms=[
728
                PercentileClippingAndToFloat(cut_min=10, cut_max=99),
729
                CenterCrop(height=300, width=300),
730
                Normalize(mean=0.252699, std=0.251142),
731
                ToTensor()
732
            ])
733
        datasets['maknee']['val'] = DatasetMAKNEESagittal2d(
734
            df_meta=sources['maknee']['val_df'],
735
            mask_mode='',
736
            sample_mode=config['sample_mode'],
737
            transforms=[
738
                PercentileClippingAndToFloat(cut_min=10, cut_max=99),
739
                CenterCrop(height=300, width=300),
740
                Normalize(mean=0.252699, std=0.251142),
741
                ToTensor()
742
            ])
743
744
        loaders = defaultdict(dict)
745
746
        loaders['oai_imo']['train'] = DataLoader(
747
            datasets['oai_imo']['train'],
748
            batch_size=int(config['batch_size'] / 2),
749
            shuffle=True,
750
            num_workers=config['num_workers'],
751
            drop_last=True)
752
        loaders['oai_imo']['val'] = DataLoader(
753
            datasets['oai_imo']['val'],
754
            batch_size=int(config['batch_size'] / 2),
755
            shuffle=False,
756
            num_workers=config['num_workers'],
757
            drop_last=True)
758
        loaders['okoa']['train'] = DataLoader(
759
            datasets['okoa']['train'],
760
            batch_size=int(config['batch_size'] / 2),
761
            shuffle=True,
762
            num_workers=config['num_workers'],
763
            drop_last=True)
764
        loaders['okoa']['val'] = DataLoader(
765
            datasets['okoa']['val'],
766
            batch_size=int(config['batch_size'] / 2),
767
            shuffle=False,
768
            num_workers=config['num_workers'],
769
            drop_last=True)
770
        loaders['maknee']['train'] = DataLoader(
771
            datasets['maknee']['train'],
772
            batch_size=int(config['batch_size'] / 2),
773
            shuffle=True,
774
            num_workers=config['num_workers'],
775
            drop_last=True)
776
        loaders['maknee']['val'] = DataLoader(
777
            datasets['maknee']['val'],
778
            batch_size=int(config['batch_size'] / 2),
779
            shuffle=False,
780
            num_workers=config['num_workers'],
781
            drop_last=True)
782
783
        trainer = ModelTrainer(config=config, fold_idx=fold_idx)
784
785
        tmp = trainer.fit(loaders=loaders)
786
        metrics_train, fnames_train, metrics_val, fnames_val = tmp
787
788
        fold_scores[fold_idx] = (metrics_val['datasetw']['dice_score_oai'],
789
                                 metrics_val['datasetw']['dice_score_okoa'])
790
        trainer.tensorboard.close()
791
    logger.info(f'Fold scores:\n{repr(fold_scores)}')
792
793
794
if __name__ == '__main__':
795
    main()