Diff of /rocaseg/train_uda2.py [000000] .. [6969be]

Switch to unified view

a b/rocaseg/train_uda2.py
1
import os
2
import logging
3
from collections import defaultdict
4
import gc
5
import click
6
import resource
7
8
import numpy as np
9
import cv2
10
11
import torch
12
import torch.nn.functional as torch_fn
13
from torch import nn
14
from torch.utils.data.dataloader import DataLoader
15
from torch.utils.tensorboard import SummaryWriter
16
from tqdm import tqdm
17
18
from rocaseg.datasets import (DatasetOAIiMoSagittal2d,
19
                              DatasetOKOASagittal2d,
20
                              DatasetMAKNEESagittal2d,
21
                              sources_from_path)
22
from rocaseg.models import dict_models
23
from rocaseg.components import (dict_losses, confusion_matrix, dice_score_from_cm,
24
                                dict_optimizers, CheckpointHandler)
25
from rocaseg.preproc import *
26
from rocaseg.repro import set_ultimate_seed
27
28
29
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
30
resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))
31
32
cv2.ocl.setUseOpenCL(False)
33
cv2.setNumThreads(0)
34
35
logging.basicConfig()
36
logger = logging.getLogger('train')
37
logger.setLevel(logging.DEBUG)
38
39
set_ultimate_seed()
40
41
if torch.cuda.is_available():
42
    maybe_gpu = 'cuda'
43
else:
44
    maybe_gpu = 'cpu'
45
46
47
class ModelTrainer:
48
    def __init__(self, config, fold_idx=None):
49
        self.config = config
50
        self.fold_idx = fold_idx
51
52
        self.paths_weights_fold = dict()
53
        self.paths_weights_fold['segm'] = \
54
            os.path.join(config['path_weights'], 'segm', f'fold_{self.fold_idx}')
55
        os.makedirs(self.paths_weights_fold['segm'], exist_ok=True)
56
        self.paths_weights_fold['discr_out'] = \
57
            os.path.join(config['path_weights'], 'discr_out', f'fold_{self.fold_idx}')
58
        os.makedirs(self.paths_weights_fold['discr_out'], exist_ok=True)
59
        self.paths_weights_fold['discr_aux'] = \
60
            os.path.join(config['path_weights'], 'discr_aux', f'fold_{self.fold_idx}')
61
        os.makedirs(self.paths_weights_fold['discr_aux'], exist_ok=True)
62
63
        self.path_logs_fold = \
64
            os.path.join(config['path_logs'], f'fold_{self.fold_idx}')
65
        os.makedirs(self.path_logs_fold, exist_ok=True)
66
67
        self.handlers_ckpt = dict()
68
        self.handlers_ckpt['segm'] = CheckpointHandler(self.paths_weights_fold['segm'])
69
        self.handlers_ckpt['discr_aux'] = CheckpointHandler(self.paths_weights_fold['discr_aux'])
70
        self.handlers_ckpt['discr_out'] = CheckpointHandler(self.paths_weights_fold['discr_out'])
71
72
        paths_ckpt_sel = dict()
73
        paths_ckpt_sel['segm'] = self.handlers_ckpt['segm'].get_last_ckpt()
74
        paths_ckpt_sel['discr_aux'] = self.handlers_ckpt['discr_aux'].get_last_ckpt()
75
        paths_ckpt_sel['discr_out'] = self.handlers_ckpt['discr_out'].get_last_ckpt()
76
77
        # Initialize and configure the models
78
        self.models = dict()
79
        self.models['segm'] = (dict_models[config['model_segm']]
80
                               (input_channels=self.config['input_channels'],
81
                                output_channels=self.config['output_channels'],
82
                                center_depth=self.config['center_depth'],
83
                                pretrained=self.config['pretrained'],
84
                                path_pretrained=self.config['path_pretrained_segm'],
85
                                restore_weights=self.config['restore_weights'],
86
                                path_weights=paths_ckpt_sel['segm'],
87
                                with_aux=True))
88
        self.models['segm'] = nn.DataParallel(self.models['segm'])
89
        self.models['segm'] = self.models['segm'].to(maybe_gpu)
90
91
        self.models['discr_aux'] = (dict_models[config['model_discr_aux']]
92
                                    (input_channels=self.config['output_channels'],
93
                                     output_channels=1,
94
                                     pretrained=self.config['pretrained'],
95
                                     restore_weights=self.config['restore_weights'],
96
                                     path_weights=paths_ckpt_sel['discr_aux']))
97
        self.models['discr_aux'] = nn.DataParallel(self.models['discr_aux'])
98
        self.models['discr_aux'] = self.models['discr_aux'].to(maybe_gpu)
99
100
        self.models['discr_out'] = (dict_models[config['model_discr_out']]
101
                                    (input_channels=self.config['output_channels'],
102
                                     output_channels=1,
103
                                     pretrained=self.config['pretrained'],
104
                                     restore_weights=self.config['restore_weights'],
105
                                     path_weights=paths_ckpt_sel['discr_out']))
106
        self.models['discr_out'] = nn.DataParallel(self.models['discr_out'])
107
        self.models['discr_out'] = self.models['discr_out'].to(maybe_gpu)
108
109
        # Configure the training
110
        self.optimizers = dict()
111
        self.optimizers['segm'] = (dict_optimizers['adam'](
112
            self.models['segm'].parameters(),
113
            lr=self.config['lr_segm'],
114
            weight_decay=5e-5))
115
        self.optimizers['discr_aux'] = (dict_optimizers['adam'](
116
            self.models['discr_aux'].parameters(),
117
            lr=self.config['lr_discr'],
118
            weight_decay=5e-5))
119
        self.optimizers['discr_out'] = (dict_optimizers['adam'](
120
            self.models['discr_out'].parameters(),
121
            lr=self.config['lr_discr'],
122
            weight_decay=5e-5))
123
124
        self.lr_update_rule = {25: 0.1}
125
126
        self.losses = dict()
127
        self.losses['segm'] = dict_losses[self.config['loss_segm']](
128
            num_classes=self.config['output_channels'],
129
        )
130
        self.losses['advers'] = dict_losses['bce_loss']()
131
        self.losses['discr'] = dict_losses['bce_loss']()
132
133
        self.losses['segm'] = self.losses['segm'].to(maybe_gpu)
