--- a +++ b/rocaseg/train_uda1.py @@ -0,0 +1,795 @@ +import os +import logging +from collections import defaultdict +import gc +import click +import resource + +import numpy as np +import cv2 + +import torch +import torch.nn.functional as torch_fn +from torch import nn +from torch.utils.data.dataloader import DataLoader +from torch.utils.tensorboard import SummaryWriter +from tqdm import tqdm + +from rocaseg.datasets import (DatasetOAIiMoSagittal2d, + DatasetOKOASagittal2d, + DatasetMAKNEESagittal2d, + sources_from_path) +from rocaseg.models import dict_models +from rocaseg.components import (dict_losses, confusion_matrix, dice_score_from_cm, + dict_optimizers, CheckpointHandler) +from rocaseg.preproc import * +from rocaseg.repro import set_ultimate_seed +from rocaseg.components.mixup import mixup_criterion, mixup_data + + +rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) +resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1])) + +cv2.ocl.setUseOpenCL(False) +cv2.setNumThreads(0) + +logging.basicConfig() +logger = logging.getLogger('train') +logger.setLevel(logging.DEBUG) + +set_ultimate_seed() + +if torch.cuda.is_available(): + maybe_gpu = 'cuda' +else: + maybe_gpu = 'cpu' + + +class ModelTrainer: + def __init__(self, config, fold_idx=None): + self.config = config + self.fold_idx = fold_idx + + self.paths_weights_fold = dict() + self.paths_weights_fold['segm'] = \ + os.path.join(config['path_weights'], 'segm', f'fold_{self.fold_idx}') + os.makedirs(self.paths_weights_fold['segm'], exist_ok=True) + self.paths_weights_fold['discr'] = \ + os.path.join(config['path_weights'], 'discr', f'fold_{self.fold_idx}') + os.makedirs(self.paths_weights_fold['discr'], exist_ok=True) + + self.path_logs_fold = \ + os.path.join(config['path_logs'], f'fold_{self.fold_idx}') + os.makedirs(self.path_logs_fold, exist_ok=True) + + self.handlers_ckpt = dict() + self.handlers_ckpt['segm'] = CheckpointHandler(self.paths_weights_fold['segm']) + self.handlers_ckpt['discr'] = CheckpointHandler(self.paths_weights_fold['discr']) + + paths_ckpt_sel = dict() + paths_ckpt_sel['segm'] = self.handlers_ckpt['segm'].get_last_ckpt() + paths_ckpt_sel['discr'] = self.handlers_ckpt['discr'].get_last_ckpt() + + # Initialize and configure the models + self.models = dict() + self.models['segm'] = (dict_models[config['model_segm']] + (input_channels=self.config['input_channels'], + output_channels=self.config['output_channels'], + center_depth=self.config['center_depth'], + pretrained=self.config['pretrained'], + path_pretrained=self.config['path_pretrained_segm'], + restore_weights=self.config['restore_weights'], + path_weights=paths_ckpt_sel['segm'])) + self.models['segm'] = nn.DataParallel(self.models['segm']) + self.models['segm'] = self.models['segm'].to(maybe_gpu) + + self.models['discr'] = (dict_models[config['model_discr']] + (input_channels=self.config['output_channels'], + output_channels=1, + pretrained=self.config['pretrained'], + restore_weights=self.config['restore_weights'], + path_weights=paths_ckpt_sel['discr'])) + self.models['discr'] = nn.DataParallel(self.models['discr']) + self.models['discr'] = self.models['discr'].to(maybe_gpu) + + # Configure the training + self.optimizers = dict() + self.optimizers['segm'] = (dict_optimizers['adam']( + self.models['segm'].parameters(), + lr=self.config['lr_segm'], + weight_decay=self.config['wd_segm'])) + self.optimizers['discr'] = (dict_optimizers['adam']( + self.models['discr'].parameters(), + lr=self.config['lr_discr'], + weight_decay=self.config['wd_discr'])) + + self.lr_update_rule = {25: 0.1} + + self.losses = dict() + self.losses['segm'] = dict_losses[self.config['loss_segm']]( + num_classes=self.config['output_channels'], + ) + self.losses['advers'] = dict_losses['bce_loss']() + self.losses['discr'] = dict_losses['bce_loss']() + + self.losses['segm'] = self.losses['segm'].to(maybe_gpu) + self.losses['advers'] = self.losses['advers'].to(maybe_gpu) + self.losses['discr'] = self.losses['discr'].to(maybe_gpu) + + self.tensorboard = SummaryWriter(self.path_logs_fold) + + def run_one_epoch(self, epoch_idx, loaders): + COEFF_DISCR = 1 + COEFF_SEGM = 1 + COEFF_ADVERS = 0.001 + + fnames_acc = defaultdict(list) + metrics_acc = dict() + metrics_acc['samplew'] = defaultdict(list) + metrics_acc['batchw'] = defaultdict(list) + metrics_acc['datasetw'] = defaultdict(list) + metrics_acc['datasetw']['cm_oai'] = \ + np.zeros((self.config['output_channels'],) * 2, dtype=np.uint32) + metrics_acc['datasetw']['cm_okoa'] = \ + np.zeros((self.config['output_channels'],) * 2, dtype=np.uint32) + + prog_bar_params = {'postfix': {'epoch': epoch_idx}, } + + if self.models['segm'].training and self.models['discr'].training: + # ------------------------ Training regime ------------------------ + loader_oai = loaders['oai_imo']['train'] + loader_maknee = loaders['maknee']['train'] + + steps_oai, steps_maknee = len(loader_oai), len(loader_maknee) + steps_total = steps_oai + prog_bar_params.update({'total': steps_total, + 'desc': f'Train, epoch {epoch_idx}'}) + + loader_oai_iter = iter(loader_oai) + loader_maknee_iter = iter(loader_maknee) + + loader_oai_iter_old = None + loader_maknee_iter_old = None + + with tqdm(**prog_bar_params) as prog_bar: + for step_idx in range(steps_total): + self.optimizers['segm'].zero_grad() + self.optimizers['discr'].zero_grad() + + metrics_acc['batchw']['loss_total'].append(0) + + try: + data_batch_oai = next(loader_oai_iter) + except StopIteration: + loader_oai_iter_old = loader_oai_iter + loader_oai_iter = iter(loader_oai) + data_batch_oai = next(loader_oai_iter) + + try: + data_batch_maknee = next(loader_maknee_iter) + except StopIteration: + loader_maknee_iter_old = loader_maknee_iter + loader_maknee_iter = iter(loader_maknee) + data_batch_maknee = next(loader_maknee_iter) + + xs_oai, ys_true_oai = data_batch_oai['xs'], data_batch_oai['ys'] + fnames_acc['oai'].extend(data_batch_oai['path_image']) + xs_oai = xs_oai.to(maybe_gpu) + ys_true_arg_oai = torch.argmax(ys_true_oai.long().to(maybe_gpu), dim=1) + + xs_maknee, _ = data_batch_maknee['xs'], data_batch_maknee['ys'] + fnames_acc['maknee'].extend(data_batch_maknee['path_image']) + xs_maknee = xs_maknee.to(maybe_gpu) + + # -------------- Train discriminator network ------------- + # With source + ys_pred_oai = self.models['segm'](xs_oai) + ys_pred_softmax_oai = torch_fn.softmax(ys_pred_oai, dim=1) + + zs_pred_oai = self.models['discr'](ys_pred_softmax_oai) + + # Use 0 as a label for the source domain + loss_discr_0 = self.losses['discr']( + input=zs_pred_oai, + target=torch.zeros_like(zs_pred_oai, device=maybe_gpu)) + loss_discr_0 = loss_discr_0 / 2 * COEFF_DISCR + loss_discr_0.backward(retain_graph=True) + metrics_acc['batchw']['loss_discr_0'].append(loss_discr_0.item()) + metrics_acc['batchw']['loss_total'][-1] += \ + metrics_acc['batchw']['loss_discr_0'][-1] + + # With target + self.models['segm'] = self.models['segm'].eval() + ys_pred_maknee = self.models['segm'](xs_maknee) + self.models['segm'] = self.models['segm'].train() + + ys_pred_softmax_maknee = torch_fn.softmax(ys_pred_maknee, dim=1) + zs_pred_maknee = self.models['discr'](ys_pred_softmax_maknee) + + # Use 1 as a label for the target domain + loss_discr_1 = self.losses['discr']( + input=zs_pred_maknee, + target=torch.ones_like(zs_pred_maknee, device=maybe_gpu)) + loss_discr_1 = loss_discr_1 / 2 * COEFF_DISCR + loss_discr_1.backward() + metrics_acc['batchw']['loss_discr_1'].append(loss_discr_1.item()) + metrics_acc['batchw']['loss_total'][-1] += \ + metrics_acc['batchw']['loss_discr_1'][-1] + + self.models['segm'].zero_grad() + self.optimizers['discr'].step() + self.models['discr'].zero_grad() + + # ---------------- Train segmentation network ------------ + # With source + if not self.config['with_mixup']: + ys_pred_oai = self.models['segm'](xs_oai) + loss_segm = self.losses['segm'](input_=ys_pred_oai, + target=ys_true_arg_oai) + else: + xs_mixup, ys_mixup_a, ys_mixup_b, lambda_mixup = mixup_data( + x=xs_oai, y=ys_true_arg_oai, + alpha=self.config['mixup_alpha'], device=maybe_gpu) + ys_pred_oai = self.models['segm'](xs_mixup) + loss_segm = mixup_criterion(criterion=self.losses['segm'], + pred=ys_pred_oai, + y_a=ys_mixup_a, + y_b=ys_mixup_b, + lam=lambda_mixup) + + loss_segm.backward(retain_graph=True) + loss_segm = loss_segm * COEFF_SEGM + metrics_acc['batchw']['loss_segm'].append(loss_segm.item()) + metrics_acc['batchw']['loss_total'][-1] += \ + metrics_acc['batchw']['loss_segm'][-1] + + # With target + self.models['segm'] = self.models['segm'].eval() + ys_pred_maknee = self.models['segm'](xs_maknee) + self.models['segm'] = self.models['segm'].train() + + ys_pred_softmax_maknee = torch_fn.softmax(ys_pred_maknee, dim=1) + zs_pred_maknee = self.models['discr'](ys_pred_softmax_maknee) + + # Use 0 as a label for the source domain + loss_advers = self.losses['advers']( + input=zs_pred_maknee, + target=torch.zeros_like(zs_pred_maknee, device=maybe_gpu)) + loss_advers = loss_advers * COEFF_ADVERS + loss_advers.backward() + metrics_acc['batchw']['loss_advers'].append(loss_advers.item()) + metrics_acc['batchw']['loss_total'][-1] += \ + metrics_acc['batchw']['loss_advers'][-1] + + self.models['discr'].zero_grad() + self.optimizers['segm'].step() + + if step_idx % 10 == 0: + self.tensorboard.add_scalars( + f'fold_{self.fold_idx}/losses_train', + {'discr_0_batchw': metrics_acc['batchw']['loss_discr_0'][-1], + 'discr_1_batchw': metrics_acc['batchw']['loss_discr_1'][-1], + 'discr_sum_batchw': + (metrics_acc['batchw']['loss_discr_0'][-1] + + metrics_acc['batchw']['loss_discr_1'][-1]), + 'segm_batchw': metrics_acc['batchw']['loss_segm'][-1], + 'advers_batchw': + metrics_acc['batchw']['loss_advers'][-1], + 'total_batchw': metrics_acc['batchw']['loss_total'][-1], + }, global_step=(epoch_idx * steps_total + step_idx)) + + prog_bar.update(1) + + del [loader_oai_iter_old, loader_maknee_iter_old] + gc.collect() + else: + # ----------------------- Validation regime ----------------------- + loader_oai = loaders['oai_imo']['val'] + loader_okoa = loaders['okoa']['val'] + loader_maknee = loaders['maknee']['val'] + + steps_oai, steps_okoa, steps_maknee = len(loader_oai), len(loader_okoa), len(loader_maknee) + steps_total = steps_oai + prog_bar_params.update({'total': steps_total, + 'desc': f'Validate, epoch {epoch_idx}'}) + + loader_oai_iter = iter(loader_oai) + loader_okoa_iter = iter(loader_okoa) + loader_maknee_iter = iter(loader_maknee) + + loader_oai_iter_old = None + loader_okoa_iter_old = None + loader_maknee_iter_old = None + + with torch.no_grad(), tqdm(**prog_bar_params) as prog_bar: + for step_idx in range(steps_total): + metrics_acc['batchw']['loss_total'].append(0) + + try: + data_batch_oai = next(loader_oai_iter) + except StopIteration: + loader_oai_iter_old = loader_oai_iter + loader_oai_iter = iter(loader_oai) + data_batch_oai = next(loader_oai_iter) + + try: + data_batch_okoa = next(loader_okoa_iter) + except StopIteration: + loader_okoa_iter_old = loader_okoa_iter + loader_okoa_iter = iter(loader_okoa) + data_batch_okoa = next(loader_okoa_iter) + + try: + data_batch_maknee = next(loader_maknee_iter) + except StopIteration: + loader_maknee_iter_old = loader_maknee_iter + loader_maknee_iter = iter(loader_maknee) + data_batch_maknee = next(loader_maknee_iter) + + xs_oai, ys_true_oai = data_batch_oai['xs'], data_batch_oai['ys'] + fnames_acc['oai'].extend(data_batch_oai['path_image']) + xs_oai = xs_oai.to(maybe_gpu) + ys_true_arg_oai = torch.argmax(ys_true_oai.long().to(maybe_gpu), dim=1) + + xs_maknee, _ = data_batch_maknee['xs'], data_batch_maknee['ys'] + fnames_acc['maknee'].extend(data_batch_maknee['path_image']) + xs_maknee = xs_maknee.to(maybe_gpu) + + # -------------- Validate discriminator network ------------- + # With source + ys_pred_oai = self.models['segm'](xs_oai) + ys_pred_softmax_oai = torch_fn.softmax(ys_pred_oai, dim=1) + + zs_pred_oai = self.models['discr'](ys_pred_softmax_oai) + + # Use 0 as a label for the source domain + loss_discr_0 = self.losses['discr']( + input=zs_pred_oai, + target=torch.zeros_like(zs_pred_oai, device=maybe_gpu)) + loss_discr_0 = loss_discr_0 / 2 * COEFF_DISCR + metrics_acc['batchw']['loss_discr_0'].append(loss_discr_0.item()) + metrics_acc['batchw']['loss_total'][-1] += \ + metrics_acc['batchw']['loss_discr_0'][-1] + + # With target + ys_pred_maknee = self.models['segm'](xs_maknee) + + ys_pred_softmax_maknee = torch_fn.softmax(ys_pred_maknee, dim=1) + zs_pred_maknee = self.models['discr'](ys_pred_softmax_maknee) + + # Use 1 as a label for the target domain + loss_discr_1 = self.losses['discr']( + input=zs_pred_maknee, + target=torch.ones_like(zs_pred_oai, device=maybe_gpu)) + loss_discr_1 = loss_discr_1 / 2 * COEFF_DISCR + metrics_acc['batchw']['loss_discr_1'].append(loss_discr_1.item()) + metrics_acc['batchw']['loss_total'][-1] += \ + metrics_acc['batchw']['loss_discr_1'][-1] + + # ---------------- Validate segmentation network ------------ + # With source + if not self.config['with_mixup']: + ys_pred_oai = self.models['segm'](xs_oai) + loss_segm = self.losses['segm'](input_=ys_pred_oai, + target=ys_true_arg_oai) + else: + xs_mixup, ys_mixup_a, ys_mixup_b, lambda_mixup = mixup_data( + x=xs_oai, y=ys_true_arg_oai, + alpha=self.config['mixup_alpha'], device=maybe_gpu) + ys_pred_oai = self.models['segm'](xs_mixup) + loss_segm = mixup_criterion(criterion=self.losses['segm'], + pred=ys_pred_oai, + y_a=ys_mixup_a, + y_b=ys_mixup_b, + lam=lambda_mixup) + + loss_segm = loss_segm * COEFF_SEGM + metrics_acc['batchw']['loss_segm'].append(loss_segm.item()) + metrics_acc['batchw']['loss_total'][-1] += \ + metrics_acc['batchw']['loss_segm'][-1] + + # With target + ys_pred_maknee = self.models['segm'](xs_maknee) + + ys_pred_softmax_maknee = torch_fn.softmax(ys_pred_maknee, dim=1) + zs_pred_maknee = self.models['discr'](ys_pred_softmax_maknee) + + # Use 0 as a label for the source domain + loss_advers = self.losses['advers']( + input=zs_pred_maknee, + target=torch.zeros_like(zs_pred_maknee, device=maybe_gpu)) + loss_advers = loss_advers * COEFF_ADVERS + metrics_acc['batchw']['loss_advers'].append(loss_advers.item()) + metrics_acc['batchw']['loss_total'][-1] += \ + metrics_acc['batchw']['loss_advers'][-1] + + if step_idx % 10 == 0: + self.tensorboard.add_scalars( + f'fold_{self.fold_idx}/losses_val', + {'discr_0_batchw': metrics_acc['batchw']['loss_discr_0'][-1], + 'discr_1_batchw': metrics_acc['batchw']['loss_discr_1'][-1], + 'discr_sum_batchw': + (metrics_acc['batchw']['loss_discr_0'][-1] + + metrics_acc['batchw']['loss_discr_1'][-1]), + 'segm_batchw': metrics_acc['batchw']['loss_segm'][-1], + 'advers_batchw': + metrics_acc['batchw']['loss_advers'][-1], + 'total_batchw': metrics_acc['batchw']['loss_total'][-1], + }, global_step=(epoch_idx * steps_total + step_idx)) + + # ------------------ Calculate metrics ------------------- + + ys_pred_arg_np_oai = torch.argmax(ys_pred_softmax_oai, 1).to('cpu').numpy() + ys_true_arg_np_oai = ys_true_arg_oai.to('cpu').numpy() + + metrics_acc['datasetw']['cm_oai'] += confusion_matrix( + ys_pred_arg_np_oai, ys_true_arg_np_oai, + self.config['output_channels']) + + # Don't consider repeating entries for the metrics calculation + if step_idx < steps_okoa: + xs_okoa, ys_true_okoa = data_batch_okoa['xs'], data_batch_okoa['ys'] + fnames_acc['okoa'].extend(data_batch_okoa['path_image']) + xs_okoa = xs_okoa.to(maybe_gpu) + + ys_pred_okoa = self.models['segm'](xs_okoa) + + ys_true_arg_okoa = torch.argmax(ys_true_okoa.long().to(maybe_gpu), dim=1) + ys_pred_softmax_okoa = torch_fn.softmax(ys_pred_okoa, dim=1) + + ys_pred_arg_np_okoa = torch.argmax(ys_pred_softmax_okoa, 1).to('cpu').numpy() + ys_true_arg_np_okoa = ys_true_arg_okoa.to('cpu').numpy() + + metrics_acc['datasetw']['cm_okoa'] += confusion_matrix( + ys_pred_arg_np_okoa, ys_true_arg_np_okoa, + self.config['output_channels']) + + prog_bar.update(1) + + del [loader_oai_iter_old, loader_okoa_iter_old, loader_maknee_iter_old] + gc.collect() + + for k, v in metrics_acc['samplew'].items(): + metrics_acc['samplew'][k] = np.asarray(v) + metrics_acc['datasetw']['dice_score_oai'] = np.asarray( + dice_score_from_cm(metrics_acc['datasetw']['cm_oai'])) + metrics_acc['datasetw']['dice_score_okoa'] = np.asarray( + dice_score_from_cm(metrics_acc['datasetw']['cm_okoa'])) + return metrics_acc, fnames_acc + + def fit(self, loaders): + epoch_idx_best = -1 + loss_best = float('inf') + metrics_train_best = dict() + fnames_train_best = [] + metrics_val_best = dict() + fnames_val_best = [] + + for epoch_idx in range(self.config['epoch_num']): + self.models = {n: m.train() for n, m in self.models.items()} + metrics_train, fnames_train = \ + self.run_one_epoch(epoch_idx, loaders) + + # Process the accumulated metrics + for k, v in metrics_train['batchw'].items(): + if k.startswith('loss'): + metrics_train['datasetw'][k] = np.mean(np.asarray(v)) + else: + logger.warning(f'Non-processed batch-wise entry: {k}') + + self.models = {n: m.eval() for n, m in self.models.items()} + metrics_val, fnames_val = \ + self.run_one_epoch(epoch_idx, loaders) + + # Process the accumulated metrics + for k, v in metrics_val['batchw'].items(): + if k.startswith('loss'): + metrics_val['datasetw'][k] = np.mean(np.asarray(v)) + else: + logger.warning(f'Non-processed batch-wise entry: {k}') + + # Learning rate update + for s, m in self.lr_update_rule.items(): + if epoch_idx == s: + for name, optim in self.optimizers.items(): + for param_group in optim.param_groups: + param_group['lr'] *= m + + # Add console logging + logger.info(f'Epoch: {epoch_idx}') + for subset, metrics in (('train', metrics_train), + ('val', metrics_val)): + logger.info(f'{subset} metrics:') + for k, v in metrics['datasetw'].items(): + logger.info(f'{k}: \n{v}') + + # Add TensorBoard logging + for subset, metrics in (('train', metrics_train), + ('val', metrics_val)): + # Log only dataset-reduced metrics + for k, v in metrics['datasetw'].items(): + if isinstance(v, np.ndarray): + self.tensorboard.add_scalars( + f'fold_{self.fold_idx}/{k}_{subset}', + {f'class{i}': e for i, e in enumerate(v.ravel().tolist())}, + global_step=epoch_idx) + elif isinstance(v, (str, int, float)): + self.tensorboard.add_scalar( + f'fold_{self.fold_idx}/{k}_{subset}', + float(v), + global_step=epoch_idx) + else: + logger.warning(f'{k} is of unsupported dtype {v}') + for name, optim in self.optimizers.items(): + for param_group in optim.param_groups: + self.tensorboard.add_scalar( + f'fold_{self.fold_idx}/learning_rate/{name}', + param_group['lr'], + global_step=epoch_idx) + + # Save the model + loss_curr = metrics_val['datasetw']['loss_total'] + if loss_curr < loss_best: + loss_best = loss_curr + epoch_idx_best = epoch_idx + metrics_train_best = metrics_train + metrics_val_best = metrics_val + fnames_train_best = fnames_train + fnames_val_best = fnames_val + + self.handlers_ckpt['segm'].save_new_ckpt( + model=self.models['segm'], + model_name=self.config['model_segm'], + fold_idx=self.fold_idx, + epoch_idx=epoch_idx) + self.handlers_ckpt['discr'].save_new_ckpt( + model=self.models['discr'], + model_name=self.config['model_discr'], + fold_idx=self.fold_idx, + epoch_idx=epoch_idx) + + msg = (f'Finished fold {self.fold_idx} ' + f'with the best loss {loss_best:.5f} ' + f'on epoch {epoch_idx_best}, ' + f'weights: ({self.paths_weights_fold})') + logger.info(msg) + return (metrics_train_best, fnames_train_best, + metrics_val_best, fnames_val_best) + + +@click.command() +@click.option('--path_data_root', default='../../data') +@click.option('--path_experiment_root', default='../../results/temporary') +@click.option('--model_segm', default='unet_lext') +@click.option('--center_depth', default=1, type=int) +@click.option('--model_discr', default='discriminator_a') +@click.option('--pretrained', is_flag=True) +@click.option('--path_pretrained_segm', type=str, help='Path to .pth file') +@click.option('--restore_weights', is_flag=True) +@click.option('--input_channels', default=1, type=int) +@click.option('--output_channels', default=1, type=int) +@click.option('--mask_mode', default='all_unitibial_unimeniscus', type=str) +@click.option('--sample_mode', default='x_y', type=str) +@click.option('--loss_segm', default='multi_ce_loss') +@click.option('--lr_segm', default=0.0001, type=float) +@click.option('--lr_discr', default=0.0001, type=float) +@click.option('--wd_segm', default=5e-5, type=float) +@click.option('--wd_discr', default=5e-5, type=float) +@click.option('--optimizer_segm', default='adam') +@click.option('--optimizer_discr', default='adam') +@click.option('--batch_size', default=64, type=int) +@click.option('--epoch_size', default=1.0, type=float) +@click.option('--epoch_num', default=2, type=int) +@click.option('--fold_num', default=5, type=int) +@click.option('--fold_idx', default=-1, type=int) +@click.option('--fold_idx_ignore', multiple=True, type=int) +@click.option('--num_workers', default=1, type=int) +@click.option('--seed_trainval_test', default=0, type=int) +@click.option('--with_mixup', is_flag=True) +@click.option('--mixup_alpha', default=1, type=float) +def main(**config): + config['path_data_root'] = os.path.abspath(config['path_data_root']) + config['path_experiment_root'] = os.path.abspath(config['path_experiment_root']) + + config['path_weights'] = os.path.join(config['path_experiment_root'], 'weights') + config['path_logs'] = os.path.join(config['path_experiment_root'], 'logs_train') + os.makedirs(config['path_weights'], exist_ok=True) + os.makedirs(config['path_logs'], exist_ok=True) + + logging_fh = logging.FileHandler( + os.path.join(config['path_logs'], 'main_{}.log'.format(config['fold_idx']))) + logging_fh.setLevel(logging.DEBUG) + logger.addHandler(logging_fh) + + # Collect the available and specified sources + sources = sources_from_path(path_data_root=config['path_data_root'], + selection=('oai_imo', 'okoa', 'maknee'), + with_folds=True, + fold_num=config['fold_num'], + seed_trainval_test=config['seed_trainval_test']) + + # Build a list of folds to run on + if config['fold_idx'] == -1: + fold_idcs = list(range(config['fold_num'])) + else: + fold_idcs = [config['fold_idx'], ] + for g in config['fold_idx_ignore']: + fold_idcs = [i for i in fold_idcs if i != g] + + # Train each fold separately + fold_scores = dict() + + # Use straightforward fold allocation strategy + folds = list(zip(sources['oai_imo']['trainval_folds'], + sources['okoa']['trainval_folds'], + sources['maknee']['trainval_folds'])) + + for fold_idx, idcs_subsets in enumerate(folds): + if fold_idx not in fold_idcs: + continue + logger.info(f'Training fold {fold_idx}') + + (sources['oai_imo']['train_idcs'], sources['oai_imo']['val_idcs']) = idcs_subsets[0] + (sources['okoa']['train_idcs'], sources['okoa']['val_idcs']) = idcs_subsets[1] + (sources['maknee']['train_idcs'], sources['maknee']['val_idcs']) = idcs_subsets[2] + + sources['oai_imo']['train_df'] = sources['oai_imo']['trainval_df'].iloc[sources['oai_imo']['train_idcs']] + sources['oai_imo']['val_df'] = sources['oai_imo']['trainval_df'].iloc[sources['oai_imo']['val_idcs']] + sources['okoa']['train_df'] = sources['okoa']['trainval_df'].iloc[sources['okoa']['train_idcs']] + sources['okoa']['val_df'] = sources['okoa']['trainval_df'].iloc[sources['okoa']['val_idcs']] + sources['maknee']['train_df'] = sources['maknee']['trainval_df'].iloc[sources['maknee']['train_idcs']] + sources['maknee']['val_df'] = sources['maknee']['trainval_df'].iloc[sources['maknee']['val_idcs']] + + for n, s in sources.items(): + logger.info('Made {} train-val split, number of samples: {}, {}' + .format(n, len(s['train_df']), len(s['val_df']))) + + datasets = defaultdict(dict) + + datasets['oai_imo']['train'] = DatasetOAIiMoSagittal2d( + df_meta=sources['oai_imo']['train_df'], + mask_mode=config['mask_mode'], + sample_mode=config['sample_mode'], + transforms=[ + PercentileClippingAndToFloat(cut_min=10, cut_max=99), + CenterCrop(height=300, width=300), + HorizontalFlip(prob=.5), + GammaCorrection(gamma_range=(0.5, 1.5), prob=.5), + OneOf([ + DualCompose([ + Scale(ratio_range=(0.7, 0.8), prob=1.), + Scale(ratio_range=(1.5, 1.6), prob=1.), + ]), + NoTransform() + ]), + Crop(output_size=(300, 300)), + BilateralFilter(d=5, sigma_color=50, sigma_space=50, prob=.3), + Normalize(mean=0.252699, std=0.251142), + ToTensor(), + ]) + datasets['okoa']['train'] = DatasetOKOASagittal2d( + df_meta=sources['okoa']['train_df'], + mask_mode='background_femoral_unitibial', + sample_mode=config['sample_mode'], + transforms=[ + PercentileClippingAndToFloat(cut_min=10, cut_max=99), + CenterCrop(height=300, width=300), + HorizontalFlip(prob=.5), + GammaCorrection(gamma_range=(0.5, 1.5), prob=.5), + OneOf([ + DualCompose([ + Scale(ratio_range=(0.7, 0.8), prob=1.), + Scale(ratio_range=(1.5, 1.6), prob=1.), + ]), + NoTransform() + ]), + Crop(output_size=(300, 300)), + BilateralFilter(d=5, sigma_color=50, sigma_space=50, prob=.3), + + Normalize(mean=0.252699, std=0.251142), + ToTensor(), + ]) + datasets['maknee']['train'] = DatasetMAKNEESagittal2d( + df_meta=sources['maknee']['train_df'], + mask_mode='', + sample_mode=config['sample_mode'], + transforms=[ + PercentileClippingAndToFloat(cut_min=10, cut_max=99), + CenterCrop(height=300, width=300), + HorizontalFlip(prob=.5), + GammaCorrection(gamma_range=(0.5, 1.5), prob=.5), + OneOf([ + DualCompose([ + Scale(ratio_range=(0.7, 0.8), prob=1.), + Scale(ratio_range=(1.5, 1.6), prob=1.), + ]), + NoTransform() + ]), + Crop(output_size=(300, 300)), + BilateralFilter(d=5, sigma_color=50, sigma_space=50, prob=.3), + Normalize(mean=0.252699, std=0.251142), + ToTensor(), + ]) + datasets['oai_imo']['val'] = DatasetOAIiMoSagittal2d( + df_meta=sources['oai_imo']['val_df'], + mask_mode=config['mask_mode'], + sample_mode=config['sample_mode'], + transforms=[ + PercentileClippingAndToFloat(cut_min=10, cut_max=99), + CenterCrop(height=300, width=300), + Normalize(mean=0.252699, std=0.251142), + ToTensor() + ]) + datasets['okoa']['val'] = DatasetOKOASagittal2d( + df_meta=sources['okoa']['val_df'], + mask_mode='background_femoral_unitibial', + sample_mode=config['sample_mode'], + transforms=[ + PercentileClippingAndToFloat(cut_min=10, cut_max=99), + CenterCrop(height=300, width=300), + Normalize(mean=0.252699, std=0.251142), + ToTensor() + ]) + datasets['maknee']['val'] = DatasetMAKNEESagittal2d( + df_meta=sources['maknee']['val_df'], + mask_mode='', + sample_mode=config['sample_mode'], + transforms=[ + PercentileClippingAndToFloat(cut_min=10, cut_max=99), + CenterCrop(height=300, width=300), + Normalize(mean=0.252699, std=0.251142), + ToTensor() + ]) + + loaders = defaultdict(dict) + + loaders['oai_imo']['train'] = DataLoader( + datasets['oai_imo']['train'], + batch_size=int(config['batch_size'] / 2), + shuffle=True, + num_workers=config['num_workers'], + drop_last=True) + loaders['oai_imo']['val'] = DataLoader( + datasets['oai_imo']['val'], + batch_size=int(config['batch_size'] / 2), + shuffle=False, + num_workers=config['num_workers'], + drop_last=True) + loaders['okoa']['train'] = DataLoader( + datasets['okoa']['train'], + batch_size=int(config['batch_size'] / 2), + shuffle=True, + num_workers=config['num_workers'], + drop_last=True) + loaders['okoa']['val'] = DataLoader( + datasets['okoa']['val'], + batch_size=int(config['batch_size'] / 2), + shuffle=False, + num_workers=config['num_workers'], + drop_last=True) + loaders['maknee']['train'] = DataLoader( + datasets['maknee']['train'], + batch_size=int(config['batch_size'] / 2), + shuffle=True, + num_workers=config['num_workers'], + drop_last=True) + loaders['maknee']['val'] = DataLoader( + datasets['maknee']['val'], + batch_size=int(config['batch_size'] / 2), + shuffle=False, + num_workers=config['num_workers'], + drop_last=True) + + trainer = ModelTrainer(config=config, fold_idx=fold_idx) + + tmp = trainer.fit(loaders=loaders) + metrics_train, fnames_train, metrics_val, fnames_val = tmp + + fold_scores[fold_idx] = (metrics_val['datasetw']['dice_score_oai'], + metrics_val['datasetw']['dice_score_okoa']) + trainer.tensorboard.close() + logger.info(f'Fold scores:\n{repr(fold_scores)}') + + +if __name__ == '__main__': + main()