Diff of /rocaseg/evaluate.py [000000] .. [6969be]

Switch to unified view

a b/rocaseg/evaluate.py
1
import os
2
import logging
3
from glob import glob
4
5
import numpy as np
6
from skimage.color import label2rgb
7
from skimage import img_as_ubyte
8
from tqdm import tqdm
9
import click
10
11
import cv2
12
import tifffile
13
import torch
14
import torch.nn as nn
15
from torch.utils.data.dataloader import DataLoader
16
17
from rocaseg.datasets import sources_from_path
18
from rocaseg.components import CheckpointHandler
19
from rocaseg.components.formats import numpy_to_nifti, png_to_numpy
20
from rocaseg.models import dict_models
21
from rocaseg.preproc import *
22
from rocaseg.repro import set_ultimate_seed
23
24
25
# The fix is a workaround to PyTorch multiprocessing issue:
26
# "RuntimeError: received 0 items of ancdata"
27
torch.multiprocessing.set_sharing_strategy('file_system')
28
29
cv2.ocl.setUseOpenCL(False)
30
cv2.setNumThreads(0)
31
32
logging.basicConfig()
33
logger = logging.getLogger('eval')
34
logger.setLevel(logging.INFO)
35
36
set_ultimate_seed()
37
38
if torch.cuda.is_available():
39
    maybe_gpu = 'cuda'
40
else:
41
    maybe_gpu = 'cpu'
42
43
44
def predict_folds(config, loader, fold_idcs):
45
    """Evaluate the model versus each fold
46
    """
47
    for fold_idx in fold_idcs:
48
        paths_weights_fold = dict()
49
        paths_weights_fold['segm'] = \
50
            os.path.join(config['path_weights'], 'segm', f'fold_{fold_idx}')
51
52
        handlers_ckpt = dict()
53
        handlers_ckpt['segm'] = CheckpointHandler(paths_weights_fold['segm'])
54
55
        paths_ckpt_sel = dict()
56
        paths_ckpt_sel['segm'] = handlers_ckpt['segm'].get_last_ckpt()
57
58
        # Initialize and configure the model
59
        model = (dict_models[config['model_segm']]
60
                 (input_channels=config['input_channels'],
61
                  output_channels=config['output_channels'],
62
                  center_depth=config['center_depth'],
63
                  pretrained=config['pretrained'],
64
                  restore_weights=config['restore_weights'],
65
                  path_weights=paths_ckpt_sel['segm']))
66
        model = nn.DataParallel(model).to(maybe_gpu)
67
        model.eval()
68
69
        with tqdm(total=len(loader), desc=f'Eval, fold {fold_idx}') as prog_bar:
70
            for i, data_batch in enumerate(loader):
71
                xs, ys_true = data_batch['xs'], data_batch['ys']
72
                xs, ys_true = xs.to(maybe_gpu), ys_true.to(maybe_gpu)
73
74
                if config['model_segm'] == 'unet_lext':
75
                    ys_pred = model(xs)
76
                elif config['model_segm'] == 'unet_lext_aux':
77
                    ys_pred, _ = model(xs)
78
                else:
79
                    msg = f"Unknown model {config['model_segm']}"
80
                    raise ValueError(msg)
81
82
                ys_pred_softmax = nn.Softmax(dim=1)(ys_pred)
83
                ys_pred_softmax_np = ys_pred_softmax.detach().to('cpu').numpy()
84
85
                data_batch['pred_softmax'] = ys_pred_softmax_np
86
87
                # Rearrange the batch
88
                data_dicts = [{k: v[n] for k, v in data_batch.items()}
89
                              for n in range(len(data_batch['image']))]
90
91
                for k, data_dict in enumerate(data_dicts):
92
                    dir_base = os.path.join(
93
                        config['path_predicts'],
94
                        data_dict['patient'], data_dict['release'], data_dict['sequence'])
95
                    fname_base = os.path.splitext(
96
                        os.path.basename(data_dict['path_rel_image']))[0]
97
98
                    # Save the predictions
99
                    dir_predicts = os.path.join(dir_base, 'mask_folds')
100
                    if not os.path.exists(dir_predicts):
101
                        os.makedirs(dir_predicts)
102
103
                    fname_full = os.path.join(
104
                        dir_predicts,
105
                        f'{fname_base}_fold_{fold_idx}.tiff')
106
107
                    tmp = (data_dict['pred_softmax'] * 255).astype(np.uint8, casting='unsafe')
108
                    tifffile.imsave(fname_full, tmp, compress=9)
109
110
                prog_bar.update(1)
111
112
113
def merge_predictions(config, source, loader, dict_fns,
114
                      save_plots=False, remove_foldw=False, convert_to_nifti=True):