134
        self.losses['advers'] = self.losses['advers'].to(maybe_gpu)
135
        self.losses['discr'] = self.losses['discr'].to(maybe_gpu)
136
137
        self.tensorboard = SummaryWriter(self.path_logs_fold)
138
139
    def run_one_epoch(self, epoch_idx, loaders):
140
        COEFF_DISCR = 1
141
        COEFF_SEGM_AUX = 0.1
142
        COEFF_SEGM_OUT = 1
143
        COEFF_ADVERS_AUX = 0.0002
144
        COEFF_ADVERS_OUT = 0.001
145
146
        fnames_acc = defaultdict(list)
147
        metrics_acc = dict()
148
        metrics_acc['samplew'] = defaultdict(list)
149
        metrics_acc['batchw'] = defaultdict(list)
150
        metrics_acc['datasetw'] = defaultdict(list)
151
        metrics_acc['datasetw']['cm_oai'] = \
152
            np.zeros((self.config['output_channels'],) * 2, dtype=np.uint32)
153
        metrics_acc['datasetw']['cm_okoa'] = \
154
            np.zeros((self.config['output_channels'],) * 2, dtype=np.uint32)
155
156
        prog_bar_params = {'postfix': {'epoch': epoch_idx}, }
157
158
        is_training = (self.models['segm'].training and
159
                       self.models['discr_aux'].training and
160
                       self.models['discr_out'].training)
161
        if is_training:
162
            # ------------------------ Training regime ------------------------
163
            loader_oai = loaders['oai_imo']['train']
164
            loader_maknee = loaders['maknee']['train']
165
166
            steps_oai, steps_maknee = len(loader_oai), len(loader_maknee)
167
            steps_total = steps_oai
168
            prog_bar_params.update({'total': steps_total,
169
                                    'desc': f'Train, epoch {epoch_idx}'})
170
171
            loader_oai_iter = iter(loader_oai)
172
            loader_maknee_iter = iter(loader_maknee)
173
174
            loader_oai_iter_old = None
175
            loader_maknee_iter_old = None
176
177
            with tqdm(**prog_bar_params) as prog_bar:
178
                for step_idx in range(steps_total):
179
                    self.optimizers['segm'].zero_grad()
180
                    self.optimizers['discr_aux'].zero_grad()
181
                    self.optimizers['discr_out'].zero_grad()
182
183
                    metrics_acc['batchw']['loss_total'].append(0)
184
185
                    try:
186
                        data_batch_oai = next(loader_oai_iter)
187
                    except StopIteration:
188
                        loader_oai_iter_old = loader_oai_iter
189
                        loader_oai_iter = iter(loader_oai)
190
                        data_batch_oai = next(loader_oai_iter)
191
192
                    try:
193
                        data_batch_maknee = next(loader_maknee_iter)
194
                    except StopIteration:
195
                        loader_maknee_iter_old = loader_maknee_iter
196
                        loader_maknee_iter = iter(loader_maknee)
197
                        data_batch_maknee = next(loader_maknee_iter)
198
199
                    xs_oai, ys_true_oai = data_batch_oai['xs'], data_batch_oai['ys']
200
                    fnames_acc['oai'].extend(data_batch_oai['path_image'])
201
                    xs_oai = xs_oai.to(maybe_gpu)
202
                    ys_true_arg_oai = torch.argmax(ys_true_oai.long().to(maybe_gpu), dim=1)
203
204
                    xs_maknee, ys_true_maknee = data_batch_maknee['xs'], data_batch_maknee['ys']
205
                    fnames_acc['maknee'].extend(data_batch_maknee['path_image'])
206
                    xs_maknee = xs_maknee.to(maybe_gpu)
207
208
                    # -------------- Train discriminator network -------------
209
                    # With source
210
                    ys_pred_out_oai, ys_pred_aux_oai = self.models['segm'](xs_oai)
211
                    ys_pred_out_softmax_oai = torch_fn.softmax(ys_pred_out_oai, dim=1)
212
                    ys_pred_aux_softmax_oai = torch_fn.softmax(ys_pred_aux_oai, dim=1)
213
214
                    zs_pred_out_oai = self.models['discr_out'](ys_pred_out_softmax_oai)
215
                    zs_pred_aux_oai = self.models['discr_aux'](ys_pred_aux_softmax_oai)
216
217
                    # Use 0 as a label for the source domain
218
                    # Output representations
219
                    loss_discr_out_0 = self.losses['discr'](
220
                        input=zs_pred_out_oai,
221
                        target=torch.zeros_like(zs_pred_out_oai, device=maybe_gpu))
222
                    loss_discr_out_0 = loss_discr_out_0 / 2 * COEFF_DISCR
223
                    loss_discr_out_0.backward(retain_graph=True)
224
                    metrics_acc['batchw']['loss_discr_out_0'].append(loss_discr_out_0.item())
225
                    metrics_acc['batchw']['loss_total'][-1] += \
226
                        metrics_acc['batchw']['loss_discr_out_0'][-1]
227
228
                    # Penultimate representations
229
                    loss_discr_aux_0 = self.losses['discr'](
230
                        input=zs_pred_aux_oai,
231
                        target=torch.zeros_like(zs_pred_aux_oai, device=maybe_gpu))
232
                    loss_discr_aux_0 = loss_discr_aux_0 / 2 * COEFF_DISCR
233
                    loss_discr_aux_0.backward(retain_graph=True)
234
                    metrics_acc['batchw']['loss_discr_aux_0'].append(
235
                        loss_discr_aux_0.item())
236
                    metrics_acc['batchw']['loss_total'][-1] += \
237
                        metrics_acc['batchw']['loss_discr_aux_0'][-1]
238
239
                    # With target
240
                    self.models['segm'] = self.models['segm'].eval()
241
                    ys_pred_out_maknee, ys_pred_aux_maknee = self.models['segm'](xs_maknee)
242
                    self.models['segm'] = self.models['segm'].train()
243
244
                    ys_pred_out_softmax_maknee = torch_fn.softmax(ys_pred_out_maknee, dim=1)
245
                    ys_pred_aux_softmax_maknee = torch_fn.softmax(ys_pred_aux_maknee, dim=1)
246
                    zs_pred_out_maknee = self.models['discr_out'](ys_pred_out_softmax_maknee)
