|
a |
|
b/predictor.py |
|
|
1 |
#!/usr/bin/env python |
|
|
2 |
# Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ). |
|
|
3 |
# |
|
|
4 |
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
5 |
# you may not use this file except in compliance with the License. |
|
|
6 |
# You may obtain a copy of the License at |
|
|
7 |
# |
|
|
8 |
# http://www.apache.org/licenses/LICENSE-2.0 |
|
|
9 |
# |
|
|
10 |
# Unless required by applicable law or agreed to in writing, software |
|
|
11 |
# distributed under the License is distributed on an "AS IS" BASIS, |
|
|
12 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
13 |
# See the License for the specific language governing permissions and |
|
|
14 |
# limitations under the License. |
|
|
15 |
# ============================================================================== |
|
|
16 |
|
|
|
17 |
import os |
|
|
18 |
import code |
|
|
19 |
import numpy as np |
|
|
20 |
import torch |
|
|
21 |
from scipy.stats import norm |
|
|
22 |
from collections import OrderedDict |
|
|
23 |
from multiprocessing import Pool |
|
|
24 |
import pickle |
|
|
25 |
from copy import deepcopy |
|
|
26 |
import pandas as pd |
|
|
27 |
|
|
|
28 |
import utils.exp_utils as utils |
|
|
29 |
from plotting import plot_batch_prediction |
|
|
30 |
|
|
|
31 |
|
|
|
32 |
class Predictor: |
|
|
33 |
""" |
|
|
34 |
Prediction pipeline: |
|
|
35 |
- receives a patched patient image (n_patches, c, y, x, (z)) from patient data loader. |
|
|
36 |
- forwards patches through model in chunks of batch_size. (method: batch_tiling_forward) |
|
|
37 |
- unmolds predictions (boxes and segmentations) to original patient coordinates. (method: spatial_tiling_forward) |
|
|
38 |
|
|
|
39 |
Ensembling (mode == 'test'): |
|
|
40 |
- for inference, forwards 4 mirrored versions of image to through model and unmolds predictions afterwards |
|
|
41 |
accordingly (method: data_aug_forward) |
|
|
42 |
- for inference, loads multiple parameter-sets of the trained model corresponding to different epochs. for each |
|
|
43 |
parameter-set loops over entire test set, runs prediction pipeline for each patient. (method: predict_test_set) |
|
|
44 |
|
|
|
45 |
Consolidation of predictions: |
|
|
46 |
- consolidates a patient's predictions (boxes, segmentations) collected over patches, data_aug- and temporal ensembling, |
|
|
47 |
performs clustering and weighted averaging (external function: apply_wbc_to_patient) to obtain consistent outptus. |
|
|
48 |
- for 2D networks, consolidates box predictions to 3D cubes via clustering (adaption of non-maximum surpression). |
|
|
49 |
(external function: merge_2D_to_3D_preds_per_patient) |
|
|
50 |
|
|
|
51 |
Ground truth handling: |
|
|
52 |
- dissmisses any ground truth boxes returned by the model (happens in validation mode, patch-based groundtruth) |
|
|
53 |
- if provided by data loader, adds 3D ground truth to the final predictions to be passed to the evaluator. |
|
|
54 |
""" |
|
|
55 |
def __init__(self, cf, net, logger, mode): |
|
|
56 |
|
|
|
57 |
self.cf = cf |
|
|
58 |
self.logger = logger |
|
|
59 |
|
|
|
60 |
# mode is 'val' for patient-based validation/monitoring and 'test' for inference. |
|
|
61 |
self.mode = mode |
|
|
62 |
|
|
|
63 |
# model instance. In validation mode, contains parameters of current epoch. |
|
|
64 |
self.net = net |
|
|
65 |
|
|
|
66 |
# rank of current epoch loaded (for temporal averaging). this info is added to each prediction, |
|
|
67 |
# for correct weighting during consolidation. |
|
|
68 |
self.rank_ix = '0' |
|
|
69 |
|
|
|
70 |
# number of ensembled models. used to calculate the number of expected predictions per position |
|
|
71 |
# during consolidation of predictions. Default is 1 (no ensembling, e.g. in validation). |
|
|
72 |
self.n_ens = 1 |
|
|
73 |
|
|
|
74 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
75 |
|
|
|
76 |
if self.mode == 'test': |
|
|
77 |
try: |
|
|
78 |
self.epoch_ranking = np.load(os.path.join(self.cf.fold_dir, 'epoch_ranking.npy'))[:cf.test_n_epochs] |
|
|
79 |
except: |
|
|
80 |
raise RuntimeError('no epoch ranking file in fold directory. ' |
|
|
81 |
'seems like you are trying to run testing without prior training...') |
|
|
82 |
self.n_ens = cf.test_n_epochs |
|
|
83 |
if self.cf.test_aug: |
|
|
84 |
self.n_ens *= 4 |
|
|
85 |
|
|
|
86 |
self.example_plot_dir = os.path.join(cf.test_dir, "example_plots") |
|
|
87 |
os.makedirs(self.example_plot_dir, exist_ok=True) |
|
|
88 |
|
|
|
89 |
|
|
|
90 |
def predict_patient(self, batch): |
|
|
91 |
""" |
|
|
92 |
predicts one patient. |
|
|
93 |
called either directly via loop over validation set in exec.py (mode=='val') |
|
|
94 |
or from self.predict_test_set (mode=='test). |
|
|
95 |
in val mode: adds 3D ground truth info to predictions and runs consolidation and 2Dto3D merging of predictions. |
|
|
96 |
in test mode: returns raw predictions (ground truth addition, consolidation, 2D to 3D merging are |
|
|
97 |
done in self.predict_test_set, because patient predictions across several epochs might be needed |
|
|
98 |
to be collected first, in case of temporal ensembling). |
|
|
99 |
:return. results_dict: stores the results for one patient. dictionary with keys: |
|
|
100 |
- 'boxes': list over batch elements. each element is a list over boxes, where each box is |
|
|
101 |
one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions |
|
|
102 |
(if not merged to 3D), and a dummy batch dimension of 1 for 3D predictions. |
|
|
103 |
- 'seg_preds': pixel-wise predictions. (b, 1, y, x, (z)) |
|
|
104 |
- losses (only in validation mode) |
|
|
105 |
""" |
|
|
106 |
#self.logger.info('\revaluating patient {} for fold {} '.format(batch['pid'], self.cf.fold)) |
|
|
107 |
print('\revaluating patient {} for fold {} '.format(batch['pid'], self.cf.fold), end="", flush=True) |
|
|
108 |
|
|
|
109 |
# True if patient is provided in patches and predictions need to be tiled. |
|
|
110 |
self.patched_patient = 'patch_crop_coords' in batch.keys() |
|
|
111 |
|
|
|
112 |
# forward batch through prediction pipeline. |
|
|
113 |
results_dict = self.data_aug_forward(batch) |
|
|
114 |
|
|
|
115 |
if self.mode == 'val': |
|
|
116 |
for b in range(batch['patient_bb_target'].shape[0]): |
|
|
117 |
for t in range(len(batch['patient_bb_target'][b])): |
|
|
118 |
results_dict['boxes'][b].append({'box_coords': batch['patient_bb_target'][b][t], |
|
|
119 |
'box_label': batch['patient_roi_labels'][b][t], |
|
|
120 |
'box_type': 'gt'}) |
|
|
121 |
|
|
|
122 |
if self.patched_patient: |
|
|
123 |
wcs_input = [results_dict['boxes'], 'dummy_pid', self.cf.class_dict, self.cf.wcs_iou, self.n_ens] |
|
|
124 |
results_dict['boxes'] = apply_wbc_to_patient(wcs_input)[0] |
|
|
125 |
|
|
|
126 |
if self.cf.merge_2D_to_3D_preds: |
|
|
127 |
merge_dims_inputs = [results_dict['boxes'], 'dummy_pid', self.cf.class_dict, self.cf.merge_3D_iou] |
|
|
128 |
results_dict['boxes'] = merge_2D_to_3D_preds_per_patient(merge_dims_inputs)[0] |
|
|
129 |
|
|
|
130 |
return results_dict |
|
|
131 |
|
|
|
132 |
|
|
|
133 |
def predict_test_set(self, batch_gen, return_results=True): |
|
|
134 |
""" |
|
|
135 |
wrapper around test method, which loads multiple (or one) epoch parameters (temporal ensembling), loops through |
|
|
136 |
the test set and collects predictions per patient. Also flattens the results per patient and epoch |
|
|
137 |
and adds optional ground truth boxes for evaluation. Saves out the raw result list for later analysis and |
|
|
138 |
optionally consolidates and returns predictions immediately. |
|
|
139 |
:return: (optionally) list_of_results_per_patient: list over patient results. each entry is a dict with keys: |
|
|
140 |
- 'boxes': list over batch elements. each element is a list over boxes, where each box is |
|
|
141 |
one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions |
|
|
142 |
(if not merged to 3D), and a dummy batch dimension of 1 for 3D predictions. |
|
|
143 |
- 'seg_preds': not implemented yet. todo for evaluation of instance/semantic segmentation. |
|
|
144 |
""" |
|
|
145 |
dict_of_patient_results = OrderedDict() |
|
|
146 |
|
|
|
147 |
# get paths of all parameter sets to be loaded for temporal ensembling. (or just one for no temp. ensembling). |
|
|
148 |
weight_paths = [os.path.join(self.cf.fold_dir, '{}_best_checkpoint'.format(epoch), 'params.pth') for epoch in |
|
|
149 |
self.epoch_ranking] |
|
|
150 |
|
|
|
151 |
print (weight_paths) |
|
|
152 |
|
|
|
153 |
n_test_plots = min(batch_gen['n_test'], 1) |
|
|
154 |
|
|
|
155 |
for rank_ix, weight_path in enumerate(weight_paths): |
|
|
156 |
|
|
|
157 |
self.logger.info(('tmp ensembling over rank_ix:{} epoch:{}'.format(rank_ix, weight_path))) |
|
|
158 |
#code.interact(local=locals()) |
|
|
159 |
self.net.load_state_dict(torch.load(weight_path)) |
|
|
160 |
self.net.eval() |
|
|
161 |
self.rank_ix = str(rank_ix) # get string of current rank for unique patch ids. |
|
|
162 |
plot_batches = np.random.choice(np.arange(batch_gen['n_test']), size=n_test_plots, replace=False) |
|
|
163 |
|
|
|
164 |
with torch.no_grad(): |
|
|
165 |
for i in range(batch_gen['n_test']): |
|
|
166 |
|
|
|
167 |
batch = next(batch_gen['test']) |
|
|
168 |
|
|
|
169 |
# store batch info in patient entry of results dict. |
|
|
170 |
if rank_ix == 0: |
|
|
171 |
dict_of_patient_results[batch['pid']] = {} |
|
|
172 |
dict_of_patient_results[batch['pid']]['results_dicts'] = [] |
|
|
173 |
dict_of_patient_results[batch['pid']]['patient_bb_target'] = batch['patient_bb_target'] |
|
|
174 |
dict_of_patient_results[batch['pid']]['patient_roi_labels'] = batch['patient_roi_labels'] |
|
|
175 |
|
|
|
176 |
# call prediction pipeline and store results in dict. |
|
|
177 |
results_dict = self.predict_patient(batch) |
|
|
178 |
dict_of_patient_results[batch['pid']]['results_dicts'].append({"boxes": results_dict['boxes']}) |
|
|
179 |
|
|
|
180 |
if i in plot_batches and not self.patched_patient: |
|
|
181 |
# view qualitative results of random test case |
|
|
182 |
# plotting for patched patients is too expensive, thus not done. Change at will. |
|
|
183 |
try: |
|
|
184 |
out_file = os.path.join(self.example_plot_dir, |
|
|
185 |
'batch_example_test_{}_rank_{}.png'.format(self.cf.fold, |
|
|
186 |
rank_ix)) |
|
|
187 |
results_for_plotting = deepcopy(results_dict) |
|
|
188 |
# seg preds of test augs are included separately. for viewing, only show aug 0 (merging |
|
|
189 |
# would need multiple changes, incl in every model). |
|
|
190 |
if results_for_plotting["seg_preds"].shape[1] > 1: |
|
|
191 |
results_for_plotting["seg_preds"] = results_dict['seg_preds'][:, [0]] |
|
|
192 |
for bix in range(batch["seg"].shape[0]): # batch dim should be 1 |
|
|
193 |
for tix in range(len(batch['bb_target'][bix])): |
|
|
194 |
results_for_plotting['boxes'][bix].append({'box_coords': batch['bb_target'][bix][tix], |
|
|
195 |
'box_label': batch['class_target'][bix][tix], |
|
|
196 |
'box_type': 'gt'}) |
|
|
197 |
utils.split_off_process(plot_batch_prediction, batch, results_for_plotting, self.cf, |
|
|
198 |
outfile=out_file, suptitle="Test plot:\nunmerged TTA overlayed.") |
|
|
199 |
except Exception as e: |
|
|
200 |
self.logger.info("WARNING: error in plotting example test batch: {}".format(e)) |
|
|
201 |
|
|
|
202 |
|
|
|
203 |
self.logger.info('finished predicting test set. starting post-processing of predictions.') |
|
|
204 |
results_per_patient = [] |
|
|
205 |
|
|
|
206 |
# loop over patients again to flatten results across epoch predictions. |
|
|
207 |
# if provided, add ground truth boxes for evaluation. |
|
|
208 |
for pid, p_dict in dict_of_patient_results.items(): |
|
|
209 |
|
|
|
210 |
tmp_ens_list = p_dict['results_dicts'] |
|
|
211 |
results_dict = {} |
|
|
212 |
# collect all boxes/seg_preds of same batch_instance over temporal instances. |
|
|
213 |
b_size = len(tmp_ens_list[0]["boxes"]) |
|
|
214 |
results_dict['boxes'] = [[item for rank_dict in tmp_ens_list for item in rank_dict["boxes"][batch_instance]] |
|
|
215 |
for batch_instance in range(b_size)] |
|
|
216 |
|
|
|
217 |
# TODO return for instance segmentation: |
|
|
218 |
# results_dict['seg_preds'] = np.mean(results_dict['seg_preds'], 1)[:, None] |
|
|
219 |
# results_dict['seg_preds'] = np.array([[item for d in tmp_ens_list for item in d['seg_preds'][batch_instance]] |
|
|
220 |
# for batch_instance in range(len(tmp_ens_list[0]['boxes']))]) |
|
|
221 |
|
|
|
222 |
# add 3D ground truth boxes for evaluation. |
|
|
223 |
for b in range(p_dict['patient_bb_target'].shape[0]): |
|
|
224 |
for t in range(len(p_dict['patient_bb_target'][b])): |
|
|
225 |
results_dict['boxes'][b].append({'box_coords': p_dict['patient_bb_target'][b][t], |
|
|
226 |
'box_label': p_dict['patient_roi_labels'][b][t], |
|
|
227 |
'box_type': 'gt'}) |
|
|
228 |
results_per_patient.append([results_dict, pid]) |
|
|
229 |
|
|
|
230 |
# save out raw predictions. |
|
|
231 |
out_string = 'raw_pred_boxes_hold_out_list' if self.cf.hold_out_test_set else 'raw_pred_boxes_list' |
|
|
232 |
with open(os.path.join(self.cf.fold_dir, '{}.pickle'.format(out_string)), 'wb') as handle: |
|
|
233 |
pickle.dump(results_per_patient, handle) |
|
|
234 |
|
|
|
235 |
if return_results: |
|
|
236 |
final_patient_box_results = [(res_dict["boxes"], pid) for res_dict, pid in results_per_patient] |
|
|
237 |
# consolidate predictions. |
|
|
238 |
self.logger.info('applying wcs to test set predictions with iou = {} and n_ens = {}.'.format( |
|
|
239 |
self.cf.wcs_iou, self.n_ens)) |
|
|
240 |
pool = Pool(processes=8) |
|
|
241 |
mp_inputs = [[ii[0], ii[1], self.cf.class_dict, self.cf.wcs_iou, self.n_ens] for ii in final_patient_box_results] |
|
|
242 |
final_patient_box_results = pool.map(apply_wbc_to_patient, mp_inputs, chunksize=1) |
|
|
243 |
pool.close() |
|
|
244 |
pool.join() |
|
|
245 |
|
|
|
246 |
# merge 2D boxes to 3D cubes. (if model predicts 2D but evaluation is run in 3D) |
|
|
247 |
if self.cf.merge_2D_to_3D_preds: |
|
|
248 |
self.logger.info('applying 2Dto3D merging to test set predictions with iou = {}.'.format(self.cf.merge_3D_iou)) |
|
|
249 |
pool = Pool(processes=6) |
|
|
250 |
mp_inputs = [[ii[0], ii[1], self.cf.class_dict, self.cf.merge_3D_iou] for ii in final_patient_box_results] |
|
|
251 |
final_patient_box_results = pool.map(merge_2D_to_3D_preds_per_patient, mp_inputs, chunksize=1) |
|
|
252 |
pool.close() |
|
|
253 |
pool.join() |
|
|
254 |
|
|
|
255 |
# final_patient_box_results holds [avg_boxes, pid] if wbc |
|
|
256 |
for ix in range(len(results_per_patient)): |
|
|
257 |
assert results_per_patient[ix][1] == final_patient_box_results[ix][1], "should be same pid" |
|
|
258 |
results_per_patient[ix][0]["boxes"] = final_patient_box_results[ix][0] |
|
|
259 |
|
|
|
260 |
out_string = 'wbc_pred_boxes_hold_out_list' if self.cf.hold_out_test_set else 'wbc_pred_boxes_list' |
|
|
261 |
with open(os.path.join(self.cf.fold_dir, '{}.pickle'.format(out_string)), 'wb') as handle: |
|
|
262 |
pickle.dump(results_per_patient, handle) |
|
|
263 |
|
|
|
264 |
|
|
|
265 |
return results_per_patient |
|
|
266 |
|
|
|
267 |
|
|
|
268 |
def load_saved_predictions(self, apply_wbc=False): |
|
|
269 |
""" |
|
|
270 |
loads raw predictions saved by self.predict_test_set. consolidates and merges 2D boxes to 3D cubes for evaluation. |
|
|
271 |
(if model predicts 2D but evaluation is run in 3D) |
|
|
272 |
:return: (optionally) results_list: list over patient results. each entry is a dict with keys: |
|
|
273 |
- 'boxes': list over batch elements. each element is a list over boxes, where each box is |
|
|
274 |
one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions |
|
|
275 |
(if not merged to 3D), and a dummy batch dimension of 1 for 3D predictions. |
|
|
276 |
- 'seg_preds': not implemented yet. todo for evaluation of instance/semantic segmentation. |
|
|
277 |
""" |
|
|
278 |
|
|
|
279 |
# load predictions for a single test-set fold. |
|
|
280 |
results_file = 'raw_pred_boxes_hold_out_list.pickle' if self.cf.hold_out_test_set else 'raw_pred_boxes_list.pickle' |
|
|
281 |
if not self.cf.hold_out_test_set or not self.cf.ensemble_folds: |
|
|
282 |
with open(os.path.join(self.cf.fold_dir, results_file), 'rb') as handle: |
|
|
283 |
results_list = pickle.load(handle) |
|
|
284 |
box_results_list = [(res_dict["boxes"], pid) for res_dict, pid in results_list] |
|
|
285 |
da_factor = 4 if self.cf.test_aug else 1 |
|
|
286 |
n_ens = self.cf.test_n_epochs * da_factor |
|
|
287 |
self.logger.info('loaded raw test set predictions with n_patients = {} and n_ens = {}'.format( |
|
|
288 |
len(results_list), n_ens)) |
|
|
289 |
|
|
|
290 |
# if hold out test set was perdicted, aggregate predictions of all trained models |
|
|
291 |
# corresponding to all CV-folds and flatten them. |
|
|
292 |
else: |
|
|
293 |
self.logger.info("loading saved predictions of hold-out test set and ensembling over folds.") |
|
|
294 |
fold_dirs = sorted([os.path.join(self.cf.exp_dir, f) for f in os.listdir(self.cf.exp_dir) if |
|
|
295 |
os.path.isdir(os.path.join(self.cf.exp_dir, f)) and f.startswith("fold")]) |
|
|
296 |
|
|
|
297 |
results_list = [] |
|
|
298 |
folds_loaded = 0 |
|
|
299 |
for fold in range(self.cf.n_cv_splits): |
|
|
300 |
fold_dir = os.path.join(self.cf.exp_dir, 'fold_{}'.format(fold)) |
|
|
301 |
if fold_dir in fold_dirs: |
|
|
302 |
with open(os.path.join(fold_dir, results_file), 'rb') as handle: |
|
|
303 |
fold_list = pickle.load(handle) |
|
|
304 |
results_list += fold_list |
|
|
305 |
folds_loaded += 1 |
|
|
306 |
else: |
|
|
307 |
self.logger.info("Skipping fold {} since no saved predictions found.".format(fold)) |
|
|
308 |
box_results_list = [] |
|
|
309 |
for res_dict, pid in results_list: #without filtering gt out: |
|
|
310 |
box_results_list.append((res_dict['boxes'], pid)) |
|
|
311 |
|
|
|
312 |
da_factor = 4 if self.cf.test_aug else 1 |
|
|
313 |
n_ens = self.cf.test_n_epochs * da_factor * folds_loaded |
|
|
314 |
|
|
|
315 |
# consolidate predictions. |
|
|
316 |
if apply_wbc: |
|
|
317 |
self.logger.info('applying wcs to test set predictions with iou = {} and n_ens = {}.'.format( |
|
|
318 |
self.cf.wcs_iou, n_ens)) |
|
|
319 |
pool = Pool(processes=6) |
|
|
320 |
mp_inputs = [[ii[0], ii[1], self.cf.class_dict, self.cf.wcs_iou, n_ens] for ii in box_results_list] |
|
|
321 |
box_results_list = pool.map(apply_wbc_to_patient, mp_inputs, chunksize=1) |
|
|
322 |
pool.close() |
|
|
323 |
pool.join() |
|
|
324 |
|
|
|
325 |
# merge 2D box predictions to 3D cubes (if model predicts 2D but evaluation is run in 3D) |
|
|
326 |
if self.cf.merge_2D_to_3D_preds: |
|
|
327 |
self.logger.info( |
|
|
328 |
'applying 2Dto3D merging to test set predictions with iou = {}.'.format(self.cf.merge_3D_iou)) |
|
|
329 |
pool = Pool(processes=6) |
|
|
330 |
mp_inputs = [[ii[0], ii[1], self.cf.class_dict, self.cf.merge_3D_iou] for ii in box_results_list] |
|
|
331 |
box_results_list = pool.map(merge_2D_to_3D_preds_per_patient, mp_inputs, chunksize=1) |
|
|
332 |
pool.close() |
|
|
333 |
pool.join() |
|
|
334 |
|
|
|
335 |
|
|
|
336 |
for ix in range(len(results_list)): |
|
|
337 |
assert np.all( |
|
|
338 |
results_list[ix][1] == box_results_list[ix][1]), "pid mismatch between loaded and aggregated results" |
|
|
339 |
results_list[ix][0]["boxes"] = box_results_list[ix][0] |
|
|
340 |
|
|
|
341 |
return results_list # holds (results_dict, pid) |
|
|
342 |
|
|
|
343 |
|
|
|
344 |
def data_aug_forward(self, batch): |
|
|
345 |
""" |
|
|
346 |
in val_mode: passes batch through to spatial_tiling method without data_aug. |
|
|
347 |
in test_mode: if cf.test_aug is set in configs, createst 4 mirrored versions of the input image, |
|
|
348 |
passes all of them to the next processing step (spatial_tiling method) and re-transforms returned predictions |
|
|
349 |
to original image version. |
|
|
350 |
:return. results_dict: stores the results for one patient. dictionary with keys: |
|
|
351 |
- 'boxes': list over batch elements. each element is a list over boxes, where each box is |
|
|
352 |
one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions, |
|
|
353 |
and a dummy batch dimension of 1 for 3D predictions. |
|
|
354 |
- 'seg_preds': pixel-wise predictions. (b, 1, y, x, (z)) |
|
|
355 |
- losses (only in validation mode) |
|
|
356 |
""" |
|
|
357 |
patch_crops = batch['patch_crop_coords'] if self.patched_patient else None |
|
|
358 |
results_list = [self.spatial_tiling_forward(batch, patch_crops)] |
|
|
359 |
org_img_shape = batch['original_img_shape'] |
|
|
360 |
|
|
|
361 |
if self.mode == 'test' and self.cf.test_aug: |
|
|
362 |
|
|
|
363 |
if self.patched_patient: |
|
|
364 |
# apply mirror transformations to patch-crop coordinates, for correct tiling in spatial_tiling method. |
|
|
365 |
mirrored_patch_crops = get_mirrored_patch_crops(patch_crops, batch['original_img_shape']) |
|
|
366 |
else: |
|
|
367 |
mirrored_patch_crops = [None] * 3 |
|
|
368 |
|
|
|
369 |
img = np.copy(batch['data']) |
|
|
370 |
|
|
|
371 |
# first mirroring: y-axis. |
|
|
372 |
batch['data'] = np.flip(img, axis=2).copy() |
|
|
373 |
chunk_dict = self.spatial_tiling_forward(batch, mirrored_patch_crops[0], n_aug='1') |
|
|
374 |
# re-transform coordinates. |
|
|
375 |
for ix in range(len(chunk_dict['boxes'])): |
|
|
376 |
for boxix in range(len(chunk_dict['boxes'][ix])): |
|
|
377 |
coords = chunk_dict['boxes'][ix][boxix]['box_coords'].copy() |
|
|
378 |
coords[0] = org_img_shape[2] - chunk_dict['boxes'][ix][boxix]['box_coords'][2] |
|
|
379 |
coords[2] = org_img_shape[2] - chunk_dict['boxes'][ix][boxix]['box_coords'][0] |
|
|
380 |
assert coords[2] >= coords[0], [coords, chunk_dict['boxes'][ix][boxix]['box_coords'].copy()] |
|
|
381 |
assert coords[3] >= coords[1], [coords, chunk_dict['boxes'][ix][boxix]['box_coords'].copy()] |
|
|
382 |
chunk_dict['boxes'][ix][boxix]['box_coords'] = coords |
|
|
383 |
# re-transform segmentation predictions. |
|
|
384 |
chunk_dict['seg_preds'] = np.flip(chunk_dict['seg_preds'], axis=2) |
|
|
385 |
results_list.append(chunk_dict) |
|
|
386 |
|
|
|
387 |
# second mirroring: x-axis. |
|
|
388 |
batch['data'] = np.flip(img, axis=3).copy() |
|
|
389 |
chunk_dict = self.spatial_tiling_forward(batch, mirrored_patch_crops[1], n_aug='2') |
|
|
390 |
# re-transform coordinates. |
|
|
391 |
for ix in range(len(chunk_dict['boxes'])): |
|
|
392 |
for boxix in range(len(chunk_dict['boxes'][ix])): |
|
|
393 |
coords = chunk_dict['boxes'][ix][boxix]['box_coords'].copy() |
|
|
394 |
coords[1] = org_img_shape[3] - chunk_dict['boxes'][ix][boxix]['box_coords'][3] |
|
|
395 |
coords[3] = org_img_shape[3] - chunk_dict['boxes'][ix][boxix]['box_coords'][1] |
|
|
396 |
assert coords[2] >= coords[0], [coords, chunk_dict['boxes'][ix][boxix]['box_coords'].copy()] |
|
|
397 |
assert coords[3] >= coords[1], [coords, chunk_dict['boxes'][ix][boxix]['box_coords'].copy()] |
|
|
398 |
chunk_dict['boxes'][ix][boxix]['box_coords'] = coords |
|
|
399 |
# re-transform segmentation predictions. |
|
|
400 |
chunk_dict['seg_preds'] = np.flip(chunk_dict['seg_preds'], axis=3) |
|
|
401 |
results_list.append(chunk_dict) |
|
|
402 |
|
|
|
403 |
# third mirroring: y-axis and x-axis. |
|
|
404 |
batch['data'] = np.flip(np.flip(img, axis=2), axis=3).copy() |
|
|
405 |
chunk_dict = self.spatial_tiling_forward(batch, mirrored_patch_crops[2], n_aug='3') |
|
|
406 |
# re-transform coordinates. |
|
|
407 |
for ix in range(len(chunk_dict['boxes'])): |
|
|
408 |
for boxix in range(len(chunk_dict['boxes'][ix])): |
|
|
409 |
coords = chunk_dict['boxes'][ix][boxix]['box_coords'].copy() |
|
|
410 |
coords[0] = org_img_shape[2] - chunk_dict['boxes'][ix][boxix]['box_coords'][2] |
|
|
411 |
coords[2] = org_img_shape[2] - chunk_dict['boxes'][ix][boxix]['box_coords'][0] |
|
|
412 |
coords[1] = org_img_shape[3] - chunk_dict['boxes'][ix][boxix]['box_coords'][3] |
|
|
413 |
coords[3] = org_img_shape[3] - chunk_dict['boxes'][ix][boxix]['box_coords'][1] |
|
|
414 |
assert coords[2] >= coords[0], [coords, chunk_dict['boxes'][ix][boxix]['box_coords'].copy()] |
|
|
415 |
assert coords[3] >= coords[1], [coords, chunk_dict['boxes'][ix][boxix]['box_coords'].copy()] |
|
|
416 |
chunk_dict['boxes'][ix][boxix]['box_coords'] = coords |
|
|
417 |
# re-transform segmentation predictions. |
|
|
418 |
chunk_dict['seg_preds'] = np.flip(np.flip(chunk_dict['seg_preds'], axis=2), axis=3).copy() |
|
|
419 |
results_list.append(chunk_dict) |
|
|
420 |
|
|
|
421 |
batch['data'] = img |
|
|
422 |
|
|
|
423 |
# aggregate all boxes/seg_preds per batch element from data_aug predictions. |
|
|
424 |
results_dict = {} |
|
|
425 |
results_dict['boxes'] = [[item for d in results_list for item in d['boxes'][batch_instance]] |
|
|
426 |
for batch_instance in range(org_img_shape[0])] |
|
|
427 |
results_dict['seg_preds'] = np.array([[item for d in results_list for item in d['seg_preds'][batch_instance]] |
|
|
428 |
for batch_instance in range(org_img_shape[0])]) |
|
|
429 |
if self.mode == 'val': |
|
|
430 |
try: |
|
|
431 |
results_dict['torch_loss'] = results_list[0]['torch_loss'] |
|
|
432 |
results_dict['class_loss'] = results_list[0]['class_loss'] |
|
|
433 |
except KeyError: |
|
|
434 |
pass |
|
|
435 |
return results_dict |
|
|
436 |
|
|
|
437 |
|
|
|
438 |
def spatial_tiling_forward(self, batch, patch_crops=None, n_aug='0'): |
|
|
439 |
""" |
|
|
440 |
forwards batch to batch_tiling_forward method and receives and returns a dictionary with results. |
|
|
441 |
if patch-based prediction, the results received from batch_tiling_forward will be on a per-patch-basis. |
|
|
442 |
this method uses the provided patch_crops to re-transform all predictions to whole-image coordinates. |
|
|
443 |
Patch-origin information of all box-predictions will be needed for consolidation, hence it is stored as |
|
|
444 |
'patch_id', which is a unique string for each patch (also takes current data aug and temporal epoch instances |
|
|
445 |
into account). all box predictions get additional information about the amount overlapping patches at the |
|
|
446 |
respective position (used for consolidation). |
|
|
447 |
:return. results_dict: stores the results for one patient. dictionary with keys: |
|
|
448 |
- 'boxes': list over batch elements. each element is a list over boxes, where each box is |
|
|
449 |
one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions, |
|
|
450 |
and a dummy batch dimension of 1 for 3D predictions. |
|
|
451 |
- 'seg_preds': pixel-wise predictions. (b, 1, y, x, (z)) |
|
|
452 |
- losses (only in validation mode) |
|
|
453 |
""" |
|
|
454 |
if patch_crops is not None: |
|
|
455 |
|
|
|
456 |
patches_dict = self.batch_tiling_forward(batch) |
|
|
457 |
|
|
|
458 |
results_dict = {'boxes': [[] for _ in range(batch['original_img_shape'][0])]} |
|
|
459 |
|
|
|
460 |
# instanciate segemntation output array. Will contain averages over patch predictions. |
|
|
461 |
out_seg_preds = np.zeros(batch['original_img_shape'], dtype=np.float16)[:, 0][:, None] |
|
|
462 |
# counts patch instances per pixel-position. |
|
|
463 |
patch_overlap_map = np.zeros_like(out_seg_preds, dtype='uint8') |
|
|
464 |
|
|
|
465 |
#unmold segmentation outputs. loop over patches. |
|
|
466 |
for pix, pc in enumerate(patch_crops): |
|
|
467 |
if self.cf.dim == 3: |
|
|
468 |
out_seg_preds[:, :, pc[0]:pc[1], pc[2]:pc[3], pc[4]:pc[5]] += patches_dict['seg_preds'][pix][None] |
|
|
469 |
patch_overlap_map[:, :, pc[0]:pc[1], pc[2]:pc[3], pc[4]:pc[5]] += 1 |
|
|
470 |
else: |
|
|
471 |
out_seg_preds[pc[4]:pc[5], :, pc[0]:pc[1], pc[2]:pc[3], ] += patches_dict['seg_preds'][pix] |
|
|
472 |
patch_overlap_map[pc[4]:pc[5], :, pc[0]:pc[1], pc[2]:pc[3], ] += 1 |
|
|
473 |
|
|
|
474 |
# take mean in overlapping areas. |
|
|
475 |
out_seg_preds[patch_overlap_map > 0] /= patch_overlap_map[patch_overlap_map > 0] |
|
|
476 |
results_dict['seg_preds'] = out_seg_preds |
|
|
477 |
|
|
|
478 |
# unmold box outputs. loop over patches. |
|
|
479 |
for pix, pc in enumerate(patch_crops): |
|
|
480 |
patch_boxes = patches_dict['boxes'][pix] |
|
|
481 |
|
|
|
482 |
for box in patch_boxes: |
|
|
483 |
|
|
|
484 |
# add unique patch id for consolidation of predictions. |
|
|
485 |
box['patch_id'] = self.rank_ix + '_' + n_aug + '_' + str(pix) |
|
|
486 |
|
|
|
487 |
# boxes from the edges of a patch have a lower prediction quality, than the ones at patch-centers. |
|
|
488 |
# hence they will be downweighted for consolidation, using the 'box_patch_center_factor', which is |
|
|
489 |
# obtained by a normal distribution over positions in the patch and average over spatial dimensions. |
|
|
490 |
# Also the info 'box_n_overlaps' is stored for consolidation, which depicts the amount over |
|
|
491 |
# overlapping patches at the box's position. |
|
|
492 |
c = box['box_coords'] |
|
|
493 |
box_centers = [(c[ii] + c[ii + 2]) / 2 for ii in range(2)] |
|
|
494 |
if self.cf.dim == 3: |
|
|
495 |
box_centers.append((c[4] + c[5]) / 2) |
|
|
496 |
box['box_patch_center_factor'] = np.mean( |
|
|
497 |
[norm.pdf(bc, loc=pc, scale=pc * 0.8) * np.sqrt(2 * np.pi) * pc * 0.8 for bc, pc in |
|
|
498 |
zip(box_centers, np.array(self.cf.patch_size) / 2)]) |
|
|
499 |
if self.cf.dim == 3: |
|
|
500 |
c += np.array([pc[0], pc[2], pc[0], pc[2], pc[4], pc[4]]) |
|
|
501 |
int_c = [int(np.floor(ii)) if ix%2 == 0 else int(np.ceil(ii)) for ix, ii in enumerate(c)] |
|
|
502 |
box['box_n_overlaps'] = np.mean(patch_overlap_map[:, :, int_c[1]:int_c[3], int_c[0]:int_c[2], int_c[4]:int_c[5]]) |
|
|
503 |
results_dict['boxes'][0].append(box) |
|
|
504 |
else: |
|
|
505 |
c += np.array([pc[0], pc[2], pc[0], pc[2]]) |
|
|
506 |
int_c = [int(np.floor(ii)) if ix % 2 == 0 else int(np.ceil(ii)) for ix, ii in enumerate(c)] |
|
|
507 |
box['box_n_overlaps'] = np.mean(patch_overlap_map[pc[4], :, int_c[1]:int_c[3], int_c[0]:int_c[2]]) |
|
|
508 |
results_dict['boxes'][pc[4]].append(box) |
|
|
509 |
|
|
|
510 |
if self.mode == 'val': |
|
|
511 |
try: |
|
|
512 |
results_dict['torch_loss'] = patches_dict['torch_loss'] |
|
|
513 |
results_dict['class_loss'] = patches_dict['class_loss'] |
|
|
514 |
except KeyError: |
|
|
515 |
pass |
|
|
516 |
# if predictions are not patch-based: |
|
|
517 |
# add patch-origin info to boxes (entire image is the same patch with overlap=1) and return results. |
|
|
518 |
else: |
|
|
519 |
results_dict = self.batch_tiling_forward(batch) |
|
|
520 |
for b in results_dict['boxes']: |
|
|
521 |
for box in b: |
|
|
522 |
box['box_patch_center_factor'] = 1 |
|
|
523 |
box['box_n_overlaps'] = 1 |
|
|
524 |
box['patch_id'] = self.rank_ix + '_' + n_aug |
|
|
525 |
|
|
|
526 |
return results_dict |
|
|
527 |
|
|
|
528 |
|
|
|
529 |
def batch_tiling_forward(self, batch): |
|
|
530 |
""" |
|
|
531 |
calls the actual network forward method. in patch-based prediction, the batch dimension might be overladed |
|
|
532 |
with n_patches >> batch_size, which would exceed gpu memory. In this case, batches are processed in chunks of |
|
|
533 |
batch_size. validation mode calls the train method to monitor losses (returned ground truth objects are discarded). |
|
|
534 |
test mode calls the test forward method, no ground truth required / involved. |
|
|
535 |
:return. results_dict: stores the results for one patient. dictionary with keys: |
|
|
536 |
- 'boxes': list over batch elements. each element is a list over boxes, where each box is |
|
|
537 |
one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions, |
|
|
538 |
and a dummy batch dimension of 1 for 3D predictions. |
|
|
539 |
- 'seg_preds': pixel-wise predictions. (b, 1, y, x, (z)) |
|
|
540 |
- losses (only in validation mode) |
|
|
541 |
""" |
|
|
542 |
#self.logger.info('forwarding (patched) patient with shape: {}'.format(batch['data'].shape)) |
|
|
543 |
|
|
|
544 |
img = batch['data'] |
|
|
545 |
|
|
|
546 |
#batch['data'] = torch.from_numpy(batch['data']).float().to(self.device) |
|
|
547 |
|
|
|
548 |
if img.shape[0] <= self.cf.batch_size: |
|
|
549 |
|
|
|
550 |
if self.mode == 'val': |
|
|
551 |
# call training method to monitor losses |
|
|
552 |
results_dict = self.net.train_forward(batch, is_validation=True) |
|
|
553 |
# discard returned ground-truth boxes (also training info boxes). |
|
|
554 |
results_dict['boxes'] = [[box for box in b if box['box_type'] == 'det'] for b in results_dict['boxes']] |
|
|
555 |
else: |
|
|
556 |
results_dict = self.net.test_forward(batch, return_masks=self.cf.return_masks_in_test) |
|
|
557 |
|
|
|
558 |
else: |
|
|
559 |
split_ixs = np.split(np.arange(img.shape[0]), np.arange(img.shape[0])[::self.cf.batch_size]) |
|
|
560 |
chunk_dicts = [] |
|
|
561 |
for chunk_ixs in split_ixs[1:]: # first split is elements before 0, so empty |
|
|
562 |
b = {k: batch[k][chunk_ixs] for k in batch.keys() |
|
|
563 |
if (isinstance(batch[k], np.ndarray) and batch[k].shape[0] == img.shape[0])} |
|
|
564 |
if self.mode == 'val': |
|
|
565 |
chunk_dicts += [self.net.train_forward(b, is_validation=True)] |
|
|
566 |
else: |
|
|
567 |
chunk_dicts += [self.net.test_forward(b, return_masks=self.cf.return_masks_in_test)] |
|
|
568 |
|
|
|
569 |
|
|
|
570 |
results_dict = {} |
|
|
571 |
# flatten out batch elements from chunks ([chunk, chunk] -> [b, b, b, b, ...]) |
|
|
572 |
results_dict['boxes'] = [item for d in chunk_dicts for item in d['boxes']] |
|
|
573 |
results_dict['seg_preds'] = np.array([item for d in chunk_dicts for item in d['seg_preds']]) |
|
|
574 |
|
|
|
575 |
if self.mode == 'val': |
|
|
576 |
try: |
|
|
577 |
# estimate metrics by mean over batch_chunks. Most similar to training metrics. |
|
|
578 |
results_dict['torch_loss'] = torch.mean(torch.cat([d['torch_loss'] for d in chunk_dicts])) |
|
|
579 |
results_dict['class_loss'] = np.mean([d['class_loss'] for d in chunk_dicts]) |
|
|
580 |
except KeyError: |
|
|
581 |
# losses are not necessarily monitored |
|
|
582 |
pass |
|
|
583 |
# discard returned ground-truth boxes (also training info boxes). |
|
|
584 |
results_dict['boxes'] = [[box for box in b if box['box_type'] == 'det'] for b in results_dict['boxes']] |
|
|
585 |
|
|
|
586 |
return results_dict |
|
|
587 |
|
|
|
588 |
|
|
|
589 |
|
|
|
590 |
def apply_wbc_to_patient(inputs): |
|
|
591 |
""" |
|
|
592 |
wrapper around prediction box consolidation: weighted cluster scoring (wcs). processes a single patient. |
|
|
593 |
loops over batch elements in patient results (1 in 3D, slices in 2D) and foreground classes, |
|
|
594 |
aggregates and stores results in new list. |
|
|
595 |
:return. patient_results_list: list over batch elements. each element is a list over boxes, where each box is |
|
|
596 |
one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D |
|
|
597 |
predictions, and a dummy batch dimension of 1 for 3D predictions. |
|
|
598 |
:return. pid: string. patient id. |
|
|
599 |
""" |
|
|
600 |
in_patient_results_list, pid, class_dict, wcs_iou, n_ens = inputs |
|
|
601 |
out_patient_results_list = [[] for _ in range(len(in_patient_results_list))] |
|
|
602 |
|
|
|
603 |
for bix, b in enumerate(in_patient_results_list): |
|
|
604 |
|
|
|
605 |
for cl in list(class_dict.keys()): |
|
|
606 |
|
|
|
607 |
boxes = [(ix, box) for ix, box in enumerate(b) if (box['box_type'] == 'det' and box['box_pred_class_id'] == cl)] |
|
|
608 |
box_coords = np.array([b[1]['box_coords'] for b in boxes]) |
|
|
609 |
box_scores = np.array([b[1]['box_score'] for b in boxes]) |
|
|
610 |
box_center_factor = np.array([b[1]['box_patch_center_factor'] for b in boxes]) |
|
|
611 |
box_n_overlaps = np.array([b[1]['box_n_overlaps'] for b in boxes]) |
|
|
612 |
box_patch_id = np.array([b[1]['patch_id'] for b in boxes]) |
|
|
613 |
|
|
|
614 |
if 0 not in box_scores.shape: |
|
|
615 |
keep_scores, keep_coords = weighted_box_clustering( |
|
|
616 |
np.concatenate((box_coords, box_scores[:, None], box_center_factor[:, None], |
|
|
617 |
box_n_overlaps[:, None]), axis=1), box_patch_id, wcs_iou, n_ens) |
|
|
618 |
|
|
|
619 |
for boxix in range(len(keep_scores)): |
|
|
620 |
out_patient_results_list[bix].append({'box_type': 'det', 'box_coords': keep_coords[boxix], |
|
|
621 |
'box_score': keep_scores[boxix], 'box_pred_class_id': cl}) |
|
|
622 |
|
|
|
623 |
# add gt boxes back to new output list. |
|
|
624 |
out_patient_results_list[bix].extend([box for box in b if box['box_type'] == 'gt']) |
|
|
625 |
|
|
|
626 |
return [out_patient_results_list, pid] |
|
|
627 |
|
|
|
628 |
|
|
|
629 |
|
|
|
630 |
def merge_2D_to_3D_preds_per_patient(inputs): |
|
|
631 |
""" |
|
|
632 |
wrapper around 2Dto3D merging operation. Processes a single patient. Takes 2D patient results (slices in batch dimension) |
|
|
633 |
and returns 3D patient results (dummy batch dimension of 1). Applies an adaption of Non-Maximum Surpression |
|
|
634 |
(Detailed methodology is described in nms_2to3D). |
|
|
635 |
:return. results_dict_boxes: list over batch elements (1 in 3D). each element is a list over boxes, where each box is |
|
|
636 |
one dictionary: [[box_0, ...], [box_n,...]]. |
|
|
637 |
:return. pid: string. patient id. |
|
|
638 |
""" |
|
|
639 |
in_patient_results_list, pid, class_dict, merge_3D_iou = inputs |
|
|
640 |
out_patient_results_list = [] |
|
|
641 |
|
|
|
642 |
for cl in list(class_dict.keys()): |
|
|
643 |
boxes, slice_ids = [], [] |
|
|
644 |
# collect box predictions over batch dimension (slices) and store slice info as slice_ids. |
|
|
645 |
for bix, b in enumerate(in_patient_results_list): |
|
|
646 |
det_boxes = [(ix, box) for ix, box in enumerate(b) if |
|
|
647 |
(box['box_type'] == 'det' and box['box_pred_class_id'] == cl)] |
|
|
648 |
boxes += det_boxes |
|
|
649 |
slice_ids += [bix] * len(det_boxes) |
|
|
650 |
|
|
|
651 |
box_coords = np.array([b[1]['box_coords'] for b in boxes]) |
|
|
652 |
box_scores = np.array([b[1]['box_score'] for b in boxes]) |
|
|
653 |
slice_ids = np.array(slice_ids) |
|
|
654 |
|
|
|
655 |
if 0 not in box_scores.shape: |
|
|
656 |
keep_ix, keep_z = nms_2to3D( |
|
|
657 |
np.concatenate((box_coords, box_scores[:, None], slice_ids[:, None]), axis=1), merge_3D_iou) |
|
|
658 |
else: |
|
|
659 |
keep_ix, keep_z = [], [] |
|
|
660 |
|
|
|
661 |
# store kept predictions in new results list and add corresponding z-dimension info to coordinates. |
|
|
662 |
for kix, kz in zip(keep_ix, keep_z): |
|
|
663 |
out_patient_results_list.append({'box_type': 'det', 'box_coords': list(box_coords[kix]) + kz, |
|
|
664 |
'box_score': box_scores[kix], 'box_pred_class_id': cl}) |
|
|
665 |
|
|
|
666 |
gt_boxes = [box for b in in_patient_results_list for box in b if box['box_type'] == 'gt'] |
|
|
667 |
if len(gt_boxes) > 0: |
|
|
668 |
assert np.all([len(box["box_coords"]) == 6 for box in gt_boxes]), "expanded preds to 3D but GT is 2D." |
|
|
669 |
out_patient_results_list += gt_boxes |
|
|
670 |
|
|
|
671 |
# add dummy batch dimension 1 for 3D. |
|
|
672 |
return [[out_patient_results_list], pid] |
|
|
673 |
|
|
|
674 |
|
|
|
675 |
|
|
|
676 |
def weighted_box_clustering(dets, box_patch_id, thresh, n_ens): |
|
|
677 |
""" |
|
|
678 |
consolidates overlapping predictions resulting from patch overlaps, test data augmentations and temporal ensembling. |
|
|
679 |
clusters predictions together with iou > thresh (like in NMS). Output score and coordinate for one cluster are the |
|
|
680 |
average weighted by individual patch center factors (how trustworthy is this candidate measured by how centered |
|
|
681 |
its position the patch is) and the size of the corresponding box. |
|
|
682 |
The number of expected predictions at a position is n_data_aug * n_temp_ens * n_overlaps_at_position |
|
|
683 |
(1 prediction per unique patch). Missing predictions at a cluster position are defined as the number of unique |
|
|
684 |
patches in the cluster, which did not contribute any predict any boxes. |
|
|
685 |
:param dets: (n_dets, (y1, x1, y2, x2, (z1), (z2), scores, box_pc_facts, box_n_ovs) |
|
|
686 |
:param thresh: threshold for iou_matching. |
|
|
687 |
:param n_ens: number of models, that are ensembled. (-> number of expected predicitions per position) |
|
|
688 |
:return: keep_scores: (n_keep) new scores of boxes to be kept. |
|
|
689 |
:return: keep_coords: (n_keep, (y1, x1, y2, x2, (z1), (z2)) new coordinates of boxes to be kept. |
|
|
690 |
""" |
|
|
691 |
dim = 2 if dets.shape[1] == 7 else 3 |
|
|
692 |
y1 = dets[:, 0] |
|
|
693 |
x1 = dets[:, 1] |
|
|
694 |
y2 = dets[:, 2] |
|
|
695 |
x2 = dets[:, 3] |
|
|
696 |
scores = dets[:, -3] |
|
|
697 |
box_pc_facts = dets[:, -2] |
|
|
698 |
box_n_ovs = dets[:, -1] |
|
|
699 |
|
|
|
700 |
areas = (y2 - y1 + 1) * (x2 - x1 + 1) |
|
|
701 |
|
|
|
702 |
if dim == 3: |
|
|
703 |
z1 = dets[:, 4] |
|
|
704 |
z2 = dets[:, 5] |
|
|
705 |
areas *= (z2 - z1 + 1) |
|
|
706 |
|
|
|
707 |
# order is the sorted index. maps order to index o[1] = 24 (rank1, ix 24) |
|
|
708 |
order = scores.argsort()[::-1] |
|
|
709 |
|
|
|
710 |
keep = [] |
|
|
711 |
keep_scores = [] |
|
|
712 |
keep_coords = [] |
|
|
713 |
|
|
|
714 |
while order.size > 0: |
|
|
715 |
i = order[0] # higehst scoring element |
|
|
716 |
xx1 = np.maximum(x1[i], x1[order]) |
|
|
717 |
yy1 = np.maximum(y1[i], y1[order]) |
|
|
718 |
xx2 = np.minimum(x2[i], x2[order]) |
|
|
719 |
yy2 = np.minimum(y2[i], y2[order]) |
|
|
720 |
|
|
|
721 |
w = np.maximum(0.0, xx2 - xx1 + 1) |
|
|
722 |
h = np.maximum(0.0, yy2 - yy1 + 1) |
|
|
723 |
inter = w * h |
|
|
724 |
|
|
|
725 |
if dim == 3: |
|
|
726 |
zz1 = np.maximum(z1[i], z1[order]) |
|
|
727 |
zz2 = np.minimum(z2[i], z2[order]) |
|
|
728 |
d = np.maximum(0.0, zz2 - zz1 + 1) |
|
|
729 |
inter *= d |
|
|
730 |
|
|
|
731 |
# overall between currently highest scoring box and all boxes. |
|
|
732 |
ovr = inter / (areas[i] + areas[order] - inter) |
|
|
733 |
|
|
|
734 |
# get all the predictions that match the current box to build one cluster. |
|
|
735 |
matches = np.argwhere(ovr > thresh) |
|
|
736 |
|
|
|
737 |
match_n_ovs = box_n_ovs[order[matches]] |
|
|
738 |
match_pc_facts = box_pc_facts[order[matches]] |
|
|
739 |
match_patch_id = box_patch_id[order[matches]] |
|
|
740 |
match_ov_facts = ovr[matches] |
|
|
741 |
match_areas = areas[order[matches]] |
|
|
742 |
match_scores = scores[order[matches]] |
|
|
743 |
|
|
|
744 |
# weight all socres in cluster by patch factors, and size. |
|
|
745 |
match_score_weights = match_ov_facts * match_areas * match_pc_facts |
|
|
746 |
match_scores *= match_score_weights |
|
|
747 |
|
|
|
748 |
# for the weigted average, scores have to be divided by the number of total expected preds at the position |
|
|
749 |
# of the current cluster. 1 Prediction per patch is expected. therefore, the number of ensembled models is |
|
|
750 |
# multiplied by the mean overlaps of patches at this position (boxes of the cluster might partly be |
|
|
751 |
# in areas of different overlaps). |
|
|
752 |
n_expected_preds = n_ens * np.mean(match_n_ovs) |
|
|
753 |
|
|
|
754 |
# the number of missing predictions is obtained as the number of patches, |
|
|
755 |
# which did not contribute any prediction to the current cluster. |
|
|
756 |
n_missing_preds = np.max((0, n_expected_preds - np.unique(match_patch_id).shape[0])) |
|
|
757 |
|
|
|
758 |
# missing preds are given the mean weighting |
|
|
759 |
# (expected prediction is the mean over all predictions in cluster). |
|
|
760 |
denom = np.sum(match_score_weights) + n_missing_preds * np.mean(match_score_weights) |
|
|
761 |
|
|
|
762 |
# compute weighted average score for the cluster |
|
|
763 |
avg_score = np.sum(match_scores) / denom |
|
|
764 |
|
|
|
765 |
# compute weighted average of coordinates for the cluster. now only take existing |
|
|
766 |
# predictions into account. |
|
|
767 |
avg_coords = [np.sum(y1[order[matches]] * match_scores) / np.sum(match_scores), |
|
|
768 |
np.sum(x1[order[matches]] * match_scores) / np.sum(match_scores), |
|
|
769 |
np.sum(y2[order[matches]] * match_scores) / np.sum(match_scores), |
|
|
770 |
np.sum(x2[order[matches]] * match_scores) / np.sum(match_scores)] |
|
|
771 |
if dim == 3: |
|
|
772 |
avg_coords.append(np.sum(z1[order[matches]] * match_scores) / np.sum(match_scores)) |
|
|
773 |
avg_coords.append(np.sum(z2[order[matches]] * match_scores) / np.sum(match_scores)) |
|
|
774 |
|
|
|
775 |
# some clusters might have very low scores due to high amounts of missing predictions. |
|
|
776 |
# filter out the with a conservative threshold, to speed up evaluation. |
|
|
777 |
if avg_score > 0.01: |
|
|
778 |
keep_scores.append(avg_score) |
|
|
779 |
keep_coords.append(avg_coords) |
|
|
780 |
|
|
|
781 |
# get index of all elements that were not matched and discard all others. |
|
|
782 |
inds = np.where(ovr <= thresh)[0] |
|
|
783 |
order = order[inds] |
|
|
784 |
|
|
|
785 |
return keep_scores, keep_coords |
|
|
786 |
|
|
|
787 |
|
|
|
788 |
|
|
|
789 |
def nms_2to3D(dets, thresh): |
|
|
790 |
""" |
|
|
791 |
Merges 2D boxes to 3D cubes. Therefore, boxes of all slices are projected into one slices. An adaptation of Non-maximum surpression |
|
|
792 |
is applied, where clusters are found (like in NMS) with an extra constrained, that surpressed boxes have to have 'connected' |
|
|
793 |
z-coordinates w.r.t the core slice (cluster center, highest scoring box). 'connected' z-coordinates are determined |
|
|
794 |
as the z-coordinates with predictions until the first coordinate, where no prediction was found. |
|
|
795 |
|
|
|
796 |
example: a cluster of predictions was found overlap > iou thresh in xy (like NMS). The z-coordinate of the highest |
|
|
797 |
scoring box is 50. Other predictions have 23, 46, 48, 49, 51, 52, 53, 56, 57. |
|
|
798 |
Only the coordinates connected with 50 are clustered to one cube: 48, 49, 51, 52, 53. (46 not because nothing was |
|
|
799 |
found in 47, so 47 is a 'hole', which interrupts the connection). Only the boxes corresponding to these coordinates |
|
|
800 |
are surpressed. All others are kept for building of further clusters. |
|
|
801 |
|
|
|
802 |
This algorithm works better with a certain min_confidence of predictions, because low confidence (e.g. noisy/cluttery) |
|
|
803 |
predictions can break the relatively strong assumption of defining cubes' z-boundaries at the first 'hole' in the cluster. |
|
|
804 |
|
|
|
805 |
:param dets: (n_detections, (y1, x1, y2, x2, scores, slice_id) |
|
|
806 |
:param thresh: iou matchin threshold (like in NMS). |
|
|
807 |
:return: keep: (n_keep) 1D tensor of indices to be kept. |
|
|
808 |
:return: keep_z: (n_keep, [z1, z2]) z-coordinates to be added to boxes, which are kept in order to form cubes. |
|
|
809 |
""" |
|
|
810 |
y1 = dets[:, 0] |
|
|
811 |
x1 = dets[:, 1] |
|
|
812 |
y2 = dets[:, 2] |
|
|
813 |
x2 = dets[:, 3] |
|
|
814 |
scores = dets[:, -2] |
|
|
815 |
slice_id = dets[:, -1] |
|
|
816 |
|
|
|
817 |
areas = (x2 - x1 + 1) * (y2 - y1 + 1) |
|
|
818 |
order = scores.argsort()[::-1] |
|
|
819 |
|
|
|
820 |
keep = [] |
|
|
821 |
keep_z = [] |
|
|
822 |
|
|
|
823 |
while order.size > 0: # order is the sorted index. maps order to index o[1] = 24 (rank1, ix 24) |
|
|
824 |
i = order[0] # pop higehst scoring element |
|
|
825 |
xx1 = np.maximum(x1[i], x1[order]) |
|
|
826 |
yy1 = np.maximum(y1[i], y1[order]) |
|
|
827 |
xx2 = np.minimum(x2[i], x2[order]) |
|
|
828 |
yy2 = np.minimum(y2[i], y2[order]) |
|
|
829 |
|
|
|
830 |
w = np.maximum(0.0, xx2 - xx1 + 1) |
|
|
831 |
h = np.maximum(0.0, yy2 - yy1 + 1) |
|
|
832 |
inter = w * h |
|
|
833 |
|
|
|
834 |
ovr = inter / (areas[i] + areas[order] - inter) |
|
|
835 |
matches = np.argwhere(ovr > thresh) # get all the elements that match the current box and have a lower score |
|
|
836 |
|
|
|
837 |
slice_ids = slice_id[order[matches]] |
|
|
838 |
core_slice = slice_id[int(i)] |
|
|
839 |
upper_wholes = [ii for ii in np.arange(core_slice, np.max(slice_ids)) if ii not in slice_ids] |
|
|
840 |
lower_wholes = [ii for ii in np.arange(np.min(slice_ids), core_slice) if ii not in slice_ids] |
|
|
841 |
max_valid_slice_id = np.min(upper_wholes) if len(upper_wholes) > 0 else np.max(slice_ids) |
|
|
842 |
min_valid_slice_id = np.max(lower_wholes) if len(lower_wholes) > 0 else np.min(slice_ids) |
|
|
843 |
z_matches = matches[(slice_ids <= max_valid_slice_id) & (slice_ids >= min_valid_slice_id)] |
|
|
844 |
|
|
|
845 |
z1 = np.min(slice_id[order[z_matches]]) - 1 |
|
|
846 |
z2 = np.max(slice_id[order[z_matches]]) + 1 |
|
|
847 |
|
|
|
848 |
keep.append(i) |
|
|
849 |
keep_z.append([z1, z2]) |
|
|
850 |
order = np.delete(order, z_matches, axis=0) |
|
|
851 |
|
|
|
852 |
return keep, keep_z |
|
|
853 |
|
|
|
854 |
|
|
|
855 |
|
|
|
856 |
def get_mirrored_patch_crops(patch_crops, org_img_shape): |
|
|
857 |
""" |
|
|
858 |
apply 3 mirrror transformations (x-axis, y-axis, x&y-axis) |
|
|
859 |
to given patch crop coordinates and return the transformed coordinates. |
|
|
860 |
Handles 2D and 3D coordinates. |
|
|
861 |
:param patch_crops: list of crops: each element is a list of coordinates for given crop [[y1, x1, ...], [y1, x1, ..]] |
|
|
862 |
:param org_img_shape: shape of patient volume used as world coordinates. |
|
|
863 |
:return: list of mirrored patch crops: lenght=3. each element is a list of transformed patch crops. |
|
|
864 |
""" |
|
|
865 |
mirrored_patch_crops = [] |
|
|
866 |
|
|
|
867 |
# y-axis transform. |
|
|
868 |
mirrored_patch_crops.append([[org_img_shape[2] - ii[1], |
|
|
869 |
org_img_shape[2] - ii[0], |
|
|
870 |
ii[2], ii[3]] if len(ii) == 4 else |
|
|
871 |
[org_img_shape[2] - ii[1], |
|
|
872 |
org_img_shape[2] - ii[0], |
|
|
873 |
ii[2], ii[3], ii[4], ii[5]] for ii in patch_crops]) |
|
|
874 |
|
|
|
875 |
# x-axis transform. |
|
|
876 |
mirrored_patch_crops.append([[ii[0], ii[1], |
|
|
877 |
org_img_shape[3] - ii[3], |
|
|
878 |
org_img_shape[3] - ii[2]] if len(ii) == 4 else |
|
|
879 |
[ii[0], ii[1], |
|
|
880 |
org_img_shape[3] - ii[3], |
|
|
881 |
org_img_shape[3] - ii[2], |
|
|
882 |
ii[4], ii[5]] for ii in patch_crops]) |
|
|
883 |
|
|
|
884 |
# y-axis and x-axis transform. |
|
|
885 |
mirrored_patch_crops.append([[org_img_shape[2] - ii[1], |
|
|
886 |
org_img_shape[2] - ii[0], |
|
|
887 |
org_img_shape[3] - ii[3], |
|
|
888 |
org_img_shape[3] - ii[2]] if len(ii) == 4 else |
|
|
889 |
[org_img_shape[2] - ii[1], |
|
|
890 |
org_img_shape[2] - ii[0], |
|
|
891 |
org_img_shape[3] - ii[3], |
|
|
892 |
org_img_shape[3] - ii[2], |
|
|
893 |
ii[4], ii[5]] for ii in patch_crops]) |
|
|
894 |
|
|
|
895 |
return mirrored_patch_crops |
|
|
896 |
|
|
|
897 |
|
|
|
898 |
|