--- a
+++ b/rocaseg/evaluate.py
@@ -0,0 +1,398 @@
+import os
+import logging
+from glob import glob
+
+import numpy as np
+from skimage.color import label2rgb
+from skimage import img_as_ubyte
+from tqdm import tqdm
+import click
+
+import cv2
+import tifffile
+import torch
+import torch.nn as nn
+from torch.utils.data.dataloader import DataLoader
+
+from rocaseg.datasets import sources_from_path
+from rocaseg.components import CheckpointHandler
+from rocaseg.components.formats import numpy_to_nifti, png_to_numpy
+from rocaseg.models import dict_models
+from rocaseg.preproc import *
+from rocaseg.repro import set_ultimate_seed
+
+
+# The fix is a workaround to PyTorch multiprocessing issue:
+# "RuntimeError: received 0 items of ancdata"
+torch.multiprocessing.set_sharing_strategy('file_system')
+
+cv2.ocl.setUseOpenCL(False)
+cv2.setNumThreads(0)
+
+logging.basicConfig()
+logger = logging.getLogger('eval')
+logger.setLevel(logging.INFO)
+
+set_ultimate_seed()
+
+if torch.cuda.is_available():
+    maybe_gpu = 'cuda'
+else:
+    maybe_gpu = 'cpu'
+
+
+def predict_folds(config, loader, fold_idcs):
+    """Evaluate the model versus each fold
+    """
+    for fold_idx in fold_idcs:
+        paths_weights_fold = dict()
+        paths_weights_fold['segm'] = \
+            os.path.join(config['path_weights'], 'segm', f'fold_{fold_idx}')
+
+        handlers_ckpt = dict()
+        handlers_ckpt['segm'] = CheckpointHandler(paths_weights_fold['segm'])
+
+        paths_ckpt_sel = dict()
+        paths_ckpt_sel['segm'] = handlers_ckpt['segm'].get_last_ckpt()
+
+        # Initialize and configure the model
+        model = (dict_models[config['model_segm']]
+                 (input_channels=config['input_channels'],
+                  output_channels=config['output_channels'],
+                  center_depth=config['center_depth'],
+                  pretrained=config['pretrained'],
+                  restore_weights=config['restore_weights'],
+                  path_weights=paths_ckpt_sel['segm']))
+        model = nn.DataParallel(model).to(maybe_gpu)
+        model.eval()
+
+        with tqdm(total=len(loader), desc=f'Eval, fold {fold_idx}') as prog_bar:
+            for i, data_batch in enumerate(loader):
+                xs, ys_true = data_batch['xs'], data_batch['ys']
+                xs, ys_true = xs.to(maybe_gpu), ys_true.to(maybe_gpu)
+
+                if config['model_segm'] == 'unet_lext':
+                    ys_pred = model(xs)
+                elif config['model_segm'] == 'unet_lext_aux':
+                    ys_pred, _ = model(xs)
+                else:
+                    msg = f"Unknown model {config['model_segm']}"
+                    raise ValueError(msg)
+
+                ys_pred_softmax = nn.Softmax(dim=1)(ys_pred)
+                ys_pred_softmax_np = ys_pred_softmax.detach().to('cpu').numpy()
+
+                data_batch['pred_softmax'] = ys_pred_softmax_np
+
+                # Rearrange the batch
+                data_dicts = [{k: v[n] for k, v in data_batch.items()}
+                              for n in range(len(data_batch['image']))]
+
+                for k, data_dict in enumerate(data_dicts):
+                    dir_base = os.path.join(
+                        config['path_predicts'],
+                        data_dict['patient'], data_dict['release'], data_dict['sequence'])
+                    fname_base = os.path.splitext(
+                        os.path.basename(data_dict['path_rel_image']))[0]
+
+                    # Save the predictions
+                    dir_predicts = os.path.join(dir_base, 'mask_folds')
+                    if not os.path.exists(dir_predicts):
+                        os.makedirs(dir_predicts)
+
+                    fname_full = os.path.join(
+                        dir_predicts,
+                        f'{fname_base}_fold_{fold_idx}.tiff')
+
+                    tmp = (data_dict['pred_softmax'] * 255).astype(np.uint8, casting='unsafe')
+                    tifffile.imsave(fname_full, tmp, compress=9)
+
+                prog_bar.update(1)
+
+
+def merge_predictions(config, source, loader, dict_fns,
+                      save_plots=False, remove_foldw=False, convert_to_nifti=True):
+    """Merge the predictions over all folds
+    """
+    dir_source_root = source['path_root']
+    df_meta = loader.dataset.df_meta
+
+    with tqdm(total=len(df_meta), desc='Merge') as prog_bar:
+        for i, row in df_meta.iterrows():
+            dir_scan_predicts = os.path.join(
+                config['path_predicts'],
+                row['patient'], row['release'], row['sequence'])
+            dir_image_prep = os.path.join(dir_scan_predicts, 'image_prep')
+            dir_mask_prep = os.path.join(dir_scan_predicts, 'mask_prep')
+            dir_mask_folds = os.path.join(dir_scan_predicts, 'mask_folds')
+            dir_mask_foldavg = os.path.join(dir_scan_predicts, 'mask_foldavg')
+            dir_vis_foldavg = os.path.join(dir_scan_predicts, 'vis_foldavg')
+
+            for p in (dir_image_prep, dir_mask_prep, dir_mask_folds, dir_mask_foldavg,
+                      dir_vis_foldavg):
+                if not os.path.exists(p):
+                    os.makedirs(p)
+
+            # Find the corresponding prediction files
+            fname_base = os.path.splitext(os.path.basename(row['path_rel_image']))[0]
+
+            fnames_pred = glob(os.path.join(dir_mask_folds, f'{fname_base}_fold_*.*'))
+
+            # Read the reference data
+            image = cv2.imread(
+                os.path.join(dir_source_root, row['path_rel_image']),
+                cv2.IMREAD_GRAYSCALE)
+            image = dict_fns['crop'](image[None, ])[0]
+            image = np.squeeze(image)
+            if 'path_rel_mask' in row.index:
+                ys_true = loader.dataset.read_mask(
+                    os.path.join(dir_source_root, row['path_rel_mask']))
+                if ys_true is not None:
+                    ys_true = dict_fns['crop'](ys_true)[0]
+            else:
+                ys_true = None
+
+            # Read the fold-wise predictions
+            yss_pred = [tifffile.imread(f) for f in fnames_pred]
+            ys_pred = np.stack(yss_pred, axis=0).astype(np.float32) / 255
+            ys_pred = torch.from_numpy(ys_pred).unsqueeze(dim=0)
+
+            # Average the fold predictions
+            ys_pred = torch.mean(ys_pred, dim=1, keepdim=False)
+            ys_pred_softmax = ys_pred / torch.sum(ys_pred, dim=1, keepdim=True)
+            ys_pred_softmax_np = ys_pred_softmax.squeeze().numpy()
+
+            ys_pred_arg_np = ys_pred_softmax_np.argmax(axis=0)
+
+            # Save preprocessed input data
+            fname_full = os.path.join(dir_image_prep, f'{fname_base}.png')
+            cv2.imwrite(fname_full, image)  # image
+
+            if ys_true is not None:
+                ys_true = ys_true.astype(np.float32)
+                ys_true = torch.from_numpy(ys_true).unsqueeze(dim=0)
+                ys_true_arg_np = ys_true.numpy().squeeze().argmax(axis=0)
+                fname_full = os.path.join(dir_mask_prep, f'{fname_base}.png')
+                cv2.imwrite(fname_full, ys_true_arg_np)  # mask
+
+            fname_meta = os.path.join(config['path_predicts'], 'meta_dynamic.csv')
+            if not os.path.exists(fname_meta):
+                df_meta.to_csv(fname_meta, index=False)  # metainfo
+
+            # Save ensemble prediction
+            fname_full = os.path.join(dir_mask_foldavg, f'{fname_base}.png')
+            cv2.imwrite(fname_full, ys_pred_arg_np)
+
+            # Save ensemble visualizations
+            if save_plots:
+                if ys_true is not None:
+                    fname_full = os.path.join(
+                        dir_vis_foldavg, f"{fname_base}_overlay_mask.png")
+                    save_vis_overlay(image=image,
+                                     mask=ys_true_arg_np,
+                                     num_classes=config['output_channels'],
+                                     fname=fname_full)
+
+                fname_full = os.path.join(
+                    dir_vis_foldavg, f"{fname_base}_overlay_pred.png")
+                save_vis_overlay(image=image,
+                                 mask=ys_pred_arg_np,
+                                 num_classes=config['output_channels'],
+                                 fname=fname_full)
+
+                if ys_true is not None:
+                    fname_full = os.path.join(
+                        dir_vis_foldavg, f"{fname_base}_overlay_diff.png")
+                    save_vis_mask_diff(image=image,
+                                       mask_true=ys_true_arg_np,
+                                       mask_pred=ys_pred_arg_np,
+                                       fname=fname_full)
+
+            # Remove the fold predictions
+            if remove_foldw:
+                for f in fnames_pred:
+                    try:
+                        os.remove(f)
+                    except OSError:
+                        logger.error(f'Cannot remove {f}')
+            prog_bar.update(1)
+
+    # Convert the results to 3D NIfTI images
+    if convert_to_nifti:
+        df_meta = df_meta.sort_values(by=["patient", "release", "sequence", "side"])
+
+        for gb_name, gb_df in tqdm(
+                df_meta.groupby(["patient", "release", "sequence", "side"]),
+                desc="Convert to NIfTI"):
+
+            patient, release, sequence, side = gb_name
+            spacings = (gb_df['pixel_spacing_0'].iloc[0],
+                        gb_df['pixel_spacing_1'].iloc[0],
+                        gb_df['slice_thickness'].iloc[0])
+
+            dir_scan_predicts = os.path.join(config['path_predicts'],
+                                             patient, release, sequence)
+            for result in ("image_prep", "mask_prep", "mask_foldavg"):
+                pattern = os.path.join(dir_scan_predicts, result, '*.png')
+                path_nii = os.path.join(dir_scan_predicts, f"{result}.nii")
+
+                # Read and compose 3D image
+                img = png_to_numpy(pattern_fname_in=pattern, reverse=False)
+
+                # Save to NIfTI
+                numpy_to_nifti(stack=img, fname_out=path_nii,
+                               spacings=spacings, rcp_to_ras=True)
+
+
+def save_vis_overlay(image, mask, num_classes, fname):
+    # Add a sample of each class to have consistent class colors
+    mask[0, :num_classes] = list(range(num_classes))
+    overlay = label2rgb(label=mask, image=image, bg_label=0,
+                        colors=['orangered', 'gold', 'lime', 'fuchsia'])
+    # Convert to uint8 to save space
+    overlay = img_as_ubyte(overlay)
+    # Save to file
+    if overlay.ndim == 3:
+        overlay = overlay[:, :, ::-1]
+    cv2.imwrite(fname, overlay)
+
+
+def save_vis_mask_diff(image, mask_true, mask_pred, fname):
+    diff = np.empty_like(mask_true)
+    diff[(mask_true == mask_pred) & (mask_pred == 0)] = 0  # TN
+    diff[(mask_true == mask_pred) & (mask_pred != 0)] = 0  # TP
+    diff[(mask_true != mask_pred) & (mask_pred == 0)] = 2  # FP
+    diff[(mask_true != mask_pred) & (mask_pred != 0)] = 3  # FN
+    diff_colors = ('green', 'red', 'yellow')
+    diff[0, :4] = [0, 1, 2, 3]
+    overlay = label2rgb(label=diff, image=image, bg_label=0,
+                        colors=diff_colors)
+    # Convert to uint8 to save space
+    overlay = img_as_ubyte(overlay)
+    # Save to file
+    if overlay.ndim == 3:
+        overlay = overlay[:, :, ::-1]
+    cv2.imwrite(fname, overlay)
+
+
+@click.command()
+@click.option('--path_data_root', default='../../data')
+@click.option('--path_experiment_root', default='../../results/temporary')
+@click.option('--model_segm', default='unet_lext')
+@click.option('--center_depth', default=1, type=int)
+@click.option('--pretrained', is_flag=True)
+@click.option('--restore_weights', is_flag=True)
+@click.option('--input_channels', default=1, type=int)
+@click.option('--output_channels', default=1, type=int)
+@click.option('--dataset', type=click.Choice(
+    ['oai_imo', 'okoa', 'maknee']))
+@click.option('--subset',  type=click.Choice(
+    ['test', 'all']))
+@click.option('--mask_mode', default='all_unitibial_unimeniscus', type=str)
+@click.option('--sample_mode', default='x_y', type=str)
+@click.option('--batch_size', default=64, type=int)
+@click.option('--fold_num', default=5, type=int)
+@click.option('--fold_idx', default=-1, type=int)
+@click.option('--fold_idx_ignore', multiple=True, type=int)
+@click.option('--num_workers', default=1, type=int)
+@click.option('--seed_trainval_test', default=0, type=int)
+@click.option('--predict_folds', is_flag=True)
+@click.option('--merge_predictions', is_flag=True)
+@click.option('--save_plots', is_flag=True)
+def main(**config):
+    config['path_data_root'] = os.path.abspath(config['path_data_root'])
+    config['path_experiment_root'] = os.path.abspath(config['path_experiment_root'])
+
+    config['path_weights'] = os.path.join(config['path_experiment_root'], 'weights')
+    if not os.path.exists(config['path_weights']):
+        raise ValueError('{} does not exist'.format(config['path_weights']))
+
+    config['path_predicts'] = os.path.join(
+        config['path_experiment_root'], f"predicts_{config['dataset']}_test")
+    config['path_logs'] = os.path.join(
+        config['path_experiment_root'], f"logs_{config['dataset']}_test")
+
+    os.makedirs(config['path_predicts'], exist_ok=True)
+    os.makedirs(config['path_logs'], exist_ok=True)
+
+    logging_fh = logging.FileHandler(
+        os.path.join(config['path_logs'], 'main.log'))
+    logging_fh.setLevel(logging.DEBUG)
+    logger.addHandler(logging_fh)
+
+    # Collect the available and specified sources
+    sources = sources_from_path(path_data_root=config['path_data_root'],
+                                selection=config['dataset'],
+                                with_folds=True,
+                                seed_trainval_test=config['seed_trainval_test'])
+
+    # Select the subset for evaluation
+    if config['subset'] == 'test':
+        logging.warning('Using the regular trainval-test split')
+    elif config['subset'] == 'all':
+        logging.warning('Using data selection: full dataset')
+        for s in sources:
+            sources[s]['test_df'] = sources[s]['sel_df']
+            logger.info(f"Selected number of samples: {len(sources[s]['test_df'])}")
+    else:
+        raise ValueError(f"Unknown dataset: {config['subset']}")
+
+    if config['dataset'] == 'oai_imo':
+        from rocaseg.datasets import DatasetOAIiMoSagittal2d as DatasetSagittal2d
+    elif config['dataset'] == 'okoa':
+        from rocaseg.datasets import DatasetOKOASagittal2d as DatasetSagittal2d
+    elif config['dataset'] == 'maknee':
+        from rocaseg.datasets import DatasetMAKNEESagittal2d as DatasetSagittal2d
+    else:
+        raise ValueError(f"Unknown dataset: {config['dataset']}")
+
+    # Configure dataset-dependent transforms
+    fn_crop = CenterCrop(height=300, width=300)
+    if config['dataset'] == 'oai_imo':
+        fn_norm = Normalize(mean=0.252699, std=0.251142)
+        fn_unnorm = UnNormalize(mean=0.252699, std=0.251142)
+    elif config['dataset'] == 'okoa':
+        fn_norm = Normalize(mean=0.232454, std=0.236259)
+        fn_unnorm = UnNormalize(mean=0.232454, std=0.236259)
+    else:
+        msg = f"No transforms defined for dataset: {config['dataset']}"
+        raise NotImplementedError(msg)
+    dict_fns = {'crop': fn_crop, 'norm': fn_norm, 'unnorm': fn_unnorm}
+
+    dataset_test = DatasetSagittal2d(
+        df_meta=sources[config['dataset']]['test_df'], mask_mode=config['mask_mode'],
+        name=config['dataset'], sample_mode=config['sample_mode'],
+        transforms=[
+            PercentileClippingAndToFloat(cut_min=10, cut_max=99),
+            fn_crop,
+            fn_norm,
+            ToTensor()
+        ])
+    loader_test = DataLoader(dataset_test,
+                             batch_size=config['batch_size'],
+                             shuffle=False,
+                             num_workers=config['num_workers'],
+                             drop_last=False)
+
+    # Build a list of folds to run on
+    if config['fold_idx'] == -1:
+        fold_idcs = list(range(config['fold_num']))
+    else:
+        fold_idcs = [config['fold_idx'], ]
+    for g in config['fold_idx_ignore']:
+        fold_idcs = [i for i in fold_idcs if i != g]
+
+    # Execute
+    with torch.no_grad():
+        if config['predict_folds']:
+            predict_folds(config=config, loader=loader_test, fold_idcs=fold_idcs)
+
+        if config['merge_predictions']:
+            merge_predictions(config=config, source=sources[config['dataset']],
+                              loader=loader_test, dict_fns=dict_fns,
+                              save_plots=config['save_plots'], remove_foldw=False,
+                              convert_to_nifti=True)
+
+
+if __name__ == '__main__':
+    main()