247
                    zs_pred_aux_maknee = self.models['discr_aux'](ys_pred_aux_softmax_maknee)
248
249
                    # Use 1 as a label for the target domain
250
                    # Output representations
251
                    loss_discr_out_1 = self.losses['discr'](
252
                        input=zs_pred_out_maknee,
253
                        target=torch.ones_like(zs_pred_out_maknee, device=maybe_gpu))
254
                    loss_discr_out_1 = loss_discr_out_1 / 2 * COEFF_DISCR
255
                    loss_discr_out_1.backward(retain_graph=True)
256
                    metrics_acc['batchw']['loss_discr_out_1'].append(loss_discr_out_1.item())
257
                    metrics_acc['batchw']['loss_total'][-1] += \
258
                        metrics_acc['batchw']['loss_discr_out_1'][-1]
259
260
                    # Penultimate representations
261
                    loss_discr_aux_1 = self.losses['discr'](
262
                        input=zs_pred_aux_maknee,
263
                        target=torch.ones_like(zs_pred_aux_maknee, device=maybe_gpu))
264
                    loss_discr_aux_1 = loss_discr_aux_1 / 2 * COEFF_DISCR
265
                    loss_discr_aux_1.backward()
266
                    metrics_acc['batchw']['loss_discr_aux_1'].append(
267
                        loss_discr_aux_1.item())
268
                    metrics_acc['batchw']['loss_total'][-1] += \
269
                        metrics_acc['batchw']['loss_discr_aux_1'][-1]
270
271
                    self.models['segm'].zero_grad()
272
                    self.optimizers['discr_aux'].step()
273
                    self.optimizers['discr_out'].step()
274
                    self.models['discr_aux'].zero_grad()
275
                    self.models['discr_out'].zero_grad()
276
277
                    # ---------------- Train segmentation network ------------
278
                    # With source
279
                    ys_pred_out_oai, ys_pred_aux_oai = self.models['segm'](xs_oai)
280
281
                    # Output representations
282
                    loss_segm_out = self.losses['segm'](input_=ys_pred_out_oai,
283
                                                        target=ys_true_arg_oai)
284
                    loss_segm_out.backward(retain_graph=True)
285
                    loss_segm_out = loss_segm_out * COEFF_SEGM_OUT
286
                    metrics_acc['batchw']['loss_segm_out'].append(loss_segm_out.item())
287
                    metrics_acc['batchw']['loss_total'][-1] += \
288
                        metrics_acc['batchw']['loss_segm_out'][-1]
289
290
                    # Penultimate representations
291
                    loss_segm_aux = self.losses['segm'](input_=ys_pred_aux_oai,
292
                                                        target=ys_true_arg_oai)
293
                    loss_segm_aux.backward(retain_graph=True)
294
                    loss_segm_aux = loss_segm_aux * COEFF_SEGM_AUX
295
                    metrics_acc['batchw']['loss_segm_aux'].append(loss_segm_aux.item())
296
                    metrics_acc['batchw']['loss_total'][-1] += \
297
                        metrics_acc['batchw']['loss_segm_aux'][-1]
298
299
                    # With target
300
                    self.models['segm'] = self.models['segm'].eval()
301
                    ys_pred_out_maknee, ys_pred_aux_maknee = self.models['segm'](xs_maknee)
302
                    self.models['segm'] = self.models['segm'].train()
303
304
                    ys_pred_out_softmax_maknee = torch_fn.softmax(ys_pred_out_maknee, dim=1)
305
                    ys_pred_aux_softmax_maknee = torch_fn.softmax(ys_pred_aux_maknee, dim=1)
306
                    zs_pred_out_maknee = self.models['discr_out'](ys_pred_out_softmax_maknee)
307
                    zs_pred_aux_maknee = self.models['discr_aux'](ys_pred_aux_softmax_maknee)
308
309
                    # Use 0 as a label for the source domain
310
                    # Output representations
311
                    loss_advers_out = self.losses['advers'](
312
                        input=zs_pred_out_maknee,
313
                        target=torch.zeros_like(zs_pred_out_maknee, device=maybe_gpu))
314
                    loss_advers_out = loss_advers_out * COEFF_ADVERS_OUT
315
                    loss_advers_out.backward(retain_graph=True)
316
                    metrics_acc['batchw']['loss_advers_out'].append(loss_advers_out.item())
317
                    metrics_acc['batchw']['loss_total'][-1] += \
318
                        metrics_acc['batchw']['loss_advers_out'][-1]
319
320
                    # Penultimate representations
321
                    loss_advers_aux = self.losses['advers'](
322
                        input=zs_pred_aux_maknee,
323
                        target=torch.zeros_like(zs_pred_aux_maknee, device=maybe_gpu))
324
                    loss_advers_aux = loss_advers_aux * COEFF_ADVERS_AUX
325
                    loss_advers_aux.backward()
326
                    metrics_acc['batchw']['loss_advers_aux'].append(
327
                        loss_advers_aux.item())
328
                    metrics_acc['batchw']['loss_total'][-1] += \
329
                        metrics_acc['batchw']['loss_advers_aux'][-1]
330
331
                    self.models['discr_out'].zero_grad()
332
                    self.models['discr_aux'].zero_grad()
333
                    self.optimizers['segm'].step()
334
                    self.models['segm'].zero_grad()
335
336
                    if step_idx % 10 == 0:
337
                        self.tensorboard.add_scalars(
338
                            f'fold_{self.fold_idx}/losses_train',
339
                            {'discr_0_aux_batchw':
340
                                 metrics_acc['batchw']['loss_discr_aux_0'][-1],
341
                             'discr_0_out_batchw':
342
                                 metrics_acc['batchw']['loss_discr_out_0'][-1],
343
                             'discr_1_aux_batchw':
344
                                 metrics_acc['batchw']['loss_discr_aux_1'][-1],
345
                             'discr_1_out_batchw':
346
                                 metrics_acc['batchw']['loss_discr_out_1'][-1],
347
                             'discr_sum_aux_batchw':
348
                                 (metrics_acc['batchw']['loss_discr_aux_0'][-1] +
349
                                  metrics_acc['batchw']['loss_discr_aux_1'][-1]),
350
                             'discr_sum_out_batchw':
351
                                 (metrics_acc['batchw']['loss_discr_out_0'][-1] +
352
                                  metrics_acc['batchw']['loss_discr_out_1'][-1]),
353
                             'discr_sum_all_batchw':
354
                                 (metrics_acc['batchw']['loss_discr_aux_0'][-1] +
355
                                  metrics_acc['batchw']['loss_discr_out_0'][-1] +
356
                                  metrics_acc['batchw']['loss_discr_aux_1'][-1] +
357
                                  metrics_acc['batchw']['loss_discr_out_1'][-1]),
358
                             'segm_aux_batchw': metrics_acc['batchw']['loss_segm_aux'][-1],
359
                             'segm_out_batchw': metrics_acc['batchw']['loss_segm_out'][-1],
360
                             'advers_aux_batchw': metrics_acc['batchw']['loss_advers_aux'][-1],
361
                             'advers_out_batchw': metrics_acc['batchw']['loss_advers_out'][-1],
362
                             'total_batchw': metrics_acc['batchw']['loss_total'][-1],
363
                             }, global_step=(epoch_idx * steps_total + step_idx))
364
365
                    prog_bar.update(1)
366
367
            del [loader_oai_iter_old, loader_maknee_iter_old]
368
            gc.collect()
369
        else:
370
            # ----------------------- Validation regime -----------------------
371
            loader_oai = loaders['oai_imo']['val']
372
            loader_okoa = loaders['okoa']['val']
373
            loader_maknee = loaders['maknee']['val']
374
375
            steps_oai, steps_okoa, steps_maknee = len(loader_oai), len(loader_okoa), len(loader_maknee)
376
            steps_total = steps_oai
377
            prog_bar_params.update({'total': steps_total,
378
                                    'desc': f'Validate, epoch {epoch_idx}'})
379
380
            loader_oai_iter = iter(loader_oai)
381
            loader_okoa_iter = iter(loader_okoa)
382
            loader_maknee_iter = iter(loader_maknee)
383
384
            loader_oai_iter_old = None
385
            loader_okoa_iter_old = None
386
            loader_maknee_iter_old = None
387
388
            with torch.no_grad(), tqdm(**prog_bar_params) as prog_bar:
389
                for step_idx in range(steps_total):
390
                    metrics_acc['batchw']['loss_total'].append(0)
391
392
                    try:
393
                        data_batch_oai = next(loader_oai_iter)
394
                    except StopIteration:
395
                        loader_oai_iter_old = loader_oai_iter
396
                        loader_oai_iter = iter(loader_oai)
397
                        data_batch_oai = next(loader_oai_iter)
398
399
                    try:
400
                        data_batch_okoa = next(loader_okoa_iter)
401
                    except StopIteration:
402
                        loader_okoa_iter_old = loader_okoa_iter
403
                        loader_okoa_iter = iter(loader_okoa)
404
                        data_batch_okoa = next(loader_okoa_iter)
405
406
                    try:
407
                        data_batch_maknee = next(loader_maknee_iter)
408
                    except StopIteration:
409
                        loader_maknee_iter_old = loader_maknee_iter
410
                        loader_maknee_iter = iter(loader_maknee)
411
                        data_batch_maknee = next(loader_maknee_iter)
412
413
                    xs_oai, ys_true_oai = data_batch_oai['xs'], data_batch_oai['ys']
414
                    fnames_acc['oai'].extend(data_batch_oai['path_image'])
415
                    xs_oai = xs_oai.to(maybe_gpu)
416
                    ys_true_arg_oai = torch.argmax(ys_true_oai.long().to(maybe_gpu), dim=1)
417
418
                    xs_maknee, ys_true_maknee = data_batch_maknee['xs'], data_batch_maknee['ys']
419
                    fnames_acc['maknee'].extend(data_batch_maknee['path_image'])
420
                    xs_maknee = xs_maknee.to(maybe_gpu)
421
422
                    # -------------- Validate discriminator network -------------
423
                    # With source
424
                    ys_pred_out_oai, ys_pred_aux_oai = self.models['segm'](xs_oai)
425
                    ys_pred_out_softmax_oai = torch_fn.softmax(ys_pred_out_oai, dim=1)
426
                    ys_pred_aux_softmax_oai = torch_fn.softmax(ys_pred_aux_oai, dim=1)
427
428
                    zs_pred_out_oai = self.models['discr_out'](ys_pred_out_softmax_oai)
429
                    zs_pred_aux_oai = self.models['discr_aux'](ys_pred_aux_softmax_oai)
430
431
                    # Use 0 as a label for the source domain
432
                    # Output representations
433
                    loss_discr_out_0 = self.losses['discr'](
434
                        input=zs_pred_out_oai,
435
                        target=torch.zeros_like(zs_pred_out_oai, device=maybe_gpu))
436
                    loss_discr_out_0 = loss_discr_out_0 / 2 * COEFF_DISCR
437
                    metrics_acc['batchw']['loss_discr_out_0'].append(
438
                        loss_discr_out_0.item())
439
                    metrics_acc['batchw']['loss_total'][-1] += \
440
                        metrics_acc['batchw']['loss_discr_out_0'][-1]
441
442
                    # Penultimate representations
443
                    loss_discr_aux_0 = self.losses['discr'](
444
                        input=zs_pred_aux_oai,
445
                        target=torch.zeros_like(zs_pred_aux_oai, device=maybe_gpu))
446
                    loss_discr_aux_0 = loss_discr_aux_0 / 2 * COEFF_DISCR
447
                    metrics_acc['batchw']['loss_discr_aux_0'].append(
448
                        loss_discr_aux_0.item())
449
                    metrics_acc['batchw']['loss_total'][-1] += \
450
                        metrics_acc['batchw']['loss_discr_aux_0'][-1]
451
452
                    # With target
453
                    ys_pred_out_maknee, ys_pred_aux_maknee = self.models['segm'](xs_maknee)
454
455
                    ys_pred_out_softmax_maknee = torch_fn.softmax(ys_pred_out_maknee,
456
                                                                  dim=1)
457
                    ys_pred_aux_softmax_maknee = torch_fn.softmax(ys_pred_aux_maknee,
458
                                                                  dim=1)
459
                    zs_pred_out_maknee = self.models['discr_out'](ys_pred_out_softmax_maknee)