115
    """Merge the predictions over all folds
116
    """
117
    dir_source_root = source['path_root']
118
    df_meta = loader.dataset.df_meta
119
120
    with tqdm(total=len(df_meta), desc='Merge') as prog_bar:
121
        for i, row in df_meta.iterrows():
122
            dir_scan_predicts = os.path.join(
123
                config['path_predicts'],
124
                row['patient'], row['release'], row['sequence'])
125
            dir_image_prep = os.path.join(dir_scan_predicts, 'image_prep')
126
            dir_mask_prep = os.path.join(dir_scan_predicts, 'mask_prep')
127
            dir_mask_folds = os.path.join(dir_scan_predicts, 'mask_folds')
128
            dir_mask_foldavg = os.path.join(dir_scan_predicts, 'mask_foldavg')
129
            dir_vis_foldavg = os.path.join(dir_scan_predicts, 'vis_foldavg')
130
131
            for p in (dir_image_prep, dir_mask_prep, dir_mask_folds, dir_mask_foldavg,
132
                      dir_vis_foldavg):
133
                if not os.path.exists(p):
134
                    os.makedirs(p)
135
136
            # Find the corresponding prediction files
137
            fname_base = os.path.splitext(os.path.basename(row['path_rel_image']))[0]
138
139
            fnames_pred = glob(os.path.join(dir_mask_folds, f'{fname_base}_fold_*.*'))
140
141
            # Read the reference data
142
            image = cv2.imread(
143
                os.path.join(dir_source_root, row['path_rel_image']),
144
                cv2.IMREAD_GRAYSCALE)
145
            image = dict_fns['crop'](image[None, ])[0]
146
            image = np.squeeze(image)
147
            if 'path_rel_mask' in row.index:
148
                ys_true = loader.dataset.read_mask(
149
                    os.path.join(dir_source_root, row['path_rel_mask']))
150
                if ys_true is not None:
151
                    ys_true = dict_fns['crop'](ys_true)[0]
152
            else:
153
                ys_true = None
154
155
            # Read the fold-wise predictions
156
            yss_pred = [tifffile.imread(f) for f in fnames_pred]
157
            ys_pred = np.stack(yss_pred, axis=0).astype(np.float32) / 255
158
            ys_pred = torch.from_numpy(ys_pred).unsqueeze(dim=0)
159
160
            # Average the fold predictions
161
            ys_pred = torch.mean(ys_pred, dim=1, keepdim=False)
162
            ys_pred_softmax = ys_pred / torch.sum(ys_pred, dim=1, keepdim=True)
163
            ys_pred_softmax_np = ys_pred_softmax.squeeze().numpy()
164
165
            ys_pred_arg_np = ys_pred_softmax_np.argmax(axis=0)
166
167
            # Save preprocessed input data
168
            fname_full = os.path.join(dir_image_prep, f'{fname_base}.png')
169
            cv2.imwrite(fname_full, image)  # image
170
171
            if ys_true is not None:
172
                ys_true = ys_true.astype(np.float32)
173
                ys_true = torch.from_numpy(ys_true).unsqueeze(dim=0)
174
                ys_true_arg_np = ys_true.numpy().squeeze().argmax(axis=0)
175
                fname_full = os.path.join(dir_mask_prep, f'{fname_base}.png')
176
                cv2.imwrite(fname_full, ys_true_arg_np)  # mask
177
178
            fname_meta = os.path.join(config['path_predicts'], 'meta_dynamic.csv')
179
            if not os.path.exists(fname_meta):
180
                df_meta.to_csv(fname_meta, index=False)  # metainfo
181
182
            # Save ensemble prediction
183
            fname_full = os.path.join(dir_mask_foldavg, f'{fname_base}.png')
184
            cv2.imwrite(fname_full, ys_pred_arg_np)
185
186
            # Save ensemble visualizations
187
            if save_plots:
188
                if ys_true is not None:
189
                    fname_full = os.path.join(
190
                        dir_vis_foldavg, f"{fname_base}_overlay_mask.png")
191
                    save_vis_overlay(image=image,
192
                                     mask=ys_true_arg_np,
193
                                     num_classes=config['output_channels'],
194
                                     fname=fname_full)
195
196
                fname_full = os.path.join(
197
                    dir_vis_foldavg, f"{fname_base}_overlay_pred.png")
198
                save_vis_overlay(image=image,
199
                                 mask=ys_pred_arg_np,
200
                                 num_classes=config['output_channels'],
201
                                 fname=fname_full)
202
203
                if ys_true is not None:
204
                    fname_full = os.path.join(
205
                        dir_vis_foldavg, f"{fname_base}_overlay_diff.png")
