Diff of /evaluator.py [000000] .. [bb7f56]

Switch to unified view

a b/evaluator.py
1
#!/usr/bin/env python
2
# Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
# ==============================================================================
16
17
import os, time
18
from multiprocessing import Pool
19
20
import numpy as np
21
import pandas as pd
22
import torch
23
from sklearn.metrics import roc_auc_score, average_precision_score
24
from sklearn.metrics import roc_curve, precision_recall_curve
25
26
import utils.exp_utils as utils
27
import utils.model_utils as mutils
28
import utils.eval_util as eutils
29
import plotting
30
31
32
33
class Evaluator():
34
35
    def __init__(self, cf, logger, mode='test'):
36
        """
37
        :param mode: either 'val_sampling', 'val_patient' or 'test'. handles prediction lists of different forms.
38
        """
39
        self.cf = cf
40
        self.logger = logger
41
        self.mode = mode
42
43
        self.plot_dir = self.cf.test_dir if self.mode == "test" else self.cf.plot_dir
44
        if self.cf.plot_prediction_histograms:
45
            self.hist_dir = os.path.join(self.plot_dir, 'histograms')
46
            os.makedirs(self.hist_dir, exist_ok=True)
47
        if self.cf.plot_stat_curves:
48
            self.curves_dir = os.path.join(self.plot_dir, 'stat_curves')
49
            os.makedirs(self.curves_dir, exist_ok=True)
50
51
52
    def eval_losses(self, batch_res_dicts):
53
        if hasattr(self.cf, "losses_to_monitor"):
54
            loss_names = self.cf.losses_to_monitor
55
        else:
56
            loss_names = {name for b_res_dict in batch_res_dicts for name in b_res_dict if 'loss' in name}
57
        self.epoch_losses = {l_name: torch.tensor([b_res_dict[l_name] for b_res_dict in batch_res_dicts if l_name
58
                                                   in b_res_dict.keys()]).mean().item() for l_name in loss_names}
59
60
    def eval_boxes(self, batch_res_dicts, pid_list):
61
        """ """
62
63
        df_list_preds = []
64
        df_list_labels = []
65
        df_list_class_preds = []
66
        df_list_pids = []
67
        df_list_type = []
68
        df_list_match_iou = []
69
70
71
        if self.mode == 'train' or self.mode=='val_sampling':
72
            # one pid per batch element
73
            # batch_size > 1, with varying patients across batch:
74
            # [[[results_0, ...], [pid_0, ...]], [[results_n, ...], [pid_n, ...]], ...]
75
            # -> [results_0, results_1, ..]
76
            batch_inst_boxes = [b_res_dict['boxes'] for b_res_dict in batch_res_dicts]  # len: nr of batches in epoch
77
            batch_inst_boxes = [[b_inst_boxes] for whole_batch_boxes in batch_inst_boxes for b_inst_boxes in
78
                                whole_batch_boxes]
79
        else:
80
            # patient processing, one element per batch = one patient.
81
            # [[results_0, pid_0], [results_1, pid_1], ...] -> [results_0, results_1, ..]
82
            batch_inst_boxes = [b_res_dict['boxes'] for b_res_dict in batch_res_dicts]
83
84
        assert len(batch_inst_boxes) == len(pid_list)
85
86
        for match_iou in self.cf.ap_match_ious:
87
            self.logger.info('evaluating with match_iou: {}'.format(match_iou))
88
            for cl in list(self.cf.class_dict.keys()):
89
                for pix, pid in enumerate(pid_list):
90
91
                    len_df_list_before_patient = len(df_list_pids)
92
93
                    # input of each batch element is a list of boxes, where each box is a dictionary.
94
                    for bix, b_boxes_list in enumerate(batch_inst_boxes[pix]):
95
96
                        b_tar_boxes = np.array([box['box_coords'] for box in b_boxes_list if
97
                                                (box['box_type'] == 'gt' and box['box_label'] == cl)])
98
                        b_cand_boxes = np.array([box['box_coords'] for box in b_boxes_list if
99
                                                 (box['box_type'] == 'det' and
100
                                                  box['box_pred_class_id'] == cl)])
101
                        b_cand_scores = np.array([box['box_score'] for box in b_boxes_list if
102
                                                  (box['box_type'] == 'det' and
103
                                                   box['box_pred_class_id'] == cl)])