460
                    zs_pred_aux_maknee = self.models['discr_aux'](ys_pred_aux_softmax_maknee)
461
462
                    # Use 1 as a label for the target domain
463
                    # Output representations
464
                    loss_discr_out_1 = self.losses['discr'](
465
                        input=zs_pred_out_maknee,
466
                        target=torch.ones_like(zs_pred_out_maknee, device=maybe_gpu))
467
                    loss_discr_out_1 = loss_discr_out_1 / 2 * COEFF_DISCR
468
                    metrics_acc['batchw']['loss_discr_out_1'].append(
469
                        loss_discr_out_1.item())
470
                    metrics_acc['batchw']['loss_total'][-1] += \
471
                        metrics_acc['batchw']['loss_discr_out_1'][-1]
472
473
                    # Penultimate representations
474
                    loss_discr_aux_1 = self.losses['discr'](
475
                        input=zs_pred_aux_maknee,
476
                        target=torch.ones_like(zs_pred_aux_maknee, device=maybe_gpu))
477
                    loss_discr_aux_1 = loss_discr_aux_1 / 2 * COEFF_DISCR
478
                    metrics_acc['batchw']['loss_discr_aux_1'].append(
479
                        loss_discr_aux_1.item())
480
                    metrics_acc['batchw']['loss_total'][-1] += \
481
                        metrics_acc['batchw']['loss_discr_aux_1'][-1]
482
483
                    # ---------------- Validate segmentation network ------------
484
                    # With source
485
                    ys_pred_out_oai, ys_pred_aux_oai = self.models['segm'](xs_oai)
486
487
                    # Output representations
488
                    loss_segm_out = self.losses['segm'](input_=ys_pred_out_oai,
489
                                                        target=ys_true_arg_oai)
490
                    loss_segm_out = loss_segm_out * COEFF_SEGM_OUT
491
                    metrics_acc['batchw']['loss_segm_out'].append(loss_segm_out.item())
492
                    metrics_acc['batchw']['loss_total'][-1] += \
493
                        metrics_acc['batchw']['loss_segm_out'][-1]
494
495
                    # Penultimate representations
496
                    loss_segm_aux = self.losses['segm'](input_=ys_pred_aux_oai,
497
                                                        target=ys_true_arg_oai)
498
                    loss_segm_aux = loss_segm_aux * COEFF_SEGM_AUX
499
                    metrics_acc['batchw']['loss_segm_aux'].append(loss_segm_aux.item())
500
                    metrics_acc['batchw']['loss_total'][-1] += \
501
                        metrics_acc['batchw']['loss_segm_aux'][-1]
502
503
                    # With target
504
                    ys_pred_out_maknee, ys_pred_aux_maknee = self.models['segm'](xs_maknee)
505
506
                    ys_pred_out_softmax_maknee = torch_fn.softmax(ys_pred_out_maknee, dim=1)
507
                    ys_pred_aux_softmax_maknee = torch_fn.softmax(ys_pred_aux_maknee, dim=1)
508
                    zs_pred_out_maknee = self.models['discr_out'](ys_pred_out_softmax_maknee)
509
                    zs_pred_aux_maknee = self.models['discr_aux'](ys_pred_aux_softmax_maknee)
510
511
                    # Use 0 as a label for the source domain
512
                    # Output representations
513
                    loss_advers_out = self.losses['advers'](
514
                        input=zs_pred_out_maknee,
515
                        target=torch.zeros_like(zs_pred_out_maknee, device=maybe_gpu))
516
                    loss_advers_out = loss_advers_out * COEFF_ADVERS_OUT
517
                    metrics_acc['batchw']['loss_advers_out'].append(
518
                        loss_advers_out.item())
519
                    metrics_acc['batchw']['loss_total'][-1] += \
520
                        metrics_acc['batchw']['loss_advers_out'][-1]
521
522
                    # Penultimate representations
523
                    loss_advers_aux = self.losses['advers'](
524
                        input=zs_pred_aux_maknee,
525
                        target=torch.zeros_like(zs_pred_aux_maknee, device=maybe_gpu))
526
                    loss_advers_aux = loss_advers_aux * COEFF_ADVERS_AUX
527
                    metrics_acc['batchw']['loss_advers_aux'].append(
528
                        loss_advers_aux.item())
529
                    metrics_acc['batchw']['loss_total'][-1] += \
530
                        metrics_acc['batchw']['loss_advers_aux'][-1]
531
532
                    if step_idx % 10 == 0:
533
                        self.tensorboard.add_scalars(
534
                            f'fold_{self.fold_idx}/losses_val',
535
                            {'discr_0_aux_batchw':
536
                                 metrics_acc['batchw']['loss_discr_aux_0'][-1],
537
                             'discr_0_out_batchw':
538
                                 metrics_acc['batchw']['loss_discr_out_0'][-1],
539
                             'discr_1_aux_batchw':
540
                                 metrics_acc['batchw']['loss_discr_aux_1'][-1],
541
                             'discr_1_out_batchw':
542
                                 metrics_acc['batchw']['loss_discr_out_1'][-1],
543
                             'discr_sum_aux_batchw':
544
                                 (metrics_acc['batchw']['loss_discr_aux_0'][-1] +
545
                                  metrics_acc['batchw']['loss_discr_aux_1'][-1]),
546
                             'discr_sum_out_batchw':
547
                                 (metrics_acc['batchw']['loss_discr_out_0'][-1] +
548
                                  metrics_acc['batchw']['loss_discr_out_1'][-1]),
549
                             'discr_sum_all_batchw':
550
                                 (metrics_acc['batchw']['loss_discr_aux_0'][-1] +
551
                                  metrics_acc['batchw']['loss_discr_out_0'][-1] +
552
                                  metrics_acc['batchw']['loss_discr_aux_1'][-1] +
553
                                  metrics_acc['batchw']['loss_discr_out_1'][-1]),
554
                             'segm_aux_batchw': metrics_acc['batchw']['loss_segm_aux'][-1],
555
                             'segm_out_batchw': metrics_acc['batchw']['loss_segm_out'][-1],
556
                             'advers_aux_batchw': metrics_acc['batchw']['loss_advers_aux'][-1],
557
                             'advers_out_batchw': metrics_acc['batchw']['loss_advers_out'][-1],
558
                             'total_batchw': metrics_acc['batchw']['loss_total'][-1],
559
                             }, global_step=(epoch_idx * steps_total + step_idx))