206
                    save_vis_mask_diff(image=image,
207
                                       mask_true=ys_true_arg_np,
208
                                       mask_pred=ys_pred_arg_np,
209
                                       fname=fname_full)
210
211
            # Remove the fold predictions
212
            if remove_foldw:
213
                for f in fnames_pred:
214
                    try:
215
                        os.remove(f)
216
                    except OSError:
217
                        logger.error(f'Cannot remove {f}')
218
            prog_bar.update(1)
219
220
    # Convert the results to 3D NIfTI images
221
    if convert_to_nifti:
222
        df_meta = df_meta.sort_values(by=["patient", "release", "sequence", "side"])
223
224
        for gb_name, gb_df in tqdm(
225
                df_meta.groupby(["patient", "release", "sequence", "side"]),
226
                desc="Convert to NIfTI"):
227
228
            patient, release, sequence, side = gb_name
229
            spacings = (gb_df['pixel_spacing_0'].iloc[0],
230
                        gb_df['pixel_spacing_1'].iloc[0],
231
                        gb_df['slice_thickness'].iloc[0])
232
233
            dir_scan_predicts = os.path.join(config['path_predicts'],
234
                                             patient, release, sequence)
235
            for result in ("image_prep", "mask_prep", "mask_foldavg"):
236
                pattern = os.path.join(dir_scan_predicts, result, '*.png')
237
                path_nii = os.path.join(dir_scan_predicts, f"{result}.nii")
238
239
                # Read and compose 3D image
240
                img = png_to_numpy(pattern_fname_in=pattern, reverse=False)
241
242
                # Save to NIfTI
243
                numpy_to_nifti(stack=img, fname_out=path_nii,
244
                               spacings=spacings, rcp_to_ras=True)
245
246
247
def save_vis_overlay(image, mask, num_classes, fname):
248
    # Add a sample of each class to have consistent class colors
249
    mask[0, :num_classes] = list(range(num_classes))
250
    overlay = label2rgb(label=mask, image=image, bg_label=0,
251
                        colors=['orangered', 'gold', 'lime', 'fuchsia'])
252
    # Convert to uint8 to save space
253
    overlay = img_as_ubyte(overlay)
254
    # Save to file
255
    if overlay.ndim == 3:
256
        overlay = overlay[:, :, ::-1]
257
    cv2.imwrite(fname, overlay)
258
259
260
def save_vis_mask_diff(image, mask_true, mask_pred, fname):
261
    diff = np.empty_like(mask_true)
262
    diff[(mask_true == mask_pred) & (mask_pred == 0)] = 0  # TN
263
    diff[(mask_true == mask_pred) & (mask_pred != 0)] = 0  # TP
264
    diff[(mask_true != mask_pred) & (mask_pred == 0)] = 2  # FP
265
    diff[(mask_true != mask_pred) & (mask_pred != 0)] = 3  # FN
266
    diff_colors = ('green', 'red', 'yellow')
267
    diff[0, :4] = [0, 1, 2, 3]
268
    overlay = label2rgb(label=diff, image=image, bg_label=0,
269
                        colors=diff_colors)
270
    # Convert to uint8 to save space
271
    overlay = img_as_ubyte(overlay)
272
    # Save to file
273
    if overlay.ndim == 3:
274
        overlay = overlay[:, :, ::-1]
275
    cv2.imwrite(fname, overlay)
276
277
278
@click.command()
279
@click.option('--path_data_root', default='../../data')
280
@click.option('--path_experiment_root', default='../../results/temporary')
281
@click.option('--model_segm', default='unet_lext')
282
@click.option('--center_depth', default=1, type=int)
283
@click.option('--pretrained', is_flag=True)
284
@click.option('--restore_weights', is_flag=True)
285
@click.option('--input_channels', default=1, type=int)
286
@click.option('--output_channels', default=1, type=int)
287
@click.option('--dataset', type=click.Choice(
288
    ['oai_imo', 'okoa', 'maknee']))
289
@click.option('--subset',  type=click.Choice(
290
    ['test', 'all']))
291
@click.option('--mask_mode', default='all_unitibial_unimeniscus', type=str)
292
@click.option('--sample_mode', default='x_y', type=str)
293
@click.option('--batch_size', default=64, type=int)
294
@click.option('--fold_num', default=5, type=int)
295
@click.option('--fold_idx', default=-1, type=int)
296
@click.option('--fold_idx_ignore', multiple=True, type=int)
297
@click.option('--num_workers', default=1, type=int)
298
@click.option('--seed_trainval_test', default=0, type=int)
299
@click.option('--predict_folds', is_flag=True)
300
@click.option('--merge_predictions', is_flag=True)
301
@click.option('--save_plots', is_flag=True)
302
def main(**config):
303
    config['path_data_root'] = os.path.abspath(config['path_data_root'])