104
105
                        # check if predictions and ground truth boxes exist and match them according to match_iou.
106
                        if not 0 in b_cand_boxes.shape and not 0 in b_tar_boxes.shape:
107
                            overlaps = mutils.compute_overlaps(b_cand_boxes, b_tar_boxes)
108
                            match_cand_ixs = np.argwhere(np.max(overlaps, 1) > match_iou)[:, 0]
109
                            non_match_cand_ixs = np.argwhere(np.max(overlaps, 1) <= match_iou)[:, 0]
110
                            match_gt_ixs = np.argmax(overlaps[match_cand_ixs, :],
111
                                                     1) if not 0 in match_cand_ixs.shape else np.array([])
112
                            non_match_gt_ixs = np.array(
113
                                [ii for ii in np.arange(b_tar_boxes.shape[0]) if ii not in match_gt_ixs])
114
                            unique, counts = np.unique(match_gt_ixs, return_counts=True)
115
116
                            # check for double assignments, i.e. two predictions having been assigned to the same gt.
117
                            # according to the COCO-metrics, only one prediction counts as true positive, the rest counts as
118
                            # false positive. This case is supposed to be avoided by the model itself by,
119
                            #  e.g. using a low enough NMS threshold.
120
                            if np.any(counts > 1):
121
                                double_match_gt_ixs = unique[np.argwhere(counts > 1)[:, 0]]
122
                                keep_max = []
123
                                double_match_list = []
124
                                for dg in double_match_gt_ixs:
125
                                    double_match_cand_ixs = match_cand_ixs[np.argwhere(match_gt_ixs == dg)]
126
                                    keep_max.append(double_match_cand_ixs[np.argmax(b_cand_scores[double_match_cand_ixs])])
127
                                    double_match_list += [ii for ii in double_match_cand_ixs]
128
129
                                fp_ixs = np.array([ii for ii in match_cand_ixs if
130
                                                     (ii in double_match_list and ii not in keep_max)])
131
132
                                match_cand_ixs = np.array([ii for ii in match_cand_ixs if ii not in fp_ixs])
133
134
                                df_list_preds += [ii for ii in b_cand_scores[fp_ixs]]
135
                                df_list_labels += [0] * fp_ixs.shape[0]
136
                                df_list_class_preds += [cl] * fp_ixs.shape[0]
137
                                df_list_pids += [pid] * fp_ixs.shape[0]
138
                                df_list_type += ['det_fp'] * fp_ixs.shape[0]
139
140
                            # matched:
141
                            if not 0 in match_cand_ixs.shape:
142
                                df_list_preds += [ii for ii in b_cand_scores[match_cand_ixs]]
143
                                df_list_labels += [1] * match_cand_ixs.shape[0]
144
                                df_list_class_preds += [cl] * match_cand_ixs.shape[0]
145
                                df_list_pids += [pid] * match_cand_ixs.shape[0]
146
                                df_list_type += ['det_tp'] * match_cand_ixs.shape[0]
147
                            # rest fp:
148
                            if not 0 in non_match_cand_ixs.shape:
149
                                df_list_preds += [ii for ii in b_cand_scores[non_match_cand_ixs]]
150
                                df_list_labels += [0] * non_match_cand_ixs.shape[0]
151
                                df_list_class_preds += [cl] * non_match_cand_ixs.shape[0]
152
                                df_list_pids += [pid] * non_match_cand_ixs.shape[0]
153
                                df_list_type += ['det_fp'] * non_match_cand_ixs.shape[0]
154
                            # rest fn:
155
                            if not 0 in non_match_gt_ixs.shape:
156
                                df_list_preds += [0] * non_match_gt_ixs.shape[0]
157
                                df_list_labels += [1] * non_match_gt_ixs.shape[0]
158
                                df_list_class_preds += [cl] * non_match_gt_ixs.shape[0]
159
                                df_list_pids += [pid]  * non_match_gt_ixs.shape[0]
160
                                df_list_type += ['det_fn']  * non_match_gt_ixs.shape[0]
161
                        # only fp:
162
                        if not 0 in b_cand_boxes.shape and 0 in b_tar_boxes.shape:
163
                            df_list_preds += [ii for ii in b_cand_scores]
164
                            df_list_labels += [0] * b_cand_scores.shape[0]
165
                            df_list_class_preds += [cl] * b_cand_scores.shape[0]
