import os
import logging
from collections import defaultdict
import gc
import click
import resource
import numpy as np
import cv2
import torch
import torch.nn.functional as torch_fn
from torch import nn
from torch.utils.data.dataloader import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from rocaseg.datasets import (DatasetOAIiMoSagittal2d,
DatasetOKOASagittal2d,
DatasetMAKNEESagittal2d,
sources_from_path)
from rocaseg.models import dict_models
from rocaseg.components import (dict_losses, confusion_matrix, dice_score_from_cm,
dict_optimizers, CheckpointHandler)
from rocaseg.preproc import *
from rocaseg.repro import set_ultimate_seed
from rocaseg.components.mixup import mixup_criterion, mixup_data
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))
cv2.ocl.setUseOpenCL(False)
cv2.setNumThreads(0)
logging.basicConfig()
logger = logging.getLogger('train')
logger.setLevel(logging.DEBUG)
set_ultimate_seed()
if torch.cuda.is_available():
maybe_gpu = 'cuda'
else:
maybe_gpu = 'cpu'
class ModelTrainer:
def __init__(self, config, fold_idx=None):
self.config = config
self.fold_idx = fold_idx
self.paths_weights_fold = dict()
self.paths_weights_fold['segm'] = \
os.path.join(config['path_weights'], 'segm', f'fold_{self.fold_idx}')
os.makedirs(self.paths_weights_fold['segm'], exist_ok=True)
self.paths_weights_fold['discr'] = \
os.path.join(config['path_weights'], 'discr', f'fold_{self.fold_idx}')
os.makedirs(self.paths_weights_fold['discr'], exist_ok=True)
self.path_logs_fold = \
os.path.join(config['path_logs'], f'fold_{self.fold_idx}')
os.makedirs(self.path_logs_fold, exist_ok=True)
self.handlers_ckpt = dict()
self.handlers_ckpt['segm'] = CheckpointHandler(self.paths_weights_fold['segm'])
self.handlers_ckpt['discr'] = CheckpointHandler(self.paths_weights_fold['discr'])
paths_ckpt_sel = dict()
paths_ckpt_sel['segm'] = self.handlers_ckpt['segm'].get_last_ckpt()
paths_ckpt_sel['discr'] = self.handlers_ckpt['discr'].get_last_ckpt()
# Initialize and configure the models
self.models = dict()
self.models['segm'] = (dict_models[config['model_segm']]
(input_channels=self.config['input_channels'],
output_channels=self.config['output_channels'],
center_depth=self.config['center_depth'],
pretrained=self.config['pretrained'],
path_pretrained=self.config['path_pretrained_segm'],
restore_weights=self.config['restore_weights'],
path_weights=paths_ckpt_sel['segm']))
self.models['segm'] = nn.DataParallel(self.models['segm'])
self.models['segm'] = self.models['segm'].to(maybe_gpu)
self.models['discr'] = (dict_models[config['model_discr']]
(input_channels=self.config['output_channels'],
output_channels=1,
pretrained=self.config['pretrained'],
restore_weights=self.config['restore_weights'],
path_weights=paths_ckpt_sel['discr']))
self.models['discr'] = nn.DataParallel(self.models['discr'])
self.models['discr'] = self.models['discr'].to(maybe_gpu)
# Configure the training
self.optimizers = dict()
self.optimizers['segm'] = (dict_optimizers['adam'](
self.models['segm'].parameters(),
lr=self.config['lr_segm'],
weight_decay=self.config['wd_segm']))
self.optimizers['discr'] = (dict_optimizers['adam'](
self.models['discr'].parameters(),
lr=self.config['lr_discr'],
weight_decay=self.config['wd_discr']))
self.lr_update_rule = {25: 0.1}
self.losses = dict()
self.losses['segm'] = dict_losses[self.config['loss_segm']](
num_classes=self.config['output_channels'],
)
self.losses['advers'] = dict_losses['bce_loss']()
self.losses['discr'] = dict_losses['bce_loss']()
self.losses['segm'] = self.losses['segm'].to(maybe_gpu)
self.losses['advers'] = self.losses['advers'].to(maybe_gpu)
self.losses['discr'] = self.losses['discr'].to(maybe_gpu)
self.tensorboard = SummaryWriter(self.path_logs_fold)
def run_one_epoch(self, epoch_idx, loaders):
COEFF_DISCR = 1
COEFF_SEGM = 1
COEFF_ADVERS = 0.001
fnames_acc = defaultdict(list)
metrics_acc = dict()
metrics_acc['samplew'] = defaultdict(list)
metrics_acc['batchw'] = defaultdict(list)
metrics_acc['datasetw'] = defaultdict(list)
metrics_acc['datasetw']['cm_oai'] = \
np.zeros((self.config['output_channels'],) * 2, dtype=np.uint32)
metrics_acc['datasetw']['cm_okoa'] = \
np.zeros((self.config['output_channels'],) * 2, dtype=np.uint32)
prog_bar_params = {'postfix': {'epoch': epoch_idx}, }
if self.models['segm'].training and self.models['discr'].training:
# ------------------------ Training regime ------------------------
loader_oai = loaders['oai_imo']['train']
loader_maknee = loaders['maknee']['train']
steps_oai, steps_maknee = len(loader_oai), len(loader_maknee)
steps_total = steps_oai
prog_bar_params.update({'total': steps_total,
'desc': f'Train, epoch {epoch_idx}'})
loader_oai_iter = iter(loader_oai)
loader_maknee_iter = iter(loader_maknee)
loader_oai_iter_old = None
loader_maknee_iter_old = None
with tqdm(**prog_bar_params) as prog_bar:
for step_idx in range(steps_total):
self.optimizers['segm'].zero_grad()
self.optimizers['discr'].zero_grad()
metrics_acc['batchw']['loss_total'].append(0)
try:
data_batch_oai = next(loader_oai_iter)
except StopIteration:
loader_oai_iter_old = loader_oai_iter
loader_oai_iter = iter(loader_oai)
data_batch_oai = next(loader_oai_iter)
try:
data_batch_maknee = next(loader_maknee_iter)
except StopIteration:
loader_maknee_iter_old = loader_maknee_iter
loader_maknee_iter = iter(loader_maknee)
data_batch_maknee = next(loader_maknee_iter)
xs_oai, ys_true_oai = data_batch_oai['xs'], data_batch_oai['ys']
fnames_acc['oai'].extend(data_batch_oai['path_image'])
xs_oai = xs_oai.to(maybe_gpu)
ys_true_arg_oai = torch.argmax(ys_true_oai.long().to(maybe_gpu), dim=1)
xs_maknee, _ = data_batch_maknee['xs'], data_batch_maknee['ys']
fnames_acc['maknee'].extend(data_batch_maknee['path_image'])
xs_maknee = xs_maknee.to(maybe_gpu)
# -------------- Train discriminator network -------------
# With source
ys_pred_oai = self.models['segm'](xs_oai)
ys_pred_softmax_oai = torch_fn.softmax(ys_pred_oai, dim=1)
zs_pred_oai = self.models['discr'](ys_pred_softmax_oai)
# Use 0 as a label for the source domain
loss_discr_0 = self.losses['discr'](
input=zs_pred_oai,
target=torch.zeros_like(zs_pred_oai, device=maybe_gpu))
loss_discr_0 = loss_discr_0 / 2 * COEFF_DISCR
loss_discr_0.backward(retain_graph=True)
metrics_acc['batchw']['loss_discr_0'].append(loss_discr_0.item())
metrics_acc['batchw']['loss_total'][-1] += \
metrics_acc['batchw']['loss_discr_0'][-1]
# With target
self.models['segm'] = self.models['segm'].eval()
ys_pred_maknee = self.models['segm'](xs_maknee)
self.models['segm'] = self.models['segm'].train()
ys_pred_softmax_maknee = torch_fn.softmax(ys_pred_maknee, dim=1)
zs_pred_maknee = self.models['discr'](ys_pred_softmax_maknee)
# Use 1 as a label for the target domain
loss_discr_1 = self.losses['discr'](
input=zs_pred_maknee,
target=torch.ones_like(zs_pred_maknee, device=maybe_gpu))
loss_discr_1 = loss_discr_1 / 2 * COEFF_DISCR
loss_discr_1.backward()
metrics_acc['batchw']['loss_discr_1'].append(loss_discr_1.item())
metrics_acc['batchw']['loss_total'][-1] += \
metrics_acc['batchw']['loss_discr_1'][-1]
self.models['segm'].zero_grad()
self.optimizers['discr'].step()
self.models['discr'].zero_grad()
# ---------------- Train segmentation network ------------
# With source
if not self.config['with_mixup']:
ys_pred_oai = self.models['segm'](xs_oai)
loss_segm = self.losses['segm'](input_=ys_pred_oai,
target=ys_true_arg_oai)
else:
xs_mixup, ys_mixup_a, ys_mixup_b, lambda_mixup = mixup_data(
x=xs_oai, y=ys_true_arg_oai,
alpha=self.config['mixup_alpha'], device=maybe_gpu)
ys_pred_oai = self.models['segm'](xs_mixup)
loss_segm = mixup_criterion(criterion=self.losses['segm'],
pred=ys_pred_oai,
y_a=ys_mixup_a,
y_b=ys_mixup_b,
lam=lambda_mixup)
loss_segm.backward(retain_graph=True)
loss_segm = loss_segm * COEFF_SEGM
metrics_acc['batchw']['loss_segm'].append(loss_segm.item())
metrics_acc['batchw']['loss_total'][-1] += \
metrics_acc['batchw']['loss_segm'][-1]
# With target
self.models['segm'] = self.models['segm'].eval()
ys_pred_maknee = self.models['segm'](xs_maknee)
self.models['segm'] = self.models['segm'].train()
ys_pred_softmax_maknee = torch_fn.softmax(ys_pred_maknee, dim=1)
zs_pred_maknee = self.models['discr'](ys_pred_softmax_maknee)
# Use 0 as a label for the source domain
loss_advers = self.losses['advers'](
input=zs_pred_maknee,
target=torch.zeros_like(zs_pred_maknee, device=maybe_gpu))
loss_advers = loss_advers * COEFF_ADVERS
loss_advers.backward()
metrics_acc['batchw']['loss_advers'].append(loss_advers.item())
metrics_acc['batchw']['loss_total'][-1] += \
metrics_acc['batchw']['loss_advers'][-1]
self.models['discr'].zero_grad()
self.optimizers['segm'].step()
if step_idx % 10 == 0:
self.tensorboard.add_scalars(
f'fold_{self.fold_idx}/losses_train',
{'discr_0_batchw': metrics_acc['batchw']['loss_discr_0'][-1],
'discr_1_batchw': metrics_acc['batchw']['loss_discr_1'][-1],
'discr_sum_batchw':
(metrics_acc['batchw']['loss_discr_0'][-1] +
metrics_acc['batchw']['loss_discr_1'][-1]),
'segm_batchw': metrics_acc['batchw']['loss_segm'][-1],
'advers_batchw':
metrics_acc['batchw']['loss_advers'][-1],
'total_batchw': metrics_acc['batchw']['loss_total'][-1],
}, global_step=(epoch_idx * steps_total + step_idx))
prog_bar.update(1)
del [loader_oai_iter_old, loader_maknee_iter_old]
gc.collect()
else:
# ----------------------- Validation regime -----------------------
loader_oai = loaders['oai_imo']['val']
loader_okoa = loaders['okoa']['val']
loader_maknee = loaders['maknee']['val']
steps_oai, steps_okoa, steps_maknee = len(loader_oai), len(loader_okoa), len(loader_maknee)
steps_total = steps_oai
prog_bar_params.update({'total': steps_total,
'desc': f'Validate, epoch {epoch_idx}'})
loader_oai_iter = iter(loader_oai)
loader_okoa_iter = iter(loader_okoa)
loader_maknee_iter = iter(loader_maknee)
loader_oai_iter_old = None
loader_okoa_iter_old = None
loader_maknee_iter_old = None
with torch.no_grad(), tqdm(**prog_bar_params) as prog_bar:
for step_idx in range(steps_total):
metrics_acc['batchw']['loss_total'].append(0)
try:
data_batch_oai = next(loader_oai_iter)
except StopIteration:
loader_oai_iter_old = loader_oai_iter
loader_oai_iter = iter(loader_oai)
data_batch_oai = next(loader_oai_iter)
try:
data_batch_okoa = next(loader_okoa_iter)
except StopIteration:
loader_okoa_iter_old = loader_okoa_iter
loader_okoa_iter = iter(loader_okoa)
data_batch_okoa = next(loader_okoa_iter)
try:
data_batch_maknee = next(loader_maknee_iter)
except StopIteration:
loader_maknee_iter_old = loader_maknee_iter
loader_maknee_iter = iter(loader_maknee)
data_batch_maknee = next(loader_maknee_iter)
xs_oai, ys_true_oai = data_batch_oai['xs'], data_batch_oai['ys']
fnames_acc['oai'].extend(data_batch_oai['path_image'])
xs_oai = xs_oai.to(maybe_gpu)
ys_true_arg_oai = torch.argmax(ys_true_oai.long().to(maybe_gpu), dim=1)
xs_maknee, _ = data_batch_maknee['xs'], data_batch_maknee['ys']
fnames_acc['maknee'].extend(data_batch_maknee['path_image'])
xs_maknee = xs_maknee.to(maybe_gpu)
# -------------- Validate discriminator network -------------
# With source
ys_pred_oai = self.models['segm'](xs_oai)
ys_pred_softmax_oai = torch_fn.softmax(ys_pred_oai, dim=1)
zs_pred_oai = self.models['discr'](ys_pred_softmax_oai)
# Use 0 as a label for the source domain
loss_discr_0 = self.losses['discr'](
input=zs_pred_oai,
target=torch.zeros_like(zs_pred_oai, device=maybe_gpu))
loss_discr_0 = loss_discr_0 / 2 * COEFF_DISCR
metrics_acc['batchw']['loss_discr_0'].append(loss_discr_0.item())
metrics_acc['batchw']['loss_total'][-1] += \
metrics_acc['batchw']['loss_discr_0'][-1]
# With target
ys_pred_maknee = self.models['segm'](xs_maknee)
ys_pred_softmax_maknee = torch_fn.softmax(ys_pred_maknee, dim=1)
zs_pred_maknee = self.models['discr'](ys_pred_softmax_maknee)
# Use 1 as a label for the target domain
loss_discr_1 = self.losses['discr'](
input=zs_pred_maknee,
target=torch.ones_like(zs_pred_oai, device=maybe_gpu))
loss_discr_1 = loss_discr_1 / 2 * COEFF_DISCR
metrics_acc['batchw']['loss_discr_1'].append(loss_discr_1.item())
metrics_acc['batchw']['loss_total'][-1] += \
metrics_acc['batchw']['loss_discr_1'][-1]
# ---------------- Validate segmentation network ------------
# With source
if not self.config['with_mixup']:
ys_pred_oai = self.models['segm'](xs_oai)
loss_segm = self.losses['segm'](input_=ys_pred_oai,
target=ys_true_arg_oai)
else:
xs_mixup, ys_mixup_a, ys_mixup_b, lambda_mixup = mixup_data(
x=xs_oai, y=ys_true_arg_oai,
alpha=self.config['mixup_alpha'], device=maybe_gpu)
ys_pred_oai = self.models['segm'](xs_mixup)
loss_segm = mixup_criterion(criterion=self.losses['segm'],
pred=ys_pred_oai,
y_a=ys_mixup_a,
y_b=ys_mixup_b,
lam=lambda_mixup)
loss_segm = loss_segm * COEFF_SEGM
metrics_acc['batchw']['loss_segm'].append(loss_segm.item())
metrics_acc['batchw']['loss_total'][-1] += \
metrics_acc['batchw']['loss_segm'][-1]
# With target
ys_pred_maknee = self.models['segm'](xs_maknee)
ys_pred_softmax_maknee = torch_fn.softmax(ys_pred_maknee, dim=1)
zs_pred_maknee = self.models['discr'](ys_pred_softmax_maknee)
# Use 0 as a label for the source domain
loss_advers = self.losses['advers'](
input=zs_pred_maknee,
target=torch.zeros_like(zs_pred_maknee, device=maybe_gpu))
loss_advers = loss_advers * COEFF_ADVERS
metrics_acc['batchw']['loss_advers'].append(loss_advers.item())
metrics_acc['batchw']['loss_total'][-1] += \
metrics_acc['batchw']['loss_advers'][-1]
if step_idx % 10 == 0:
self.tensorboard.add_scalars(
f'fold_{self.fold_idx}/losses_val',
{'discr_0_batchw': metrics_acc['batchw']['loss_discr_0'][-1],
'discr_1_batchw': metrics_acc['batchw']['loss_discr_1'][-1],
'discr_sum_batchw':
(metrics_acc['batchw']['loss_discr_0'][-1] +
metrics_acc['batchw']['loss_discr_1'][-1]),
'segm_batchw': metrics_acc['batchw']['loss_segm'][-1],
'advers_batchw':
metrics_acc['batchw']['loss_advers'][-1],
'total_batchw': metrics_acc['batchw']['loss_total'][-1],
}, global_step=(epoch_idx * steps_total + step_idx))
# ------------------ Calculate metrics -------------------
ys_pred_arg_np_oai = torch.argmax(ys_pred_softmax_oai, 1).to('cpu').numpy()
ys_true_arg_np_oai = ys_true_arg_oai.to('cpu').numpy()
metrics_acc['datasetw']['cm_oai'] += confusion_matrix(
ys_pred_arg_np_oai, ys_true_arg_np_oai,
self.config['output_channels'])
# Don't consider repeating entries for the metrics calculation
if step_idx < steps_okoa:
xs_okoa, ys_true_okoa = data_batch_okoa['xs'], data_batch_okoa['ys']
fnames_acc['okoa'].extend(data_batch_okoa['path_image'])
xs_okoa = xs_okoa.to(maybe_gpu)
ys_pred_okoa = self.models['segm'](xs_okoa)
ys_true_arg_okoa = torch.argmax(ys_true_okoa.long().to(maybe_gpu), dim=1)
ys_pred_softmax_okoa = torch_fn.softmax(ys_pred_okoa, dim=1)
ys_pred_arg_np_okoa = torch.argmax(ys_pred_softmax_okoa, 1).to('cpu').numpy()
ys_true_arg_np_okoa = ys_true_arg_okoa.to('cpu').numpy()
metrics_acc['datasetw']['cm_okoa'] += confusion_matrix(
ys_pred_arg_np_okoa, ys_true_arg_np_okoa,
self.config['output_channels'])
prog_bar.update(1)
del [loader_oai_iter_old, loader_okoa_iter_old, loader_maknee_iter_old]
gc.collect()
for k, v in metrics_acc['samplew'].items():
metrics_acc['samplew'][k] = np.asarray(v)
metrics_acc['datasetw']['dice_score_oai'] = np.asarray(
dice_score_from_cm(metrics_acc['datasetw']['cm_oai']))
metrics_acc['datasetw']['dice_score_okoa'] = np.asarray(
dice_score_from_cm(metrics_acc['datasetw']['cm_okoa']))
return metrics_acc, fnames_acc
def fit(self, loaders):
epoch_idx_best = -1
loss_best = float('inf')
metrics_train_best = dict()
fnames_train_best = []
metrics_val_best = dict()
fnames_val_best = []
for epoch_idx in range(self.config['epoch_num']):
self.models = {n: m.train() for n, m in self.models.items()}
metrics_train, fnames_train = \
self.run_one_epoch(epoch_idx, loaders)
# Process the accumulated metrics
for k, v in metrics_train['batchw'].items():
if k.startswith('loss'):
metrics_train['datasetw'][k] = np.mean(np.asarray(v))
else:
logger.warning(f'Non-processed batch-wise entry: {k}')
self.models = {n: m.eval() for n, m in self.models.items()}
metrics_val, fnames_val = \
self.run_one_epoch(epoch_idx, loaders)
# Process the accumulated metrics
for k, v in metrics_val['batchw'].items():
if k.startswith('loss'):
metrics_val['datasetw'][k] = np.mean(np.asarray(v))
else:
logger.warning(f'Non-processed batch-wise entry: {k}')
# Learning rate update
for s, m in self.lr_update_rule.items():
if epoch_idx == s:
for name, optim in self.optimizers.items():
for param_group in optim.param_groups:
param_group['lr'] *= m
# Add console logging
logger.info(f'Epoch: {epoch_idx}')
for subset, metrics in (('train', metrics_train),
('val', metrics_val)):
logger.info(f'{subset} metrics:')
for k, v in metrics['datasetw'].items():
logger.info(f'{k}: \n{v}')
# Add TensorBoard logging
for subset, metrics in (('train', metrics_train),
('val', metrics_val)):
# Log only dataset-reduced metrics
for k, v in metrics['datasetw'].items():
if isinstance(v, np.ndarray):
self.tensorboard.add_scalars(
f'fold_{self.fold_idx}/{k}_{subset}',
{f'class{i}': e for i, e in enumerate(v.ravel().tolist())},
global_step=epoch_idx)
elif isinstance(v, (str, int, float)):
self.tensorboard.add_scalar(
f'fold_{self.fold_idx}/{k}_{subset}',
float(v),
global_step=epoch_idx)
else:
logger.warning(f'{k} is of unsupported dtype {v}')
for name, optim in self.optimizers.items():
for param_group in optim.param_groups:
self.tensorboard.add_scalar(
f'fold_{self.fold_idx}/learning_rate/{name}',
param_group['lr'],
global_step=epoch_idx)
# Save the model
loss_curr = metrics_val['datasetw']['loss_total']
if loss_curr < loss_best:
loss_best = loss_curr
epoch_idx_best = epoch_idx
metrics_train_best = metrics_train
metrics_val_best = metrics_val
fnames_train_best = fnames_train
fnames_val_best = fnames_val
self.handlers_ckpt['segm'].save_new_ckpt(
model=self.models['segm'],
model_name=self.config['model_segm'],
fold_idx=self.fold_idx,
epoch_idx=epoch_idx)
self.handlers_ckpt['discr'].save_new_ckpt(
model=self.models['discr'],
model_name=self.config['model_discr'],
fold_idx=self.fold_idx,
epoch_idx=epoch_idx)
msg = (f'Finished fold {self.fold_idx} '
f'with the best loss {loss_best:.5f} '
f'on epoch {epoch_idx_best}, '
f'weights: ({self.paths_weights_fold})')
logger.info(msg)
return (metrics_train_best, fnames_train_best,
metrics_val_best, fnames_val_best)
@click.command()
@click.option('--path_data_root', default='../../data')
@click.option('--path_experiment_root', default='../../results/temporary')
@click.option('--model_segm', default='unet_lext')
@click.option('--center_depth', default=1, type=int)
@click.option('--model_discr', default='discriminator_a')
@click.option('--pretrained', is_flag=True)
@click.option('--path_pretrained_segm', type=str, help='Path to .pth file')
@click.option('--restore_weights', is_flag=True)
@click.option('--input_channels', default=1, type=int)
@click.option('--output_channels', default=1, type=int)
@click.option('--mask_mode', default='all_unitibial_unimeniscus', type=str)
@click.option('--sample_mode', default='x_y', type=str)
@click.option('--loss_segm', default='multi_ce_loss')
@click.option('--lr_segm', default=0.0001, type=float)
@click.option('--lr_discr', default=0.0001, type=float)
@click.option('--wd_segm', default=5e-5, type=float)
@click.option('--wd_discr', default=5e-5, type=float)
@click.option('--optimizer_segm', default='adam')
@click.option('--optimizer_discr', default='adam')
@click.option('--batch_size', default=64, type=int)
@click.option('--epoch_size', default=1.0, type=float)
@click.option('--epoch_num', default=2, type=int)
@click.option('--fold_num', default=5, type=int)
@click.option('--fold_idx', default=-1, type=int)
@click.option('--fold_idx_ignore', multiple=True, type=int)
@click.option('--num_workers', default=1, type=int)
@click.option('--seed_trainval_test', default=0, type=int)
@click.option('--with_mixup', is_flag=True)
@click.option('--mixup_alpha', default=1, type=float)
def main(**config):
config['path_data_root'] = os.path.abspath(config['path_data_root'])
config['path_experiment_root'] = os.path.abspath(config['path_experiment_root'])
config['path_weights'] = os.path.join(config['path_experiment_root'], 'weights')
config['path_logs'] = os.path.join(config['path_experiment_root'], 'logs_train')
os.makedirs(config['path_weights'], exist_ok=True)
os.makedirs(config['path_logs'], exist_ok=True)
logging_fh = logging.FileHandler(
os.path.join(config['path_logs'], 'main_{}.log'.format(config['fold_idx'])))
logging_fh.setLevel(logging.DEBUG)
logger.addHandler(logging_fh)
# Collect the available and specified sources
sources = sources_from_path(path_data_root=config['path_data_root'],
selection=('oai_imo', 'okoa', 'maknee'),
with_folds=True,
fold_num=config['fold_num'],
seed_trainval_test=config['seed_trainval_test'])
# Build a list of folds to run on
if config['fold_idx'] == -1:
fold_idcs = list(range(config['fold_num']))
else:
fold_idcs = [config['fold_idx'], ]
for g in config['fold_idx_ignore']:
fold_idcs = [i for i in fold_idcs if i != g]
# Train each fold separately
fold_scores = dict()
# Use straightforward fold allocation strategy
folds = list(zip(sources['oai_imo']['trainval_folds'],
sources['okoa']['trainval_folds'],
sources['maknee']['trainval_folds']))
for fold_idx, idcs_subsets in enumerate(folds):
if fold_idx not in fold_idcs:
continue
logger.info(f'Training fold {fold_idx}')
(sources['oai_imo']['train_idcs'], sources['oai_imo']['val_idcs']) = idcs_subsets[0]
(sources['okoa']['train_idcs'], sources['okoa']['val_idcs']) = idcs_subsets[1]
(sources['maknee']['train_idcs'], sources['maknee']['val_idcs']) = idcs_subsets[2]
sources['oai_imo']['train_df'] = sources['oai_imo']['trainval_df'].iloc[sources['oai_imo']['train_idcs']]
sources['oai_imo']['val_df'] = sources['oai_imo']['trainval_df'].iloc[sources['oai_imo']['val_idcs']]
sources['okoa']['train_df'] = sources['okoa']['trainval_df'].iloc[sources['okoa']['train_idcs']]
sources['okoa']['val_df'] = sources['okoa']['trainval_df'].iloc[sources['okoa']['val_idcs']]
sources['maknee']['train_df'] = sources['maknee']['trainval_df'].iloc[sources['maknee']['train_idcs']]
sources['maknee']['val_df'] = sources['maknee']['trainval_df'].iloc[sources['maknee']['val_idcs']]
for n, s in sources.items():
logger.info('Made {} train-val split, number of samples: {}, {}'
.format(n, len(s['train_df']), len(s['val_df'])))
datasets = defaultdict(dict)
datasets['oai_imo']['train'] = DatasetOAIiMoSagittal2d(
df_meta=sources['oai_imo']['train_df'],
mask_mode=config['mask_mode'],
sample_mode=config['sample_mode'],
transforms=[
PercentileClippingAndToFloat(cut_min=10, cut_max=99),
CenterCrop(height=300, width=300),
HorizontalFlip(prob=.5),
GammaCorrection(gamma_range=(0.5, 1.5), prob=.5),
OneOf([
DualCompose([
Scale(ratio_range=(0.7, 0.8), prob=1.),
Scale(ratio_range=(1.5, 1.6), prob=1.),
]),
NoTransform()
]),
Crop(output_size=(300, 300)),
BilateralFilter(d=5, sigma_color=50, sigma_space=50, prob=.3),
Normalize(mean=0.252699, std=0.251142),
ToTensor(),
])
datasets['okoa']['train'] = DatasetOKOASagittal2d(
df_meta=sources['okoa']['train_df'],
mask_mode='background_femoral_unitibial',
sample_mode=config['sample_mode'],
transforms=[
PercentileClippingAndToFloat(cut_min=10, cut_max=99),
CenterCrop(height=300, width=300),
HorizontalFlip(prob=.5),
GammaCorrection(gamma_range=(0.5, 1.5), prob=.5),
OneOf([
DualCompose([
Scale(ratio_range=(0.7, 0.8), prob=1.),
Scale(ratio_range=(1.5, 1.6), prob=1.),
]),
NoTransform()
]),
Crop(output_size=(300, 300)),
BilateralFilter(d=5, sigma_color=50, sigma_space=50, prob=.3),
Normalize(mean=0.252699, std=0.251142),
ToTensor(),
])
datasets['maknee']['train'] = DatasetMAKNEESagittal2d(
df_meta=sources['maknee']['train_df'],
mask_mode='',
sample_mode=config['sample_mode'],
transforms=[
PercentileClippingAndToFloat(cut_min=10, cut_max=99),
CenterCrop(height=300, width=300),
HorizontalFlip(prob=.5),
GammaCorrection(gamma_range=(0.5, 1.5), prob=.5),
OneOf([
DualCompose([
Scale(ratio_range=(0.7, 0.8), prob=1.),
Scale(ratio_range=(1.5, 1.6), prob=1.),
]),
NoTransform()
]),
Crop(output_size=(300, 300)),
BilateralFilter(d=5, sigma_color=50, sigma_space=50, prob=.3),
Normalize(mean=0.252699, std=0.251142),
ToTensor(),
])
datasets['oai_imo']['val'] = DatasetOAIiMoSagittal2d(
df_meta=sources['oai_imo']['val_df'],
mask_mode=config['mask_mode'],
sample_mode=config['sample_mode'],
transforms=[
PercentileClippingAndToFloat(cut_min=10, cut_max=99),
CenterCrop(height=300, width=300),
Normalize(mean=0.252699, std=0.251142),
ToTensor()
])
datasets['okoa']['val'] = DatasetOKOASagittal2d(
df_meta=sources['okoa']['val_df'],
mask_mode='background_femoral_unitibial',
sample_mode=config['sample_mode'],
transforms=[
PercentileClippingAndToFloat(cut_min=10, cut_max=99),
CenterCrop(height=300, width=300),
Normalize(mean=0.252699, std=0.251142),
ToTensor()
])
datasets['maknee']['val'] = DatasetMAKNEESagittal2d(
df_meta=sources['maknee']['val_df'],
mask_mode='',
sample_mode=config['sample_mode'],
transforms=[
PercentileClippingAndToFloat(cut_min=10, cut_max=99),
CenterCrop(height=300, width=300),
Normalize(mean=0.252699, std=0.251142),
ToTensor()
])
loaders = defaultdict(dict)
loaders['oai_imo']['train'] = DataLoader(
datasets['oai_imo']['train'],
batch_size=int(config['batch_size'] / 2),
shuffle=True,
num_workers=config['num_workers'],
drop_last=True)
loaders['oai_imo']['val'] = DataLoader(
datasets['oai_imo']['val'],
batch_size=int(config['batch_size'] / 2),
shuffle=False,
num_workers=config['num_workers'],
drop_last=True)
loaders['okoa']['train'] = DataLoader(
datasets['okoa']['train'],
batch_size=int(config['batch_size'] / 2),
shuffle=True,
num_workers=config['num_workers'],
drop_last=True)
loaders['okoa']['val'] = DataLoader(
datasets['okoa']['val'],
batch_size=int(config['batch_size'] / 2),
shuffle=False,
num_workers=config['num_workers'],
drop_last=True)
loaders['maknee']['train'] = DataLoader(
datasets['maknee']['train'],
batch_size=int(config['batch_size'] / 2),
shuffle=True,
num_workers=config['num_workers'],
drop_last=True)
loaders['maknee']['val'] = DataLoader(
datasets['maknee']['val'],
batch_size=int(config['batch_size'] / 2),
shuffle=False,
num_workers=config['num_workers'],
drop_last=True)
trainer = ModelTrainer(config=config, fold_idx=fold_idx)
tmp = trainer.fit(loaders=loaders)
metrics_train, fnames_train, metrics_val, fnames_val = tmp
fold_scores[fold_idx] = (metrics_val['datasetw']['dice_score_oai'],
metrics_val['datasetw']['dice_score_okoa'])
trainer.tensorboard.close()
logger.info(f'Fold scores:\n{repr(fold_scores)}')
if __name__ == '__main__':
main()