304
    config['path_experiment_root'] = os.path.abspath(config['path_experiment_root'])
305
306
    config['path_weights'] = os.path.join(config['path_experiment_root'], 'weights')
307
    if not os.path.exists(config['path_weights']):
308
        raise ValueError('{} does not exist'.format(config['path_weights']))
309
310
    config['path_predicts'] = os.path.join(
311
        config['path_experiment_root'], f"predicts_{config['dataset']}_test")
312
    config['path_logs'] = os.path.join(
313
        config['path_experiment_root'], f"logs_{config['dataset']}_test")
314
315
    os.makedirs(config['path_predicts'], exist_ok=True)
316
    os.makedirs(config['path_logs'], exist_ok=True)
317
318
    logging_fh = logging.FileHandler(
319
        os.path.join(config['path_logs'], 'main.log'))
320
    logging_fh.setLevel(logging.DEBUG)
321
    logger.addHandler(logging_fh)
322
323
    # Collect the available and specified sources
324
    sources = sources_from_path(path_data_root=config['path_data_root'],
325
                                selection=config['dataset'],
326
                                with_folds=True,
327
                                seed_trainval_test=config['seed_trainval_test'])
328
329
    # Select the subset for evaluation
330
    if config['subset'] == 'test':
331
        logging.warning('Using the regular trainval-test split')
332
    elif config['subset'] == 'all':
333
        logging.warning('Using data selection: full dataset')
334
        for s in sources:
335
            sources[s]['test_df'] = sources[s]['sel_df']
336
            logger.info(f"Selected number of samples: {len(sources[s]['test_df'])}")
337
    else:
338
        raise ValueError(f"Unknown dataset: {config['subset']}")
339
340
    if config['dataset'] == 'oai_imo':
341
        from rocaseg.datasets import DatasetOAIiMoSagittal2d as DatasetSagittal2d
342
    elif config['dataset'] == 'okoa':
343
        from rocaseg.datasets import DatasetOKOASagittal2d as DatasetSagittal2d
344
    elif config['dataset'] == 'maknee':
345
        from rocaseg.datasets import DatasetMAKNEESagittal2d as DatasetSagittal2d
346
    else:
347
        raise ValueError(f"Unknown dataset: {config['dataset']}")
348
349
    # Configure dataset-dependent transforms
350
    fn_crop = CenterCrop(height=300, width=300)
351
    if config['dataset'] == 'oai_imo':
352
        fn_norm = Normalize(mean=0.252699, std=0.251142)
353
        fn_unnorm = UnNormalize(mean=0.252699, std=0.251142)
354
    elif config['dataset'] == 'okoa':
355
        fn_norm = Normalize(mean=0.232454, std=0.236259)
356
        fn_unnorm = UnNormalize(mean=0.232454, std=0.236259)
357
    else:
358
        msg = f"No transforms defined for dataset: {config['dataset']}"
359
        raise NotImplementedError(msg)
360
    dict_fns = {'crop': fn_crop, 'norm': fn_norm, 'unnorm': fn_unnorm}
361
362
    dataset_test = DatasetSagittal2d(
363
        df_meta=sources[config['dataset']]['test_df'], mask_mode=config['mask_mode'],
364
        name=config['dataset'], sample_mode=config['sample_mode'],
365
        transforms=[
366
            PercentileClippingAndToFloat(cut_min=10, cut_max=99),
367
            fn_crop,
368
            fn_norm,
369
            ToTensor()
370
        ])
371
    loader_test = DataLoader(dataset_test,
372
                             batch_size=config['batch_size'],
373
                             shuffle=False,
374
                             num_workers=config['num_workers'],
375
                             drop_last=False)
376
377
    # Build a list of folds to run on
378
    if config['fold_idx'] == -1:
379
        fold_idcs = list(range(config['fold_num']))
380
    else:
381
        fold_idcs = [config['fold_idx'], ]
382
    for g in config['fold_idx_ignore']:
383
        fold_idcs = [i for i in fold_idcs if i != g]
384
385
    # Execute
386
    with torch.no_grad():
387
        if config['predict_folds']:
388
            predict_folds(config=config, loader=loader_test, fold_idcs=fold_idcs)
389
390
        if config['merge_predictions']:
391
            merge_predictions(config=config, source=sources[config['dataset']],
392
                              loader=loader_test, dict_fns=dict_fns,
393
                              save_plots=config['save_plots'], remove_foldw=False,
394
                              convert_to_nifti=True)
395
396
397
if __name__ == '__main__':
398
    main()