560
561
                    # ------------------ Calculate metrics -------------------
562
                    ys_pred_arg_np_oai = torch.argmax(ys_pred_out_softmax_oai, 1).to('cpu').numpy()
563
                    ys_true_arg_np_oai = ys_true_arg_oai.to('cpu').numpy()
564
565
                    metrics_acc['datasetw']['cm_oai'] += confusion_matrix(
566
                        ys_pred_arg_np_oai, ys_true_arg_np_oai,
567
                        self.config['output_channels'])
568
569
                    # Don't consider repeating entries for the metrics calculation
570
                    if step_idx < steps_okoa:
571
                        xs_okoa, ys_true_okoa = data_batch_okoa['xs'], data_batch_okoa['ys']
572
                        fnames_acc['okoa'].extend(data_batch_okoa['path_image'])
573
                        xs_okoa = xs_okoa.to(maybe_gpu)
574
575
                        ys_pred_okoa, _ = self.models['segm'](xs_okoa)
576
577
                        ys_true_arg_okoa = torch.argmax(ys_true_okoa.long().to(maybe_gpu), dim=1)
578
                        ys_pred_softmax_okoa = torch_fn.softmax(ys_pred_okoa, dim=1)
579
580
                        ys_pred_arg_np_okoa = torch.argmax(ys_pred_softmax_okoa, 1).to('cpu').numpy()
581
                        ys_true_arg_np_okoa = ys_true_arg_okoa.to('cpu').numpy()
582
583
                        metrics_acc['datasetw']['cm_okoa'] += confusion_matrix(
584
                            ys_pred_arg_np_okoa, ys_true_arg_np_okoa,
585
                            self.config['output_channels'])
586
587
                    prog_bar.update(1)
588
589
            del [loader_oai_iter_old, loader_okoa_iter_old, loader_maknee_iter_old]
590
            gc.collect()
591
592
        for k, v in metrics_acc['samplew'].items():
593
            metrics_acc['samplew'][k] = np.asarray(v)
594
        metrics_acc['datasetw']['dice_score_oai'] = np.asarray(
595
            dice_score_from_cm(metrics_acc['datasetw']['cm_oai']))
596
        metrics_acc['datasetw']['dice_score_okoa'] = np.asarray(
597
            dice_score_from_cm(metrics_acc['datasetw']['cm_okoa']))
598
        return metrics_acc, fnames_acc
599
600
    def fit(self, loaders):
601
        epoch_idx_best = -1
602
        loss_best = float('inf')
603
        metrics_train_best = dict()
604
        fnames_train_best = []
605
        metrics_val_best = dict()
606
        fnames_val_best = []
607
608
        for epoch_idx in range(self.config['epoch_num']):
609
            self.models = {n: m.train() for n, m in self.models.items()}
610
            metrics_train, fnames_train = \
611
                self.run_one_epoch(epoch_idx, loaders)
612
613
            # Process the accumulated metrics
614
            for k, v in metrics_train['batchw'].items():
615
                if k.startswith('loss'):
616
                    metrics_train['datasetw'][k] = np.mean(np.asarray(v))
617
                else:
618
                    logger.warning(f'Non-processed batch-wise entry: {k}')
619
620
            self.models = {n: m.eval() for n, m in self.models.items()}
621
            metrics_val, fnames_val = \
622
                self.run_one_epoch(epoch_idx, loaders)
623
624
            # Process the accumulated metrics
625
            for k, v in metrics_val['batchw'].items():
626
                if k.startswith('loss'):
627
                    metrics_val['datasetw'][k] = np.mean(np.asarray(v))
628
                else:
629
                    logger.warning(f'Non-processed batch-wise entry: {k}')
630
631
            # Learning rate update
632
            for s, m in self.lr_update_rule.items():
633
                if epoch_idx == s:
634
                    for name, optim in self.optimizers.items():
635
                        for param_group in optim.param_groups:
636
                            param_group['lr'] *= m
637
638
            # Add console logging
639
            logger.info(f'Epoch: {epoch_idx}')
640
            for subset, metrics in (('train', metrics_train),
641
                                    ('val', metrics_val)):
642
                logger.info(f'{subset} metrics:')
643
                for k, v in metrics['datasetw'].items():
644
                    logger.info(f'{k}: \n{v}')
645
646
            # Add TensorBoard logging
647
            for subset, metrics in (('train', metrics_train),
648
                                    ('val', metrics_val)):
649
                # Log only dataset-reduced metrics
650
                for k, v in metrics['datasetw'].items():
651
                    if isinstance(v, np.ndarray):
652
                        self.tensorboard.add_scalars(
653
                            f'fold_{self.fold_idx}/{k}_{subset}',
654
                            {f'class{i}': e for i, e in enumerate(v.ravel().tolist())},
655
                            global_step=epoch_idx)
656
                    elif isinstance(v, (str, int, float)):
657
                        self.tensorboard.add_scalar(
658
                            f'fold_{self.fold_idx}/{k}_{subset}',
659
                            float(v),
660
                            global_step=epoch_idx)
661
                    else:
662
                        logger.warning(f'{k} is of unsupported dtype {v}')
663
            for name, optim in self.optimizers.items():
664
                for param_group in optim.param_groups:
665
                    self.tensorboard.add_scalar(
666
                        f'fold_{self.fold_idx}/learning_rate/{name}',
667
                        param_group['lr'],
668
                        global_step=epoch_idx)
669
670
            # Save the model
671
            loss_curr = metrics_val['datasetw']['loss_total']
672
            # Alternative validation criterion:
673
            # loss_curr = metrics_val['datasetw']['loss_segm'] + metrics_val['datasetw']['loss_advers']
674
            if loss_curr < loss_best:
675
                loss_best = loss_curr
676
                epoch_idx_best = epoch_idx
677
                metrics_train_best = metrics_train
678
                metrics_val_best = metrics_val
679
                fnames_train_best = fnames_train
680
                fnames_val_best = fnames_val
681
682
                self.handlers_ckpt['segm'].save_new_ckpt(
683
                    model=self.models['segm'],
684
                    model_name=self.config['model_segm'],
685
                    fold_idx=self.fold_idx,
686
                    epoch_idx=epoch_idx)