166
                            df_list_pids += [pid] * b_cand_scores.shape[0]
167
                            df_list_type += ['det_fp'] * b_cand_scores.shape[0]
168
                        # only fn:
169
                        if 0 in b_cand_boxes.shape and not 0 in b_tar_boxes.shape:
170
                            df_list_preds += [0] * b_tar_boxes.shape[0]
171
                            df_list_labels += [1] * b_tar_boxes.shape[0]
172
                            df_list_class_preds += [cl] * b_tar_boxes.shape[0]
173
                            df_list_pids += [pid] * b_tar_boxes.shape[0]
174
                            df_list_type += ['det_fn'] * b_tar_boxes.shape[0]
175
176
                    # empty patient with 0 detections needs patient dummy score, in order to not disappear from stats.
177
                    # filtered out for roi-level evaluation later. During training (and val_sampling),
178
                    # tn are assigned per sample independently of associated patients.
179
                    if len(df_list_pids) == len_df_list_before_patient:
180
                        df_list_preds += [0] * 1
181
                        df_list_labels += [0] * 1
182
                        df_list_class_preds += [cl] * 1
183
                        df_list_pids += [pid] * 1
184
                        df_list_type += ['patient_tn'] * 1 # true negative: no ground truth boxes, no detections.
185
186
            df_list_match_iou += [match_iou] * (len(df_list_preds) - len(df_list_match_iou))
187
188
        self.test_df = pd.DataFrame()
189
        self.test_df['pred_score'] = df_list_preds
190
        self.test_df['class_label'] = df_list_labels
191
        self.test_df['pred_class'] = df_list_class_preds
192
        self.test_df['pid'] = df_list_pids
193
        self.test_df['det_type'] = df_list_type
194
        self.test_df['fold'] = self.cf.fold
195
        self.test_df['match_iou'] = df_list_match_iou
196
197
198
    def evaluate_predictions(self, results_list, monitor_metrics=None):
199
        """
200
        Performs the matching of predicted boxes and ground truth boxes. Loops over list of matching IoUs and foreground classes.
201
        Resulting info of each prediction is stored as one line in an internal dataframe, with the keys:
202
        det_type: 'tp' (true positive), 'fp' (false positive), 'fn' (false negative), 'tn' (true negative)
203
        pred_class: foreground class which the object predicts.
204
        pid: corresponding patient-id.
205
        pred_score: confidence score [0, 1]
206
        fold: corresponding fold of CV.
207
        match_iou: utilized IoU for matching.
208
        :param results_list: list of model predictions. Either from train/val_sampling (patch processing) for monitoring with form:
209
        [[[results_0, ...], [pid_0, ...]], [[results_n, ...], [pid_n, ...]], ...]
210
        Or from val_patient/testing (patient processing), with form: [[results_0, pid_0], [results_1, pid_1], ...])
211
        :param monitor_metrics (optional):  dict of dicts with all metrics of previous epochs.
212
        :return monitor_metrics: if provided (during training), return monitor_metrics now including results of current epoch.
213
        """
214
215
        self.logger.info('evaluating in mode {}'.format(self.mode))
216
217
        batch_res_dicts = [batch[0] for batch in results_list]  # len: nr of batches in epoch
218
        if self.mode == 'train' or self.mode == 'val_sampling':
219
            # one pid per batch element
220
            # [[[results_0, ...], [pid_0, ...]], [[results_n, ...], [pid_n, ...]], ...]
221
            # -> [pid_0, pid_1, ...]
222
            # additional list wrapping to make conform with below per-patient batches, where one pid is linked to more than one batch instance
223
            pid_list = [batch_instance_pid for batch in results_list for batch_instance_pid in batch[1]]
224
        elif self.mode == "val_patient" or self.mode == "test":
225
            # [[results_0, pid_0], [results_1, pid_1], ...] -> [pid_0, pid_1, ...]
226
            # in patientbatchiterator there is only one pid per batch
227
            pid_list = [np.unique(batch[1]) for batch in results_list]
228
            assert np.all([len(pid) == 1 for pid in
229
                           pid_list]), "pid list in patient-eval mode, should only contain a single scalar per patient: {}".format(
230
                pid_list)
231
            pid_list = [pid[0] for pid in pid_list]
232
            # todo remove assert
233
            pid_list_orig = [item[1] for item in results_list]
234
            assert np.all(pid_list == pid_list_orig)
