--- a +++ b/predictor.py @@ -0,0 +1,898 @@ +#!/usr/bin/env python +# Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ). +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import os +import code +import numpy as np +import torch +from scipy.stats import norm +from collections import OrderedDict +from multiprocessing import Pool +import pickle +from copy import deepcopy +import pandas as pd + +import utils.exp_utils as utils +from plotting import plot_batch_prediction + + +class Predictor: + """ + Prediction pipeline: + - receives a patched patient image (n_patches, c, y, x, (z)) from patient data loader. + - forwards patches through model in chunks of batch_size. (method: batch_tiling_forward) + - unmolds predictions (boxes and segmentations) to original patient coordinates. (method: spatial_tiling_forward) + + Ensembling (mode == 'test'): + - for inference, forwards 4 mirrored versions of image to through model and unmolds predictions afterwards + accordingly (method: data_aug_forward) + - for inference, loads multiple parameter-sets of the trained model corresponding to different epochs. for each + parameter-set loops over entire test set, runs prediction pipeline for each patient. (method: predict_test_set) + + Consolidation of predictions: + - consolidates a patient's predictions (boxes, segmentations) collected over patches, data_aug- and temporal ensembling, + performs clustering and weighted averaging (external function: apply_wbc_to_patient) to obtain consistent outptus. + - for 2D networks, consolidates box predictions to 3D cubes via clustering (adaption of non-maximum surpression). + (external function: merge_2D_to_3D_preds_per_patient) + + Ground truth handling: + - dissmisses any ground truth boxes returned by the model (happens in validation mode, patch-based groundtruth) + - if provided by data loader, adds 3D ground truth to the final predictions to be passed to the evaluator. + """ + def __init__(self, cf, net, logger, mode): + + self.cf = cf + self.logger = logger + + # mode is 'val' for patient-based validation/monitoring and 'test' for inference. + self.mode = mode + + # model instance. In validation mode, contains parameters of current epoch. + self.net = net + + # rank of current epoch loaded (for temporal averaging). this info is added to each prediction, + # for correct weighting during consolidation. + self.rank_ix = '0' + + # number of ensembled models. used to calculate the number of expected predictions per position + # during consolidation of predictions. Default is 1 (no ensembling, e.g. in validation). + self.n_ens = 1 + + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + if self.mode == 'test': + try: + self.epoch_ranking = np.load(os.path.join(self.cf.fold_dir, 'epoch_ranking.npy'))[:cf.test_n_epochs] + except: + raise RuntimeError('no epoch ranking file in fold directory. ' + 'seems like you are trying to run testing without prior training...') + self.n_ens = cf.test_n_epochs + if self.cf.test_aug: + self.n_ens *= 4 + + self.example_plot_dir = os.path.join(cf.test_dir, "example_plots") + os.makedirs(self.example_plot_dir, exist_ok=True) + + + def predict_patient(self, batch): + """ + predicts one patient. + called either directly via loop over validation set in exec.py (mode=='val') + or from self.predict_test_set (mode=='test). + in val mode: adds 3D ground truth info to predictions and runs consolidation and 2Dto3D merging of predictions. + in test mode: returns raw predictions (ground truth addition, consolidation, 2D to 3D merging are + done in self.predict_test_set, because patient predictions across several epochs might be needed + to be collected first, in case of temporal ensembling). + :return. results_dict: stores the results for one patient. dictionary with keys: + - 'boxes': list over batch elements. each element is a list over boxes, where each box is + one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions + (if not merged to 3D), and a dummy batch dimension of 1 for 3D predictions. + - 'seg_preds': pixel-wise predictions. (b, 1, y, x, (z)) + - losses (only in validation mode) + """ + #self.logger.info('\revaluating patient {} for fold {} '.format(batch['pid'], self.cf.fold)) + print('\revaluating patient {} for fold {} '.format(batch['pid'], self.cf.fold), end="", flush=True) + + # True if patient is provided in patches and predictions need to be tiled. + self.patched_patient = 'patch_crop_coords' in batch.keys() + + # forward batch through prediction pipeline. + results_dict = self.data_aug_forward(batch) + + if self.mode == 'val': + for b in range(batch['patient_bb_target'].shape[0]): + for t in range(len(batch['patient_bb_target'][b])): + results_dict['boxes'][b].append({'box_coords': batch['patient_bb_target'][b][t], + 'box_label': batch['patient_roi_labels'][b][t], + 'box_type': 'gt'}) + + if self.patched_patient: + wcs_input = [results_dict['boxes'], 'dummy_pid', self.cf.class_dict, self.cf.wcs_iou, self.n_ens] + results_dict['boxes'] = apply_wbc_to_patient(wcs_input)[0] + + if self.cf.merge_2D_to_3D_preds: + merge_dims_inputs = [results_dict['boxes'], 'dummy_pid', self.cf.class_dict, self.cf.merge_3D_iou] + results_dict['boxes'] = merge_2D_to_3D_preds_per_patient(merge_dims_inputs)[0] + + return results_dict + + + def predict_test_set(self, batch_gen, return_results=True): + """ + wrapper around test method, which loads multiple (or one) epoch parameters (temporal ensembling), loops through + the test set and collects predictions per patient. Also flattens the results per patient and epoch + and adds optional ground truth boxes for evaluation. Saves out the raw result list for later analysis and + optionally consolidates and returns predictions immediately. + :return: (optionally) list_of_results_per_patient: list over patient results. each entry is a dict with keys: + - 'boxes': list over batch elements. each element is a list over boxes, where each box is + one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions + (if not merged to 3D), and a dummy batch dimension of 1 for 3D predictions. + - 'seg_preds': not implemented yet. todo for evaluation of instance/semantic segmentation. + """ + dict_of_patient_results = OrderedDict() + + # get paths of all parameter sets to be loaded for temporal ensembling. (or just one for no temp. ensembling). + weight_paths = [os.path.join(self.cf.fold_dir, '{}_best_checkpoint'.format(epoch), 'params.pth') for epoch in + self.epoch_ranking] + + print (weight_paths) + + n_test_plots = min(batch_gen['n_test'], 1) + + for rank_ix, weight_path in enumerate(weight_paths): + + self.logger.info(('tmp ensembling over rank_ix:{} epoch:{}'.format(rank_ix, weight_path))) + #code.interact(local=locals()) + self.net.load_state_dict(torch.load(weight_path)) + self.net.eval() + self.rank_ix = str(rank_ix) # get string of current rank for unique patch ids. + plot_batches = np.random.choice(np.arange(batch_gen['n_test']), size=n_test_plots, replace=False) + + with torch.no_grad(): + for i in range(batch_gen['n_test']): + + batch = next(batch_gen['test']) + + # store batch info in patient entry of results dict. + if rank_ix == 0: + dict_of_patient_results[batch['pid']] = {} + dict_of_patient_results[batch['pid']]['results_dicts'] = [] + dict_of_patient_results[batch['pid']]['patient_bb_target'] = batch['patient_bb_target'] + dict_of_patient_results[batch['pid']]['patient_roi_labels'] = batch['patient_roi_labels'] + + # call prediction pipeline and store results in dict. + results_dict = self.predict_patient(batch) + dict_of_patient_results[batch['pid']]['results_dicts'].append({"boxes": results_dict['boxes']}) + + if i in plot_batches and not self.patched_patient: + # view qualitative results of random test case + # plotting for patched patients is too expensive, thus not done. Change at will. + try: + out_file = os.path.join(self.example_plot_dir, + 'batch_example_test_{}_rank_{}.png'.format(self.cf.fold, + rank_ix)) + results_for_plotting = deepcopy(results_dict) + # seg preds of test augs are included separately. for viewing, only show aug 0 (merging + # would need multiple changes, incl in every model). + if results_for_plotting["seg_preds"].shape[1] > 1: + results_for_plotting["seg_preds"] = results_dict['seg_preds'][:, [0]] + for bix in range(batch["seg"].shape[0]): # batch dim should be 1 + for tix in range(len(batch['bb_target'][bix])): + results_for_plotting['boxes'][bix].append({'box_coords': batch['bb_target'][bix][tix], + 'box_label': batch['class_target'][bix][tix], + 'box_type': 'gt'}) + utils.split_off_process(plot_batch_prediction, batch, results_for_plotting, self.cf, + outfile=out_file, suptitle="Test plot:\nunmerged TTA overlayed.") + except Exception as e: + self.logger.info("WARNING: error in plotting example test batch: {}".format(e)) + + + self.logger.info('finished predicting test set. starting post-processing of predictions.') + results_per_patient = [] + + # loop over patients again to flatten results across epoch predictions. + # if provided, add ground truth boxes for evaluation. + for pid, p_dict in dict_of_patient_results.items(): + + tmp_ens_list = p_dict['results_dicts'] + results_dict = {} + # collect all boxes/seg_preds of same batch_instance over temporal instances. + b_size = len(tmp_ens_list[0]["boxes"]) + results_dict['boxes'] = [[item for rank_dict in tmp_ens_list for item in rank_dict["boxes"][batch_instance]] + for batch_instance in range(b_size)] + + # TODO return for instance segmentation: + # results_dict['seg_preds'] = np.mean(results_dict['seg_preds'], 1)[:, None] + # results_dict['seg_preds'] = np.array([[item for d in tmp_ens_list for item in d['seg_preds'][batch_instance]] + # for batch_instance in range(len(tmp_ens_list[0]['boxes']))]) + + # add 3D ground truth boxes for evaluation. + for b in range(p_dict['patient_bb_target'].shape[0]): + for t in range(len(p_dict['patient_bb_target'][b])): + results_dict['boxes'][b].append({'box_coords': p_dict['patient_bb_target'][b][t], + 'box_label': p_dict['patient_roi_labels'][b][t], + 'box_type': 'gt'}) + results_per_patient.append([results_dict, pid]) + + # save out raw predictions. + out_string = 'raw_pred_boxes_hold_out_list' if self.cf.hold_out_test_set else 'raw_pred_boxes_list' + with open(os.path.join(self.cf.fold_dir, '{}.pickle'.format(out_string)), 'wb') as handle: + pickle.dump(results_per_patient, handle) + + if return_results: + final_patient_box_results = [(res_dict["boxes"], pid) for res_dict, pid in results_per_patient] + # consolidate predictions. + self.logger.info('applying wcs to test set predictions with iou = {} and n_ens = {}.'.format( + self.cf.wcs_iou, self.n_ens)) + pool = Pool(processes=8) + mp_inputs = [[ii[0], ii[1], self.cf.class_dict, self.cf.wcs_iou, self.n_ens] for ii in final_patient_box_results] + final_patient_box_results = pool.map(apply_wbc_to_patient, mp_inputs, chunksize=1) + pool.close() + pool.join() + + # merge 2D boxes to 3D cubes. (if model predicts 2D but evaluation is run in 3D) + if self.cf.merge_2D_to_3D_preds: + self.logger.info('applying 2Dto3D merging to test set predictions with iou = {}.'.format(self.cf.merge_3D_iou)) + pool = Pool(processes=6) + mp_inputs = [[ii[0], ii[1], self.cf.class_dict, self.cf.merge_3D_iou] for ii in final_patient_box_results] + final_patient_box_results = pool.map(merge_2D_to_3D_preds_per_patient, mp_inputs, chunksize=1) + pool.close() + pool.join() + + # final_patient_box_results holds [avg_boxes, pid] if wbc + for ix in range(len(results_per_patient)): + assert results_per_patient[ix][1] == final_patient_box_results[ix][1], "should be same pid" + results_per_patient[ix][0]["boxes"] = final_patient_box_results[ix][0] + + out_string = 'wbc_pred_boxes_hold_out_list' if self.cf.hold_out_test_set else 'wbc_pred_boxes_list' + with open(os.path.join(self.cf.fold_dir, '{}.pickle'.format(out_string)), 'wb') as handle: + pickle.dump(results_per_patient, handle) + + + return results_per_patient + + + def load_saved_predictions(self, apply_wbc=False): + """ + loads raw predictions saved by self.predict_test_set. consolidates and merges 2D boxes to 3D cubes for evaluation. + (if model predicts 2D but evaluation is run in 3D) + :return: (optionally) results_list: list over patient results. each entry is a dict with keys: + - 'boxes': list over batch elements. each element is a list over boxes, where each box is + one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions + (if not merged to 3D), and a dummy batch dimension of 1 for 3D predictions. + - 'seg_preds': not implemented yet. todo for evaluation of instance/semantic segmentation. + """ + + # load predictions for a single test-set fold. + results_file = 'raw_pred_boxes_hold_out_list.pickle' if self.cf.hold_out_test_set else 'raw_pred_boxes_list.pickle' + if not self.cf.hold_out_test_set or not self.cf.ensemble_folds: + with open(os.path.join(self.cf.fold_dir, results_file), 'rb') as handle: + results_list = pickle.load(handle) + box_results_list = [(res_dict["boxes"], pid) for res_dict, pid in results_list] + da_factor = 4 if self.cf.test_aug else 1 + n_ens = self.cf.test_n_epochs * da_factor + self.logger.info('loaded raw test set predictions with n_patients = {} and n_ens = {}'.format( + len(results_list), n_ens)) + + # if hold out test set was perdicted, aggregate predictions of all trained models + # corresponding to all CV-folds and flatten them. + else: + self.logger.info("loading saved predictions of hold-out test set and ensembling over folds.") + fold_dirs = sorted([os.path.join(self.cf.exp_dir, f) for f in os.listdir(self.cf.exp_dir) if + os.path.isdir(os.path.join(self.cf.exp_dir, f)) and f.startswith("fold")]) + + results_list = [] + folds_loaded = 0 + for fold in range(self.cf.n_cv_splits): + fold_dir = os.path.join(self.cf.exp_dir, 'fold_{}'.format(fold)) + if fold_dir in fold_dirs: + with open(os.path.join(fold_dir, results_file), 'rb') as handle: + fold_list = pickle.load(handle) + results_list += fold_list + folds_loaded += 1 + else: + self.logger.info("Skipping fold {} since no saved predictions found.".format(fold)) + box_results_list = [] + for res_dict, pid in results_list: #without filtering gt out: + box_results_list.append((res_dict['boxes'], pid)) + + da_factor = 4 if self.cf.test_aug else 1 + n_ens = self.cf.test_n_epochs * da_factor * folds_loaded + + # consolidate predictions. + if apply_wbc: + self.logger.info('applying wcs to test set predictions with iou = {} and n_ens = {}.'.format( + self.cf.wcs_iou, n_ens)) + pool = Pool(processes=6) + mp_inputs = [[ii[0], ii[1], self.cf.class_dict, self.cf.wcs_iou, n_ens] for ii in box_results_list] + box_results_list = pool.map(apply_wbc_to_patient, mp_inputs, chunksize=1) + pool.close() + pool.join() + + # merge 2D box predictions to 3D cubes (if model predicts 2D but evaluation is run in 3D) + if self.cf.merge_2D_to_3D_preds: + self.logger.info( + 'applying 2Dto3D merging to test set predictions with iou = {}.'.format(self.cf.merge_3D_iou)) + pool = Pool(processes=6) + mp_inputs = [[ii[0], ii[1], self.cf.class_dict, self.cf.merge_3D_iou] for ii in box_results_list] + box_results_list = pool.map(merge_2D_to_3D_preds_per_patient, mp_inputs, chunksize=1) + pool.close() + pool.join() + + + for ix in range(len(results_list)): + assert np.all( + results_list[ix][1] == box_results_list[ix][1]), "pid mismatch between loaded and aggregated results" + results_list[ix][0]["boxes"] = box_results_list[ix][0] + + return results_list # holds (results_dict, pid) + + + def data_aug_forward(self, batch): + """ + in val_mode: passes batch through to spatial_tiling method without data_aug. + in test_mode: if cf.test_aug is set in configs, createst 4 mirrored versions of the input image, + passes all of them to the next processing step (spatial_tiling method) and re-transforms returned predictions + to original image version. + :return. results_dict: stores the results for one patient. dictionary with keys: + - 'boxes': list over batch elements. each element is a list over boxes, where each box is + one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions, + and a dummy batch dimension of 1 for 3D predictions. + - 'seg_preds': pixel-wise predictions. (b, 1, y, x, (z)) + - losses (only in validation mode) + """ + patch_crops = batch['patch_crop_coords'] if self.patched_patient else None + results_list = [self.spatial_tiling_forward(batch, patch_crops)] + org_img_shape = batch['original_img_shape'] + + if self.mode == 'test' and self.cf.test_aug: + + if self.patched_patient: + # apply mirror transformations to patch-crop coordinates, for correct tiling in spatial_tiling method. + mirrored_patch_crops = get_mirrored_patch_crops(patch_crops, batch['original_img_shape']) + else: + mirrored_patch_crops = [None] * 3 + + img = np.copy(batch['data']) + + # first mirroring: y-axis. + batch['data'] = np.flip(img, axis=2).copy() + chunk_dict = self.spatial_tiling_forward(batch, mirrored_patch_crops[0], n_aug='1') + # re-transform coordinates. + for ix in range(len(chunk_dict['boxes'])): + for boxix in range(len(chunk_dict['boxes'][ix])): + coords = chunk_dict['boxes'][ix][boxix]['box_coords'].copy() + coords[0] = org_img_shape[2] - chunk_dict['boxes'][ix][boxix]['box_coords'][2] + coords[2] = org_img_shape[2] - chunk_dict['boxes'][ix][boxix]['box_coords'][0] + assert coords[2] >= coords[0], [coords, chunk_dict['boxes'][ix][boxix]['box_coords'].copy()] + assert coords[3] >= coords[1], [coords, chunk_dict['boxes'][ix][boxix]['box_coords'].copy()] + chunk_dict['boxes'][ix][boxix]['box_coords'] = coords + # re-transform segmentation predictions. + chunk_dict['seg_preds'] = np.flip(chunk_dict['seg_preds'], axis=2) + results_list.append(chunk_dict) + + # second mirroring: x-axis. + batch['data'] = np.flip(img, axis=3).copy() + chunk_dict = self.spatial_tiling_forward(batch, mirrored_patch_crops[1], n_aug='2') + # re-transform coordinates. + for ix in range(len(chunk_dict['boxes'])): + for boxix in range(len(chunk_dict['boxes'][ix])): + coords = chunk_dict['boxes'][ix][boxix]['box_coords'].copy() + coords[1] = org_img_shape[3] - chunk_dict['boxes'][ix][boxix]['box_coords'][3] + coords[3] = org_img_shape[3] - chunk_dict['boxes'][ix][boxix]['box_coords'][1] + assert coords[2] >= coords[0], [coords, chunk_dict['boxes'][ix][boxix]['box_coords'].copy()] + assert coords[3] >= coords[1], [coords, chunk_dict['boxes'][ix][boxix]['box_coords'].copy()] + chunk_dict['boxes'][ix][boxix]['box_coords'] = coords + # re-transform segmentation predictions. + chunk_dict['seg_preds'] = np.flip(chunk_dict['seg_preds'], axis=3) + results_list.append(chunk_dict) + + # third mirroring: y-axis and x-axis. + batch['data'] = np.flip(np.flip(img, axis=2), axis=3).copy() + chunk_dict = self.spatial_tiling_forward(batch, mirrored_patch_crops[2], n_aug='3') + # re-transform coordinates. + for ix in range(len(chunk_dict['boxes'])): + for boxix in range(len(chunk_dict['boxes'][ix])): + coords = chunk_dict['boxes'][ix][boxix]['box_coords'].copy() + coords[0] = org_img_shape[2] - chunk_dict['boxes'][ix][boxix]['box_coords'][2] + coords[2] = org_img_shape[2] - chunk_dict['boxes'][ix][boxix]['box_coords'][0] + coords[1] = org_img_shape[3] - chunk_dict['boxes'][ix][boxix]['box_coords'][3] + coords[3] = org_img_shape[3] - chunk_dict['boxes'][ix][boxix]['box_coords'][1] + assert coords[2] >= coords[0], [coords, chunk_dict['boxes'][ix][boxix]['box_coords'].copy()] + assert coords[3] >= coords[1], [coords, chunk_dict['boxes'][ix][boxix]['box_coords'].copy()] + chunk_dict['boxes'][ix][boxix]['box_coords'] = coords + # re-transform segmentation predictions. + chunk_dict['seg_preds'] = np.flip(np.flip(chunk_dict['seg_preds'], axis=2), axis=3).copy() + results_list.append(chunk_dict) + + batch['data'] = img + + # aggregate all boxes/seg_preds per batch element from data_aug predictions. + results_dict = {} + results_dict['boxes'] = [[item for d in results_list for item in d['boxes'][batch_instance]] + for batch_instance in range(org_img_shape[0])] + results_dict['seg_preds'] = np.array([[item for d in results_list for item in d['seg_preds'][batch_instance]] + for batch_instance in range(org_img_shape[0])]) + if self.mode == 'val': + try: + results_dict['torch_loss'] = results_list[0]['torch_loss'] + results_dict['class_loss'] = results_list[0]['class_loss'] + except KeyError: + pass + return results_dict + + + def spatial_tiling_forward(self, batch, patch_crops=None, n_aug='0'): + """ + forwards batch to batch_tiling_forward method and receives and returns a dictionary with results. + if patch-based prediction, the results received from batch_tiling_forward will be on a per-patch-basis. + this method uses the provided patch_crops to re-transform all predictions to whole-image coordinates. + Patch-origin information of all box-predictions will be needed for consolidation, hence it is stored as + 'patch_id', which is a unique string for each patch (also takes current data aug and temporal epoch instances + into account). all box predictions get additional information about the amount overlapping patches at the + respective position (used for consolidation). + :return. results_dict: stores the results for one patient. dictionary with keys: + - 'boxes': list over batch elements. each element is a list over boxes, where each box is + one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions, + and a dummy batch dimension of 1 for 3D predictions. + - 'seg_preds': pixel-wise predictions. (b, 1, y, x, (z)) + - losses (only in validation mode) + """ + if patch_crops is not None: + + patches_dict = self.batch_tiling_forward(batch) + + results_dict = {'boxes': [[] for _ in range(batch['original_img_shape'][0])]} + + # instanciate segemntation output array. Will contain averages over patch predictions. + out_seg_preds = np.zeros(batch['original_img_shape'], dtype=np.float16)[:, 0][:, None] + # counts patch instances per pixel-position. + patch_overlap_map = np.zeros_like(out_seg_preds, dtype='uint8') + + #unmold segmentation outputs. loop over patches. + for pix, pc in enumerate(patch_crops): + if self.cf.dim == 3: + out_seg_preds[:, :, pc[0]:pc[1], pc[2]:pc[3], pc[4]:pc[5]] += patches_dict['seg_preds'][pix][None] + patch_overlap_map[:, :, pc[0]:pc[1], pc[2]:pc[3], pc[4]:pc[5]] += 1 + else: + out_seg_preds[pc[4]:pc[5], :, pc[0]:pc[1], pc[2]:pc[3], ] += patches_dict['seg_preds'][pix] + patch_overlap_map[pc[4]:pc[5], :, pc[0]:pc[1], pc[2]:pc[3], ] += 1 + + # take mean in overlapping areas. + out_seg_preds[patch_overlap_map > 0] /= patch_overlap_map[patch_overlap_map > 0] + results_dict['seg_preds'] = out_seg_preds + + # unmold box outputs. loop over patches. + for pix, pc in enumerate(patch_crops): + patch_boxes = patches_dict['boxes'][pix] + + for box in patch_boxes: + + # add unique patch id for consolidation of predictions. + box['patch_id'] = self.rank_ix + '_' + n_aug + '_' + str(pix) + + # boxes from the edges of a patch have a lower prediction quality, than the ones at patch-centers. + # hence they will be downweighted for consolidation, using the 'box_patch_center_factor', which is + # obtained by a normal distribution over positions in the patch and average over spatial dimensions. + # Also the info 'box_n_overlaps' is stored for consolidation, which depicts the amount over + # overlapping patches at the box's position. + c = box['box_coords'] + box_centers = [(c[ii] + c[ii + 2]) / 2 for ii in range(2)] + if self.cf.dim == 3: + box_centers.append((c[4] + c[5]) / 2) + box['box_patch_center_factor'] = np.mean( + [norm.pdf(bc, loc=pc, scale=pc * 0.8) * np.sqrt(2 * np.pi) * pc * 0.8 for bc, pc in + zip(box_centers, np.array(self.cf.patch_size) / 2)]) + if self.cf.dim == 3: + c += np.array([pc[0], pc[2], pc[0], pc[2], pc[4], pc[4]]) + int_c = [int(np.floor(ii)) if ix%2 == 0 else int(np.ceil(ii)) for ix, ii in enumerate(c)] + box['box_n_overlaps'] = np.mean(patch_overlap_map[:, :, int_c[1]:int_c[3], int_c[0]:int_c[2], int_c[4]:int_c[5]]) + results_dict['boxes'][0].append(box) + else: + c += np.array([pc[0], pc[2], pc[0], pc[2]]) + int_c = [int(np.floor(ii)) if ix % 2 == 0 else int(np.ceil(ii)) for ix, ii in enumerate(c)] + box['box_n_overlaps'] = np.mean(patch_overlap_map[pc[4], :, int_c[1]:int_c[3], int_c[0]:int_c[2]]) + results_dict['boxes'][pc[4]].append(box) + + if self.mode == 'val': + try: + results_dict['torch_loss'] = patches_dict['torch_loss'] + results_dict['class_loss'] = patches_dict['class_loss'] + except KeyError: + pass + # if predictions are not patch-based: + # add patch-origin info to boxes (entire image is the same patch with overlap=1) and return results. + else: + results_dict = self.batch_tiling_forward(batch) + for b in results_dict['boxes']: + for box in b: + box['box_patch_center_factor'] = 1 + box['box_n_overlaps'] = 1 + box['patch_id'] = self.rank_ix + '_' + n_aug + + return results_dict + + + def batch_tiling_forward(self, batch): + """ + calls the actual network forward method. in patch-based prediction, the batch dimension might be overladed + with n_patches >> batch_size, which would exceed gpu memory. In this case, batches are processed in chunks of + batch_size. validation mode calls the train method to monitor losses (returned ground truth objects are discarded). + test mode calls the test forward method, no ground truth required / involved. + :return. results_dict: stores the results for one patient. dictionary with keys: + - 'boxes': list over batch elements. each element is a list over boxes, where each box is + one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions, + and a dummy batch dimension of 1 for 3D predictions. + - 'seg_preds': pixel-wise predictions. (b, 1, y, x, (z)) + - losses (only in validation mode) + """ + #self.logger.info('forwarding (patched) patient with shape: {}'.format(batch['data'].shape)) + + img = batch['data'] + + #batch['data'] = torch.from_numpy(batch['data']).float().to(self.device) + + if img.shape[0] <= self.cf.batch_size: + + if self.mode == 'val': + # call training method to monitor losses + results_dict = self.net.train_forward(batch, is_validation=True) + # discard returned ground-truth boxes (also training info boxes). + results_dict['boxes'] = [[box for box in b if box['box_type'] == 'det'] for b in results_dict['boxes']] + else: + results_dict = self.net.test_forward(batch, return_masks=self.cf.return_masks_in_test) + + else: + split_ixs = np.split(np.arange(img.shape[0]), np.arange(img.shape[0])[::self.cf.batch_size]) + chunk_dicts = [] + for chunk_ixs in split_ixs[1:]: # first split is elements before 0, so empty + b = {k: batch[k][chunk_ixs] for k in batch.keys() + if (isinstance(batch[k], np.ndarray) and batch[k].shape[0] == img.shape[0])} + if self.mode == 'val': + chunk_dicts += [self.net.train_forward(b, is_validation=True)] + else: + chunk_dicts += [self.net.test_forward(b, return_masks=self.cf.return_masks_in_test)] + + + results_dict = {} + # flatten out batch elements from chunks ([chunk, chunk] -> [b, b, b, b, ...]) + results_dict['boxes'] = [item for d in chunk_dicts for item in d['boxes']] + results_dict['seg_preds'] = np.array([item for d in chunk_dicts for item in d['seg_preds']]) + + if self.mode == 'val': + try: + # estimate metrics by mean over batch_chunks. Most similar to training metrics. + results_dict['torch_loss'] = torch.mean(torch.cat([d['torch_loss'] for d in chunk_dicts])) + results_dict['class_loss'] = np.mean([d['class_loss'] for d in chunk_dicts]) + except KeyError: + # losses are not necessarily monitored + pass + # discard returned ground-truth boxes (also training info boxes). + results_dict['boxes'] = [[box for box in b if box['box_type'] == 'det'] for b in results_dict['boxes']] + + return results_dict + + + +def apply_wbc_to_patient(inputs): + """ + wrapper around prediction box consolidation: weighted cluster scoring (wcs). processes a single patient. + loops over batch elements in patient results (1 in 3D, slices in 2D) and foreground classes, + aggregates and stores results in new list. + :return. patient_results_list: list over batch elements. each element is a list over boxes, where each box is + one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D + predictions, and a dummy batch dimension of 1 for 3D predictions. + :return. pid: string. patient id. + """ + in_patient_results_list, pid, class_dict, wcs_iou, n_ens = inputs + out_patient_results_list = [[] for _ in range(len(in_patient_results_list))] + + for bix, b in enumerate(in_patient_results_list): + + for cl in list(class_dict.keys()): + + boxes = [(ix, box) for ix, box in enumerate(b) if (box['box_type'] == 'det' and box['box_pred_class_id'] == cl)] + box_coords = np.array([b[1]['box_coords'] for b in boxes]) + box_scores = np.array([b[1]['box_score'] for b in boxes]) + box_center_factor = np.array([b[1]['box_patch_center_factor'] for b in boxes]) + box_n_overlaps = np.array([b[1]['box_n_overlaps'] for b in boxes]) + box_patch_id = np.array([b[1]['patch_id'] for b in boxes]) + + if 0 not in box_scores.shape: + keep_scores, keep_coords = weighted_box_clustering( + np.concatenate((box_coords, box_scores[:, None], box_center_factor[:, None], + box_n_overlaps[:, None]), axis=1), box_patch_id, wcs_iou, n_ens) + + for boxix in range(len(keep_scores)): + out_patient_results_list[bix].append({'box_type': 'det', 'box_coords': keep_coords[boxix], + 'box_score': keep_scores[boxix], 'box_pred_class_id': cl}) + + # add gt boxes back to new output list. + out_patient_results_list[bix].extend([box for box in b if box['box_type'] == 'gt']) + + return [out_patient_results_list, pid] + + + +def merge_2D_to_3D_preds_per_patient(inputs): + """ + wrapper around 2Dto3D merging operation. Processes a single patient. Takes 2D patient results (slices in batch dimension) + and returns 3D patient results (dummy batch dimension of 1). Applies an adaption of Non-Maximum Surpression + (Detailed methodology is described in nms_2to3D). + :return. results_dict_boxes: list over batch elements (1 in 3D). each element is a list over boxes, where each box is + one dictionary: [[box_0, ...], [box_n,...]]. + :return. pid: string. patient id. + """ + in_patient_results_list, pid, class_dict, merge_3D_iou = inputs + out_patient_results_list = [] + + for cl in list(class_dict.keys()): + boxes, slice_ids = [], [] + # collect box predictions over batch dimension (slices) and store slice info as slice_ids. + for bix, b in enumerate(in_patient_results_list): + det_boxes = [(ix, box) for ix, box in enumerate(b) if + (box['box_type'] == 'det' and box['box_pred_class_id'] == cl)] + boxes += det_boxes + slice_ids += [bix] * len(det_boxes) + + box_coords = np.array([b[1]['box_coords'] for b in boxes]) + box_scores = np.array([b[1]['box_score'] for b in boxes]) + slice_ids = np.array(slice_ids) + + if 0 not in box_scores.shape: + keep_ix, keep_z = nms_2to3D( + np.concatenate((box_coords, box_scores[:, None], slice_ids[:, None]), axis=1), merge_3D_iou) + else: + keep_ix, keep_z = [], [] + + # store kept predictions in new results list and add corresponding z-dimension info to coordinates. + for kix, kz in zip(keep_ix, keep_z): + out_patient_results_list.append({'box_type': 'det', 'box_coords': list(box_coords[kix]) + kz, + 'box_score': box_scores[kix], 'box_pred_class_id': cl}) + + gt_boxes = [box for b in in_patient_results_list for box in b if box['box_type'] == 'gt'] + if len(gt_boxes) > 0: + assert np.all([len(box["box_coords"]) == 6 for box in gt_boxes]), "expanded preds to 3D but GT is 2D." + out_patient_results_list += gt_boxes + + # add dummy batch dimension 1 for 3D. + return [[out_patient_results_list], pid] + + + +def weighted_box_clustering(dets, box_patch_id, thresh, n_ens): + """ + consolidates overlapping predictions resulting from patch overlaps, test data augmentations and temporal ensembling. + clusters predictions together with iou > thresh (like in NMS). Output score and coordinate for one cluster are the + average weighted by individual patch center factors (how trustworthy is this candidate measured by how centered + its position the patch is) and the size of the corresponding box. + The number of expected predictions at a position is n_data_aug * n_temp_ens * n_overlaps_at_position + (1 prediction per unique patch). Missing predictions at a cluster position are defined as the number of unique + patches in the cluster, which did not contribute any predict any boxes. + :param dets: (n_dets, (y1, x1, y2, x2, (z1), (z2), scores, box_pc_facts, box_n_ovs) + :param thresh: threshold for iou_matching. + :param n_ens: number of models, that are ensembled. (-> number of expected predicitions per position) + :return: keep_scores: (n_keep) new scores of boxes to be kept. + :return: keep_coords: (n_keep, (y1, x1, y2, x2, (z1), (z2)) new coordinates of boxes to be kept. + """ + dim = 2 if dets.shape[1] == 7 else 3 + y1 = dets[:, 0] + x1 = dets[:, 1] + y2 = dets[:, 2] + x2 = dets[:, 3] + scores = dets[:, -3] + box_pc_facts = dets[:, -2] + box_n_ovs = dets[:, -1] + + areas = (y2 - y1 + 1) * (x2 - x1 + 1) + + if dim == 3: + z1 = dets[:, 4] + z2 = dets[:, 5] + areas *= (z2 - z1 + 1) + + # order is the sorted index. maps order to index o[1] = 24 (rank1, ix 24) + order = scores.argsort()[::-1] + + keep = [] + keep_scores = [] + keep_coords = [] + + while order.size > 0: + i = order[0] # higehst scoring element + xx1 = np.maximum(x1[i], x1[order]) + yy1 = np.maximum(y1[i], y1[order]) + xx2 = np.minimum(x2[i], x2[order]) + yy2 = np.minimum(y2[i], y2[order]) + + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + + if dim == 3: + zz1 = np.maximum(z1[i], z1[order]) + zz2 = np.minimum(z2[i], z2[order]) + d = np.maximum(0.0, zz2 - zz1 + 1) + inter *= d + + # overall between currently highest scoring box and all boxes. + ovr = inter / (areas[i] + areas[order] - inter) + + # get all the predictions that match the current box to build one cluster. + matches = np.argwhere(ovr > thresh) + + match_n_ovs = box_n_ovs[order[matches]] + match_pc_facts = box_pc_facts[order[matches]] + match_patch_id = box_patch_id[order[matches]] + match_ov_facts = ovr[matches] + match_areas = areas[order[matches]] + match_scores = scores[order[matches]] + + # weight all socres in cluster by patch factors, and size. + match_score_weights = match_ov_facts * match_areas * match_pc_facts + match_scores *= match_score_weights + + # for the weigted average, scores have to be divided by the number of total expected preds at the position + # of the current cluster. 1 Prediction per patch is expected. therefore, the number of ensembled models is + # multiplied by the mean overlaps of patches at this position (boxes of the cluster might partly be + # in areas of different overlaps). + n_expected_preds = n_ens * np.mean(match_n_ovs) + + # the number of missing predictions is obtained as the number of patches, + # which did not contribute any prediction to the current cluster. + n_missing_preds = np.max((0, n_expected_preds - np.unique(match_patch_id).shape[0])) + + # missing preds are given the mean weighting + # (expected prediction is the mean over all predictions in cluster). + denom = np.sum(match_score_weights) + n_missing_preds * np.mean(match_score_weights) + + # compute weighted average score for the cluster + avg_score = np.sum(match_scores) / denom + + # compute weighted average of coordinates for the cluster. now only take existing + # predictions into account. + avg_coords = [np.sum(y1[order[matches]] * match_scores) / np.sum(match_scores), + np.sum(x1[order[matches]] * match_scores) / np.sum(match_scores), + np.sum(y2[order[matches]] * match_scores) / np.sum(match_scores), + np.sum(x2[order[matches]] * match_scores) / np.sum(match_scores)] + if dim == 3: + avg_coords.append(np.sum(z1[order[matches]] * match_scores) / np.sum(match_scores)) + avg_coords.append(np.sum(z2[order[matches]] * match_scores) / np.sum(match_scores)) + + # some clusters might have very low scores due to high amounts of missing predictions. + # filter out the with a conservative threshold, to speed up evaluation. + if avg_score > 0.01: + keep_scores.append(avg_score) + keep_coords.append(avg_coords) + + # get index of all elements that were not matched and discard all others. + inds = np.where(ovr <= thresh)[0] + order = order[inds] + + return keep_scores, keep_coords + + + +def nms_2to3D(dets, thresh): + """ + Merges 2D boxes to 3D cubes. Therefore, boxes of all slices are projected into one slices. An adaptation of Non-maximum surpression + is applied, where clusters are found (like in NMS) with an extra constrained, that surpressed boxes have to have 'connected' + z-coordinates w.r.t the core slice (cluster center, highest scoring box). 'connected' z-coordinates are determined + as the z-coordinates with predictions until the first coordinate, where no prediction was found. + + example: a cluster of predictions was found overlap > iou thresh in xy (like NMS). The z-coordinate of the highest + scoring box is 50. Other predictions have 23, 46, 48, 49, 51, 52, 53, 56, 57. + Only the coordinates connected with 50 are clustered to one cube: 48, 49, 51, 52, 53. (46 not because nothing was + found in 47, so 47 is a 'hole', which interrupts the connection). Only the boxes corresponding to these coordinates + are surpressed. All others are kept for building of further clusters. + + This algorithm works better with a certain min_confidence of predictions, because low confidence (e.g. noisy/cluttery) + predictions can break the relatively strong assumption of defining cubes' z-boundaries at the first 'hole' in the cluster. + + :param dets: (n_detections, (y1, x1, y2, x2, scores, slice_id) + :param thresh: iou matchin threshold (like in NMS). + :return: keep: (n_keep) 1D tensor of indices to be kept. + :return: keep_z: (n_keep, [z1, z2]) z-coordinates to be added to boxes, which are kept in order to form cubes. + """ + y1 = dets[:, 0] + x1 = dets[:, 1] + y2 = dets[:, 2] + x2 = dets[:, 3] + scores = dets[:, -2] + slice_id = dets[:, -1] + + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + keep = [] + keep_z = [] + + while order.size > 0: # order is the sorted index. maps order to index o[1] = 24 (rank1, ix 24) + i = order[0] # pop higehst scoring element + xx1 = np.maximum(x1[i], x1[order]) + yy1 = np.maximum(y1[i], y1[order]) + xx2 = np.minimum(x2[i], x2[order]) + yy2 = np.minimum(y2[i], y2[order]) + + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + + ovr = inter / (areas[i] + areas[order] - inter) + matches = np.argwhere(ovr > thresh) # get all the elements that match the current box and have a lower score + + slice_ids = slice_id[order[matches]] + core_slice = slice_id[int(i)] + upper_wholes = [ii for ii in np.arange(core_slice, np.max(slice_ids)) if ii not in slice_ids] + lower_wholes = [ii for ii in np.arange(np.min(slice_ids), core_slice) if ii not in slice_ids] + max_valid_slice_id = np.min(upper_wholes) if len(upper_wholes) > 0 else np.max(slice_ids) + min_valid_slice_id = np.max(lower_wholes) if len(lower_wholes) > 0 else np.min(slice_ids) + z_matches = matches[(slice_ids <= max_valid_slice_id) & (slice_ids >= min_valid_slice_id)] + + z1 = np.min(slice_id[order[z_matches]]) - 1 + z2 = np.max(slice_id[order[z_matches]]) + 1 + + keep.append(i) + keep_z.append([z1, z2]) + order = np.delete(order, z_matches, axis=0) + + return keep, keep_z + + + +def get_mirrored_patch_crops(patch_crops, org_img_shape): + """ + apply 3 mirrror transformations (x-axis, y-axis, x&y-axis) + to given patch crop coordinates and return the transformed coordinates. + Handles 2D and 3D coordinates. + :param patch_crops: list of crops: each element is a list of coordinates for given crop [[y1, x1, ...], [y1, x1, ..]] + :param org_img_shape: shape of patient volume used as world coordinates. + :return: list of mirrored patch crops: lenght=3. each element is a list of transformed patch crops. + """ + mirrored_patch_crops = [] + + # y-axis transform. + mirrored_patch_crops.append([[org_img_shape[2] - ii[1], + org_img_shape[2] - ii[0], + ii[2], ii[3]] if len(ii) == 4 else + [org_img_shape[2] - ii[1], + org_img_shape[2] - ii[0], + ii[2], ii[3], ii[4], ii[5]] for ii in patch_crops]) + + # x-axis transform. + mirrored_patch_crops.append([[ii[0], ii[1], + org_img_shape[3] - ii[3], + org_img_shape[3] - ii[2]] if len(ii) == 4 else + [ii[0], ii[1], + org_img_shape[3] - ii[3], + org_img_shape[3] - ii[2], + ii[4], ii[5]] for ii in patch_crops]) + + # y-axis and x-axis transform. + mirrored_patch_crops.append([[org_img_shape[2] - ii[1], + org_img_shape[2] - ii[0], + org_img_shape[3] - ii[3], + org_img_shape[3] - ii[2]] if len(ii) == 4 else + [org_img_shape[2] - ii[1], + org_img_shape[2] - ii[0], + org_img_shape[3] - ii[3], + org_img_shape[3] - ii[2], + ii[4], ii[5]] for ii in patch_crops]) + + return mirrored_patch_crops + + +