687
                self.handlers_ckpt['discr_aux'].save_new_ckpt(
688
                    model=self.models['discr_aux'],
689
                    model_name=self.config['model_discr_aux'],
690
                    fold_idx=self.fold_idx,
691
                    epoch_idx=epoch_idx)
692
                self.handlers_ckpt['discr_out'].save_new_ckpt(
693
                    model=self.models['discr_out'],
694
                    model_name=self.config['model_discr_out'],
695
                    fold_idx=self.fold_idx,
696
                    epoch_idx=epoch_idx)
697
698
        msg = (f'Finished fold {self.fold_idx} '
699
               f'with the best loss {loss_best:.5f} '
700
               f'on epoch {epoch_idx_best}, '
701
               f'weights: ({self.paths_weights_fold})')
702
        logger.info(msg)
703
        return (metrics_train_best, fnames_train_best,
704
                metrics_val_best, fnames_val_best)
705
706
707
@click.command()
708
@click.option('--path_data_root', default='../../data')
709
@click.option('--path_experiment_root', default='../../results/temporary')
710
@click.option('--model_segm', default='unet_lext')
711
@click.option('--center_depth', default=1, type=int)
712
@click.option('--model_discr_out', default='discriminator_a')
713
@click.option('--model_discr_aux', default='discriminator_a')
714
@click.option('--pretrained', is_flag=True)
715
@click.option('--path_pretrained_segm', type=str, help='Path to .pth file')
716
@click.option('--restore_weights', is_flag=True)
717
@click.option('--input_channels', default=1, type=int)
718
@click.option('--output_channels', default=1, type=int)
719
@click.option('--mask_mode', default='', type=str)
720
@click.option('--sample_mode', default='x_y', type=str)
721
@click.option('--loss_segm', default='multi_ce_loss')
722
@click.option('--lr_segm', default=0.0001, type=float)
723
@click.option('--lr_discr', default=0.0001, type=float)
724
@click.option('--optimizer_segm', default='adam')
725
@click.option('--optimizer_discr', default='adam')
726
@click.option('--batch_size', default=64, type=int)
727
@click.option('--epoch_size', default=1.0, type=float)
728
@click.option('--epoch_num', default=2, type=int)
729
@click.option('--fold_num', default=5, type=int)
730
@click.option('--fold_idx', default=-1, type=int)
731
@click.option('--fold_idx_ignore', multiple=True, type=int)
732
@click.option('--num_workers', default=1, type=int)
733
@click.option('--seed_trainval_test', default=0, type=int)
734
def main(**config):
735
    config['path_data_root'] = os.path.abspath(config['path_data_root'])
736
    config['path_experiment_root'] = os.path.abspath(config['path_experiment_root'])
737
738
    config['path_weights'] = os.path.join(config['path_experiment_root'], 'weights')
739
    config['path_logs'] = os.path.join(config['path_experiment_root'], 'logs_train')
740
    os.makedirs(config['path_weights'], exist_ok=True)
741
    os.makedirs(config['path_logs'], exist_ok=True)
742
743
    logging_fh = logging.FileHandler(
744
        os.path.join(config['path_logs'], 'main_{}.log'.format(config['fold_idx'])))
745
    logging_fh.setLevel(logging.DEBUG)
746
    logger.addHandler(logging_fh)
747
748
    # Collect the available and specified sources
749
    sources = sources_from_path(path_data_root=config['path_data_root'],
750
                                selection=('oai_imo', 'okoa', 'maknee'),
751
                                with_folds=True,
752
                                fold_num=config['fold_num'],
753
                                seed_trainval_test=config['seed_trainval_test'])
754
755
    # Build a list of folds to run on
756
    if config['fold_idx'] == -1:
757
        fold_idcs = list(range(config['fold_num']))
758
    else:
759
        fold_idcs = [config['fold_idx'], ]
760
    for g in config['fold_idx_ignore']:
761
        fold_idcs = [i for i in fold_idcs if i != g]
762
763
    # Train each fold separately
764
    fold_scores = dict()
765
766
    # Use straightforward fold allocation strategy
767
    folds = list(zip(sources['oai_imo']['trainval_folds'],
768
                     sources['okoa']['trainval_folds'],
769
                     sources['maknee']['trainval_folds']))
770
771
    for fold_idx, idcs_subsets in enumerate(folds):
772
        if fold_idx not in fold_idcs:
773
            continue
774
        logger.info(f'Training fold {fold_idx}')
775
776
        (sources['oai_imo']['train_idcs'], sources['oai_imo']['val_idcs']) = idcs_subsets[0]
777
        (sources['okoa']['train_idcs'], sources['okoa']['val_idcs']) = idcs_subsets[1]
778
        (sources['maknee']['train_idcs'], sources['maknee']['val_idcs']) = idcs_subsets[2]
779
780
        sources['oai_imo']['train_df'] = sources['oai_imo']['trainval_df'].iloc[sources['oai_imo']['train_idcs']]
781
        sources['oai_imo']['val_df'] = sources['oai_imo']['trainval_df'].iloc[sources['oai_imo']['val_idcs']]
782
        sources['okoa']['train_df'] = sources['okoa']['trainval_df'].iloc[sources['okoa']['train_idcs']]
783
        sources['okoa']['val_df'] = sources['okoa']['trainval_df'].iloc[sources['okoa']['val_idcs']]
784
        sources['maknee']['train_df'] = sources['maknee']['trainval_df'].iloc[sources['maknee']['train_idcs']]
785
        sources['maknee']['val_df'] = sources['maknee']['trainval_df'].iloc[sources['maknee']['val_idcs']]
786
787
        for n, s in sources.items():
788
            logger.info('Made {} train-val split, number of samples: {}, {}'
789
                        .format(n, len(s['train_df']), len(s['val_df'])))
790
791
        datasets = defaultdict(dict)