235
        else:
236
            raise Exception("undefined run mode encountered")
237
238
        self.eval_losses(batch_res_dicts)
239
        self.eval_boxes(batch_res_dicts, pid_list)
240
241
        if monitor_metrics is not None:
242
            # return all_stats, updated monitor_metrics
243
            return self.return_metrics(monitor_metrics)
244
245
246
    def return_metrics(self, monitor_metrics=None):
247
        """
248
        calculates AP/AUC scores for internal dataframe. called directly from evaluate_predictions during training for monitoring,
249
        or from score_test_df during inference (for single folds or aggregated test set). Loops over foreground classes
250
        and score_levels (typically 'roi' and 'patient'), gets scores and stores them. Optionally creates plots of
251
        prediction histograms and roc/prc curves.
252
        :param monitor_metrics: dict of dicts with all metrics of previous epochs.
253
        this function adds metrics for current epoch and returns the same object.
254
        :return: all_stats: list. Contains dicts with resulting scores for each combination of foreground class and
255
        score_level.
256
        :return: monitor_metrics
257
        """
258
259
        # -------------- monitoring independent of class, score level ------------
260
        if monitor_metrics is not None:
261
            for l_name in self.epoch_losses:
262
                monitor_metrics[l_name] = [self.epoch_losses[l_name]]
263
264
265
        df = self.test_df
266
267
        all_stats = []
268
        for cl in list(self.cf.class_dict.keys()):
269
            cl_df = df[df.pred_class == cl]
270
271
            for score_level in self.cf.report_score_level:
272
                stats_dict = {}
273
                stats_dict['name'] = 'fold_{} {} cl_{}'.format(self.cf.fold, score_level, cl)
274
275
                if score_level == 'rois':
276
                    # kick out dummy entries for true negative patients. not needed on roi-level.
277
                    spec_df = cl_df[cl_df.det_type != 'patient_tn']
278
                    stats_dict['ap'] = get_roi_ap_from_df([spec_df, self.cf.min_det_thresh, self.cf.per_patient_ap])
279
                    # AUC not sensible on roi-level, since true negative box predictions do not exist. Would reward
280
                    # higher amounts of low confidence false positives.
281
                    stats_dict['auc'] = np.nan
282
                    stats_dict['roc'] = np.nan
283
                    stats_dict['prc'] = np.nan
284
285
                    # for the aggregated test set case, additionally get the scores for averaging over fold results.
286
                    if len(df.fold.unique()) > 1:
287
                        aps = []
288
                        for fold in df.fold.unique():
289
                            fold_df = spec_df[spec_df.fold == fold]
290
                            aps.append(get_roi_ap_from_df([fold_df, self.cf.min_det_thresh, self.cf.per_patient_ap]))
291
                        stats_dict['mean_ap'] = np.mean(aps)
292
                        stats_dict['mean_auc'] = 0
293
294
                # on patient level, aggregate predictions per patient (pid): The patient predicted score is the highest
295
                # confidence prediction for this class. The patient class label is 1 if roi of this class exists in patient, else 0.
296
                if score_level == 'patient':
297
                    # spec_df = cl_df.groupby(['pid'], as_index=False).agg({'class_label': 'max', 'pred_score': 'max', 'fold': 'first'})
298
                    spec_df = cl_df.groupby(["pid"], as_index=False).apply(eutils.patient_based_filter)
299
300
                    if len(spec_df.class_label.unique()) > 1:
301
                        stats_dict['auc'] = roc_auc_score(spec_df.class_label.tolist(), spec_df.pred_score.tolist())
302
                        stats_dict['roc'] = roc_curve(spec_df.class_label.tolist(), spec_df.pred_score.tolist())
303
                    else:
304
                        stats_dict['auc'] = np.nan
305
                        stats_dict['roc'] = np.nan
306
307
                    if (spec_df.class_label == 1).any():
308
                        stats_dict['ap'] = average_precision_score(spec_df.class_label.tolist(), spec_df.pred_score.tolist())
309
                        stats_dict['prc'] = precision_recall_curve(spec_df.class_label.tolist(), spec_df.pred_score.tolist())
310
                    else:
311
                        stats_dict['ap'] = np.nan
312
                        stats_dict['prc'] = np.nan
313
314
                    # for the aggregated test set case, additionally get the scores for averaging over fold results.
315
                    if len(df.fold.unique()) > 1:
316
                        aucs = []
317
                        aps = []
318
                        for fold in df.fold.unique():
319
                            fold_df = spec_df[spec_df.fold == fold]
320
                            if len(fold_df.class_label.unique()) > 1:
321
                                aucs.append(roc_auc_score(fold_df.class_label.tolist(), fold_df.pred_score.tolist()))
322
                            if (fold_df.class_label == 1).any():
323
                                aps.append(average_precision_score(fold_df.class_label.tolist(), fold_df.pred_score.tolist()))
324
                        stats_dict['mean_auc'] = np.mean(aucs)
325
                        stats_dict['mean_ap'] = np.mean(aps)
326
327
                # fill new results into monitor_metrics dict. for simplicity, only one class (of interest) is monitored on patient level.
328
                if monitor_metrics is not None and not (score_level == 'patient' and cl != self.cf.patient_class_of_interest):
329
                    score_level_name = 'patient' if score_level == 'patient' else self.cf.class_dict[cl]
330
                    monitor_metrics[score_level_name + '_ap'].append(stats_dict['ap'] if stats_dict['ap'] > 0 else np.nan)
331
                    if score_level == 'patient':
332
                        monitor_metrics[score_level_name + '_auc'].append(
333
                            stats_dict['auc'] if stats_dict['auc'] > 0 else np.nan)
334
335
                if self.cf.plot_prediction_histograms:
336
                    out_filename = os.path.join(self.hist_dir, 'pred_hist_{}_{}_{}_cl{}'.format(
337
                        self.cf.fold, 'val' if 'val' in self.mode else self.mode, score_level, cl))
338
                    # type_list = None if score_level == 'patient' else spec_df.det_type.tolist()
339
                    type_list = spec_df.det_type.tolist()
340
                    utils.split_off_process(plotting.plot_prediction_hist, spec_df.class_label.tolist(),
341
                                            spec_df.pred_score.tolist(), type_list, out_filename)
342
343
                all_stats.append(stats_dict)
344
345
                # analysis of the  hyper-parameter cf.min_det_thresh, for optimization on validation set.
346
                if self.cf.scan_det_thresh:
347
                    conf_threshs = list(np.arange(0.9, 1, 0.01))
348
                    pool = Pool(processes=8)
349
                    mp_inputs = [[spec_df, ii, self.cf.per_patient_ap] for ii in conf_threshs]
350
                    aps = pool.map(get_roi_ap_from_df, mp_inputs, chunksize=1)
351
                    pool.close()
352
                    pool.join()
353
                    self.logger.info('results from scanning over det_threshs:', [[i, j] for i, j in zip(conf_threshs, aps)])
354
355
        if self.cf.plot_stat_curves:
356
            out_filename = os.path.join(self.curves_dir, '{}_{}_stat_curves'.format(self.cf.fold, self.mode))
357
            utils.split_off_process(plotting.plot_stat_curves, all_stats, out_filename)
358
359
        # get average stats over foreground classes on roi level.
360
        avg_ap = np.mean([d['ap'] for d in all_stats if 'rois' in d['name']])
361
        all_stats.append({'name': 'average_foreground_roi', 'auc': 0, 'ap': avg_ap})
362
        if len(df.fold.unique()) > 1:
363
            avg_mean_ap = np.mean([d['mean_ap'] for d in all_stats if 'rois' in d['name']])
364
            all_stats[-1]['mean_ap'] = avg_mean_ap
365
            all_stats[-1]['mean_auc'] = 0
366
367
        # in small data sets, values of model_selection_criterion can be identical across epochs, wich breaks the
368
        # ranking of model_selector. Thus, pertube identical values by a neglectibale random term.
369
        for sc in self.cf.model_selection_criteria:
370
            if 'val' in self.mode and monitor_metrics[sc].count(monitor_metrics[sc][-1]) > 1 and monitor_metrics[sc][-1] is not None:
371
                monitor_metrics[sc][-1] += 1e-6 * np.random.rand()
372
373
        return all_stats, monitor_metrics
374
375
376
    def write_to_results_table(self, stats, metrics_to_score, out_path):
377
        """Write overall results to a common inter-experiment table.
378
        :param metrics_to_score:
379
        :return:
380
        """
381
382
        with open(out_path, 'a') as handle:
383
            # ---column headers---
384
            handle.write('\n{},'.format("Experiment Name"))