792
793
        datasets['oai_imo']['train'] = DatasetOAIiMoSagittal2d(
794
            df_meta=sources['oai_imo']['train_df'],
795
            mask_mode=config['mask_mode'],
796
            sample_mode=config['sample_mode'],
797
            transforms=[
798
                PercentileClippingAndToFloat(cut_min=10, cut_max=99),
799
                CenterCrop(height=300, width=300),
800
                HorizontalFlip(prob=.5),
801
                GammaCorrection(gamma_range=(0.5, 1.5), prob=.5),
802
                OneOf([
803
                    DualCompose([
804
                        Scale(ratio_range=(0.7, 0.8), prob=1.),
805
                        Scale(ratio_range=(1.5, 1.6), prob=1.),
806
                    ]),
807
                    NoTransform()
808
                ]),
809
                Crop(output_size=(300, 300)),
810
                BilateralFilter(d=5, sigma_color=50, sigma_space=50, prob=.3),
811
                Normalize(mean=0.252699, std=0.251142),
812
                ToTensor(),
813
            ])
814
        datasets['okoa']['train'] = DatasetOKOASagittal2d(
815
            df_meta=sources['okoa']['train_df'],
816
            mask_mode='background_femoral_unitibial',
817
            sample_mode=config['sample_mode'],
818
            transforms=[
819
                PercentileClippingAndToFloat(cut_min=10, cut_max=99),
820
                CenterCrop(height=300, width=300),
821
                HorizontalFlip(prob=.5),
822
                GammaCorrection(gamma_range=(0.5, 1.5), prob=.5),
823
                OneOf([
824
                    DualCompose([
825
                        Scale(ratio_range=(0.7, 0.8), prob=1.),
826
                        Scale(ratio_range=(1.5, 1.6), prob=1.),
827
                    ]),
828
                    NoTransform()
829
                ]),
830
                Crop(output_size=(300, 300)),
831
                BilateralFilter(d=5, sigma_color=50, sigma_space=50, prob=.3),
832
                Normalize(mean=0.252699, std=0.251142),
833
                ToTensor(),
834
            ])
835
        datasets['maknee']['train'] = DatasetMAKNEESagittal2d(
836
            df_meta=sources['maknee']['train_df'],
837
            mask_mode='',
838
            sample_mode=config['sample_mode'],
839
            transforms=[
840
                PercentileClippingAndToFloat(cut_min=10, cut_max=99),
841
                CenterCrop(height=300, width=300),
842
                HorizontalFlip(prob=.5),
843
                GammaCorrection(gamma_range=(0.5, 1.5), prob=.5),
844
                OneOf([
845
                    DualCompose([
846
                        Scale(ratio_range=(0.7, 0.8), prob=1.),
847
                        Scale(ratio_range=(1.5, 1.6), prob=1.),
848
                    ]),
849
                    NoTransform()
850
                ]),
851
                Crop(output_size=(300, 300)),
852
                BilateralFilter(d=5, sigma_color=50, sigma_space=50, prob=.3),
853
                Normalize(mean=0.252699, std=0.251142),
854
                ToTensor(),
855
            ])
856
        datasets['oai_imo']['val'] = DatasetOAIiMoSagittal2d(
857
            df_meta=sources['oai_imo']['val_df'],
858
            mask_mode=config['mask_mode'],
859
            sample_mode=config['sample_mode'],
860
            transforms=[
861
                PercentileClippingAndToFloat(cut_min=10, cut_max=99),
862
                CenterCrop(height=300, width=300),
863
                Normalize(mean=0.252699, std=0.251142),
864
                ToTensor()
865
            ])
866
        datasets['okoa']['val'] = DatasetOKOASagittal2d(
867
            df_meta=sources['okoa']['val_df'],
868
            mask_mode='background_femoral_unitibial',
869
            sample_mode=config['sample_mode'],
870
            transforms=[
871
                PercentileClippingAndToFloat(cut_min=10, cut_max=99),
872
                CenterCrop(height=300, width=300),
873
                Normalize(mean=0.252699, std=0.251142),
874
                ToTensor()
875
            ])
876
        datasets['maknee']['val'] = DatasetMAKNEESagittal2d(
877
            df_meta=sources['maknee']['val_df'],
878
            mask_mode='',
879
            sample_mode=config['sample_mode'],
880
            transforms=[
881
                PercentileClippingAndToFloat(cut_min=10, cut_max=99),
882
                CenterCrop(height=300, width=300),
883
                Normalize(mean=0.252699, std=0.251142),
884
                ToTensor()
885
            ])
886
887
        loaders = defaultdict(dict)
888
889
        loaders['oai_imo']['train'] = DataLoader(
890
            datasets['oai_imo']['train'],
891
            batch_size=int(config['batch_size'] / 2),
892
            shuffle=True,
893
            num_workers=config['num_workers'],
894
            drop_last=True)
895
        loaders['oai_imo']['val'] = DataLoader(
896
            datasets['oai_imo']['val'],
897
            batch_size=int(config['batch_size'] / 2),
898
            shuffle=False,
899
            num_workers=config['num_workers'],
900
            drop_last=True)
901
        loaders['okoa']['train'] = DataLoader(
902
            datasets['okoa']['train'],
903
            batch_size=int(config['batch_size'] / 2),
904
            shuffle=True,
905
            num_workers=config['num_workers'],
906
            drop_last=True)
907
        loaders['okoa']['val'] = DataLoader(
908
            datasets['okoa']['val'],
909
            batch_size=int(config['batch_size'] / 2),
910
            shuffle=False,
911
            num_workers=config['num_workers'],
912
            drop_last=True)
913
        loaders['maknee']['train'] = DataLoader(
914
            datasets['maknee']['train'],
915
            batch_size=int(config['batch_size'] / 2),
916
            shuffle=True,
917
            num_workers=config['num_workers'],
918
            drop_last=True)
919
        loaders['maknee']['val'] = DataLoader(
920
            datasets['maknee']['val'],
921
            batch_size=int(config['batch_size'] / 2),
922
            shuffle=False,
923
            num_workers=config['num_workers'],
924
            drop_last=True)
925
926
        trainer = ModelTrainer(config=config, fold_idx=fold_idx)
927
928
        tmp = trainer.fit(loaders=loaders)
929
        metrics_train, fnames_train, metrics_val, fnames_val = tmp
930
931
        fold_scores[fold_idx] = (metrics_val['datasetw']['dice_score_oai'],
932
                                 metrics_val['datasetw']['dice_score_okoa'])
933
        trainer.tensorboard.close()
934
    logger.info(f'Fold scores:\n{repr(fold_scores)}')
935
936
937
if __name__ == '__main__':
938
    main()