385
            handle.write('{},'.format("Time Stamp"))
386
            handle.write('{},'.format("Samples Seen"))
387
            handle.write('{},'.format("Spatial Dim"))
388
            handle.write('{},'.format("Patch Size"))
389
            handle.write('{},'.format("CV Folds"))
390
            handle.write('{},'.format("WBC IoU"))
391
            handle.write('{},'.format("Merge-2D-to-3D IoU"))
392
            for s in stats:
393
                #if self.cf.class_dict[self.cf.patient_class_of_interest] in s['name'] or "average" in s["name"]:
394
                for metric in metrics_to_score:
395
                    if metric in s.keys() and not np.isnan(s[metric]):
396
                        if metric == 'ap':
397
                            handle.write('{} : {}_{},'.format(s['name'], metric.upper(),
398
                                                              "_".join((np.array(self.cf.ap_match_ious) * 100)
399
                                                                       .astype("int").astype("str"))))
400
                        else:
401
                            handle.write('{} : {},'.format(s['name'], metric.upper()))
402
                    else:
403
                        print("WARNING: skipped metric {} since not avail".format(metric))
404
            handle.write('\n')
405
406
            # --- columns content---
407
            handle.write('{},'.format(self.cf.exp_dir.split(os.sep)[-1]))
408
            handle.write('{},'.format(time.strftime("%d%b%y %H:%M:%S")))
409
            handle.write('{},'.format(self.cf.num_epochs * self.cf.num_train_batches * self.cf.batch_size))
410
            handle.write('{}D,'.format(self.cf.dim))
411
            handle.write('{},'.format("x".join([str(self.cf.patch_size[i]) for i in range(self.cf.dim)])))
412
            handle.write('{},'.format(str(self.test_df.fold.unique().tolist()).replace(",", "")))
413
            handle.write('{},'.format(self.cf.wcs_iou))
414
            handle.write('{},'.format(self.cf.merge_3D_iou if self.cf.merge_2D_to_3D_preds else str("N/A")))
415
            for s in stats:
416
                #if self.cf.class_dict[self.cf.patient_class_of_interest] in s['name'] or "mean" in s["name"]:
417
                for metric in metrics_to_score:
418
                    if metric in s.keys() and not np.isnan(s[metric]):
419
                        handle.write('{:0.3f}, '.format(s[metric]))
420
            handle.write('\n')
421
422
    def score_test_df(self, internal_df=True):
423
        """
424
        Writes out resulting scores to text files: First checks for class-internal-df (typically current) fold,
425
        gets resulting scores, writes them to a text file and pickles data frame. Also checks if data-frame pickles of
426
        all folds of cross-validation exist in exp_dir. If true, loads all dataframes, aggregates test sets over folds,
427
        and calculates and writes out overall metrics.
428
        """
429
        if internal_df:
430
431
            self.test_df.to_pickle(os.path.join(self.cf.test_dir, '{}_test_df.pickle'.format(self.cf.fold)))
432
            stats, _ = self.return_metrics()
433
434
            with open(os.path.join(self.cf.test_dir, 'results.txt'), 'a') as handle:
435
                handle.write('\n****************************\n')
436
                handle.write('\nresults for fold {} \n'.format(self.cf.fold))
437
                handle.write('\n****************************\n')
438
                handle.write('\nfold df shape {}\n  \n'.format(self.test_df.shape))
439
                for s in stats:
440
                    handle.write('AUC {:0.4f}  AP {:0.4f} {} \n'.format(s['auc'], s['ap'], s['name']))
441
442
            fold_df_paths = [ii for ii in os.listdir(self.cf.test_dir) if ('test_df.pickle' in ii and not 'overall' in ii)]
443
            if len(fold_df_paths) == self.cf.n_cv_splits:
444
                results_table_path = os.path.join((os.sep).join(self.cf.exp_dir.split(os.sep)[:-1]), 'results_table.csv')
445
446
                if not self.cf.hold_out_test_set or not self.cf.ensemble_folds:
447
                    with open(os.path.join(self.cf.test_dir, 'results.txt'), 'a') as handle:
448
                        self.cf.fold = 'overall'
449
                        dfs_list = [pd.read_pickle(os.path.join(self.cf.test_dir, ii)) for ii in fold_df_paths]
450
                        for ix, df in enumerate(dfs_list):
451
                            df['fold'] = ix
452
                        self.test_df = pd.concat(dfs_list)
453
                        stats, _ = self.return_metrics()
454
                        handle.write('\n****************************\n')
455
                        handle.write('\nOVERALL RESULTS \n')
456
                        handle.write('\n****************************\n')
457
                        handle.write('\ndf shape \n  \n'.format(self.test_df.shape))
458
                        for s in stats:
459
                            handle.write('\nAUC {:0.4f} (mu {:0.4f})  AP {:0.4f} (mu {:0.4f})  {}\n '
460
                                         .format(s['auc'], s['mean_auc'], s['ap'], s['mean_ap'], s['name']))
461
                    metrics_to_score = ["auc", "mean_auc", "ap", "mean_ap"]
462
                    self.write_to_results_table(stats, metrics_to_score, out_path=results_table_path)
463
                else:
464
                    metrics_to_score = ["auc", "ap"]
465
                    self.write_to_results_table(stats, metrics_to_score, out_path=results_table_path)
466
467
468
def get_roi_ap_from_df(inputs):
469
    '''
470
    :param df: data frame.
471
    :param det_thresh: min_threshold for filtering out low confidence predictions.
472
    :param per_patient_ap: boolean flag. evaluate average precision per image and average over images,
473
    instead of computing one ap over data set.
474
    :return: average_precision (float)
475
    '''
476
    df, det_thresh, per_patient_ap = inputs
477
478
    if per_patient_ap:
479
        pids_list = df.pid.unique()
480
        aps = []
481
        for match_iou in df.match_iou.unique():
482
            iou_df = df[df.match_iou == match_iou]
483
            for pid in pids_list:
484
                pid_df = iou_df[iou_df.pid == pid]
485
                all_p = len(pid_df[pid_df.class_label == 1])
486
                pid_df = pid_df[(pid_df.det_type == 'det_fp') | (pid_df.det_type == 'det_tp')].sort_values('pred_score', ascending=False)
487
                pid_df = pid_df[pid_df.pred_score > det_thresh]
488
                if (len(pid_df) ==0 and all_p == 0):
489
                   pass
490
                elif (len(pid_df) > 0 and all_p == 0):
491
                    aps.append(0)
492
                else:
493
                    aps.append(compute_roi_ap(pid_df, all_p))
494
        return np.mean(aps)
495
496
    else:
497
        aps = []
498
        for match_iou in df.match_iou.unique():
499
            iou_df = df[df.match_iou == match_iou]
500
            all_p = len(iou_df[iou_df.class_label == 1])
501
            iou_df = iou_df[(iou_df.det_type == 'det_fp') | (iou_df.det_type == 'det_tp')].sort_values('pred_score', ascending=False)
502
            iou_df = iou_df[iou_df.pred_score > det_thresh]
503
            if all_p > 0:
504
                aps.append(compute_roi_ap(iou_df, all_p))
505
        return np.mean(aps)
506
507
508
509
def compute_roi_ap(df, all_p):
510
    """
511
    adapted from: https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/cocoeval.py
512
    :param df: dataframe containing class labels of predictions sorted in descending manner by their prediction score.
513
    :param all_p: number of all ground truth objects. (for denominator of recall.)
514
    :return:
515
    """
516
    tp = df.class_label.values
517
    fp = (tp == 0) * 1
518
    #recall thresholds, where precision will be measured
519
    R = np.linspace(.0, 1, 101, endpoint=True)
520
    tp_sum = np.cumsum(tp)
521
    fp_sum = np.cumsum(fp)
522
    nd = len(tp)
523
    rc = tp_sum / all_p
524
    pr = tp_sum / (fp_sum + tp_sum)
525
    # initialize precision array over recall steps.
526
    q = np.zeros((len(R),))
527
528
    # numpy is slow without cython optimization for accessing elements
529
    # use python array gets significant speed improvement
530
    pr = pr.tolist()
531
    q = q.tolist()
532
    for i in range(nd - 1, 0, -1):
533
        if pr[i] > pr[i - 1]:
534
            pr[i - 1] = pr[i]
535
536
    #discretize empiric recall steps with given bins.
537
    inds = np.searchsorted(rc, R, side='left')
538
    try:
539
        for ri, pi in enumerate(inds):
540
            q[ri] = pr[pi]
541
    except:
542
        pass
543
544
    return np.mean(q)