|
a |
|
b/plotting.py |
|
|
1 |
#!/usr/bin/env python |
|
|
2 |
# Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ). |
|
|
3 |
# |
|
|
4 |
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
5 |
# you may not use this file except in compliance with the License. |
|
|
6 |
# You may obtain a copy of the License at |
|
|
7 |
# |
|
|
8 |
# http://www.apache.org/licenses/LICENSE-2.0 |
|
|
9 |
# |
|
|
10 |
# Unless required by applicable law or agreed to in writing, software |
|
|
11 |
# distributed under the License is distributed on an "AS IS" BASIS, |
|
|
12 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
13 |
# See the License for the specific language governing permissions and |
|
|
14 |
# limitations under the License. |
|
|
15 |
# ============================================================================== |
|
|
16 |
from typing import Union |
|
|
17 |
|
|
|
18 |
import matplotlib |
|
|
19 |
matplotlib.use('Agg') |
|
|
20 |
import matplotlib.pyplot as plt |
|
|
21 |
import matplotlib.gridspec as gridspec |
|
|
22 |
import numpy as np |
|
|
23 |
import os |
|
|
24 |
from copy import deepcopy |
|
|
25 |
import pickle |
|
|
26 |
|
|
|
27 |
def suppress_axes_lines(ax): |
|
|
28 |
""" |
|
|
29 |
:param ax: pyplot axes object |
|
|
30 |
""" |
|
|
31 |
ax.axes.get_xaxis().set_ticks([]) |
|
|
32 |
ax.axes.get_yaxis().set_ticks([]) |
|
|
33 |
ax.spines['top'].set_visible(False) |
|
|
34 |
ax.spines['right'].set_visible(False) |
|
|
35 |
ax.spines['bottom'].set_visible(False) |
|
|
36 |
ax.spines['left'].set_visible(False) |
|
|
37 |
|
|
|
38 |
return |
|
|
39 |
|
|
|
40 |
def plot_batch_prediction(batch: dict, results_dict: dict, cf, outfile: Union[str, None]=None, |
|
|
41 |
suptitle: Union[str, None]=None): |
|
|
42 |
""" |
|
|
43 |
plot the input images, ground truth annotations, and output predictions of a batch. If 3D batch, plots a 2D projection |
|
|
44 |
of one randomly sampled element (patient) in the batch. Since plotting all slices of patient volume blows up costs of |
|
|
45 |
time and space, only a section containing a randomly sampled ground truth annotation is plotted. |
|
|
46 |
:param batch: dict with keys: 'data' (input image), 'seg' (pixelwise annotations), 'pid' |
|
|
47 |
:param results_dict: list over batch element. Each element is a list of boxes (prediction and ground truth), |
|
|
48 |
where every box is a dictionary containing box_coords, box_score and box_type. |
|
|
49 |
""" |
|
|
50 |
print ("Outfile beginning"+str(outfile)) |
|
|
51 |
if outfile is None: |
|
|
52 |
outfile = os.path.join(cf.plot_dir, 'pred_example_{}.png'.format(cf.fold)) |
|
|
53 |
|
|
|
54 |
data = batch['data'] |
|
|
55 |
segs = batch['seg'] |
|
|
56 |
pids = batch['pid'] |
|
|
57 |
# for 3D, repeat pid over batch elements. |
|
|
58 |
if len(set(pids)) == 1: |
|
|
59 |
pids = [pids] * data.shape[0] |
|
|
60 |
|
|
|
61 |
seg_preds = results_dict['seg_preds'] |
|
|
62 |
roi_results = deepcopy(results_dict['boxes']) |
|
|
63 |
|
|
|
64 |
# Randomly sampled one patient of batch and project data into 2D slices for plotting. |
|
|
65 |
if cf.dim == 3: |
|
|
66 |
patient_ix = np.random.choice(data.shape[0]) |
|
|
67 |
data = np.transpose(data[patient_ix], axes=(3, 0, 1, 2)) |
|
|
68 |
|
|
|
69 |
# select interesting foreground section to plot. |
|
|
70 |
gt_boxes = [box['box_coords'] for box in roi_results[patient_ix] if box['box_type'] == 'gt'] |
|
|
71 |
if len(gt_boxes) > 0: |
|
|
72 |
z_cuts = [np.max((int(gt_boxes[0][4]) - 5, 0)), np.min((int(gt_boxes[0][5]) + 5, data.shape[0]))] |
|
|
73 |
else: |
|
|
74 |
z_cuts = [data.shape[0]//2 - 5, int(data.shape[0]//2 + np.min([10, data.shape[0]//2]))] |
|
|
75 |
p_roi_results = roi_results[patient_ix] |
|
|
76 |
roi_results = [[] for _ in range(data.shape[0])] |
|
|
77 |
|
|
|
78 |
# iterate over cubes and spread across slices. |
|
|
79 |
for box in p_roi_results: |
|
|
80 |
b = box['box_coords'] |
|
|
81 |
# dismiss negative anchor slices. |
|
|
82 |
slices = np.round(np.unique(np.clip(np.arange(b[4], b[5] + 1), 0, data.shape[0]-1))) |
|
|
83 |
for s in slices: |
|
|
84 |
roi_results[int(s)].append(box) |
|
|
85 |
roi_results[int(s)][-1]['box_coords'] = b[:4] |
|
|
86 |
|
|
|
87 |
roi_results = roi_results[z_cuts[0]: z_cuts[1]] |
|
|
88 |
data = data[z_cuts[0]: z_cuts[1]] |
|
|
89 |
segs = np.transpose(segs[patient_ix], axes=(3, 0, 1, 2))[z_cuts[0]: z_cuts[1]] |
|
|
90 |
seg_preds = np.transpose(seg_preds[patient_ix], axes=(3, 0, 1, 2))[z_cuts[0]: z_cuts[1]] |
|
|
91 |
pids = [pids[patient_ix]] * data.shape[0] |
|
|
92 |
|
|
|
93 |
try: |
|
|
94 |
# all dimensions except for the 'channel-dimension' are required to match |
|
|
95 |
for i in [0, 2, 3]: |
|
|
96 |
assert data.shape[i] == segs.shape[i] == seg_preds.shape[i] |
|
|
97 |
except: |
|
|
98 |
raise Warning('Shapes of arrays to plot not in agreement!' |
|
|
99 |
'Shapes {} vs. {} vs {}'.format(data.shape, segs.shape, seg_preds.shape)) |
|
|
100 |
|
|
|
101 |
|
|
|
102 |
show_arrays = np.concatenate([data, segs, seg_preds, data[:, 0][:, None]], axis=1).astype(float) |
|
|
103 |
approx_figshape = (4 * show_arrays.shape[0], 4 * show_arrays.shape[1]) |
|
|
104 |
fig = plt.figure(figsize=approx_figshape) |
|
|
105 |
gs = gridspec.GridSpec(show_arrays.shape[1] + 1, show_arrays.shape[0]) |
|
|
106 |
gs.update(wspace=0.1, hspace=0.1) |
|
|
107 |
for b in range(show_arrays.shape[0]): |
|
|
108 |
for m in range(show_arrays.shape[1]): |
|
|
109 |
|
|
|
110 |
ax = plt.subplot(gs[m, b]) |
|
|
111 |
suppress_axes_lines(ax) |
|
|
112 |
if m < show_arrays.shape[1]: |
|
|
113 |
arr = show_arrays[b, m] |
|
|
114 |
|
|
|
115 |
if m < data.shape[1] or m == show_arrays.shape[1] - 1: |
|
|
116 |
if b == 0: |
|
|
117 |
ax.set_ylabel("Input" + (" + GT & Pred Box" if m == show_arrays.shape[1] - 1 else "")) |
|
|
118 |
cmap = 'gray' |
|
|
119 |
vmin = None |
|
|
120 |
vmax = None |
|
|
121 |
else: |
|
|
122 |
cmap = None |
|
|
123 |
vmin = 0 |
|
|
124 |
vmax = cf.num_seg_classes - 1 |
|
|
125 |
|
|
|
126 |
if m == 0: |
|
|
127 |
plt.title('{}'.format(pids[b][:10]), fontsize=20) |
|
|
128 |
|
|
|
129 |
plt.imshow(arr, cmap=cmap, vmin=vmin, vmax=vmax) |
|
|
130 |
if m >= (data.shape[1]): |
|
|
131 |
if b == 0: |
|
|
132 |
if m == data.shape[1]: |
|
|
133 |
ax.set_ylabel("GT Box & Seg") |
|
|
134 |
if m == data.shape[1]+1: |
|
|
135 |
ax.set_ylabel("GT Box + Pred Seg & Box") |
|
|
136 |
for box in roi_results[b]: |
|
|
137 |
if box['box_type'] != 'patient_tn_box': # don't plot true negative dummy boxes. |
|
|
138 |
coords = box['box_coords'] |
|
|
139 |
if box['box_type'] == 'det': |
|
|
140 |
# dont plot background preds or low confidence boxes. |
|
|
141 |
if box['box_pred_class_id'] > 0 and box['box_score'] > 0.1: |
|
|
142 |
plot_text = True |
|
|
143 |
score = np.max(box['box_score']) |
|
|
144 |
score_text = '{}|{:.0f}'.format(box['box_pred_class_id'], score*100) |
|
|
145 |
# if prob detection: plot only boxes from correct sampling instance. |
|
|
146 |
if 'sample_id' in box.keys() and int(box['sample_id']) != m - data.shape[1] - 2: |
|
|
147 |
continue |
|
|
148 |
# if prob detection: plot reconstructed boxes only in corresponding line. |
|
|
149 |
if not 'sample_id' in box.keys() and m != data.shape[1] + 1: |
|
|
150 |
continue |
|
|
151 |
|
|
|
152 |
score_font_size = 7 |
|
|
153 |
text_color = 'w' |
|
|
154 |
text_x = coords[1] + 10*(box['box_pred_class_id'] -1) #avoid overlap of scores in plot. |
|
|
155 |
text_y = coords[2] + 5 |
|
|
156 |
else: |
|
|
157 |
continue |
|
|
158 |
elif box['box_type'] == 'gt': |
|
|
159 |
plot_text = True |
|
|
160 |
score_text = int(box['box_label']) |
|
|
161 |
score_font_size = 7 |
|
|
162 |
text_color = 'r' |
|
|
163 |
text_x = coords[1] |
|
|
164 |
text_y = coords[0] - 1 |
|
|
165 |
else: |
|
|
166 |
plot_text = False |
|
|
167 |
|
|
|
168 |
color_var = 'extra_usage' if 'extra_usage' in list(box.keys()) else 'box_type' |
|
|
169 |
color = cf.box_color_palette[box[color_var]] |
|
|
170 |
plt.plot([coords[1], coords[3]], [coords[0], coords[0]], color=color, linewidth=1, alpha=1) # up |
|
|
171 |
plt.plot([coords[1], coords[3]], [coords[2], coords[2]], color=color, linewidth=1, alpha=1) # down |
|
|
172 |
plt.plot([coords[1], coords[1]], [coords[0], coords[2]], color=color, linewidth=1, alpha=1) # left |
|
|
173 |
plt.plot([coords[3], coords[3]], [coords[0], coords[2]], color=color, linewidth=1, alpha=1) # right |
|
|
174 |
if plot_text: |
|
|
175 |
plt.text(text_x, text_y, score_text, fontsize=score_font_size, color=text_color) |
|
|
176 |
|
|
|
177 |
if suptitle is not None: |
|
|
178 |
plt.suptitle(suptitle, fontsize=22) |
|
|
179 |
|
|
|
180 |
print ("Outfile end"+str(outfile)) |
|
|
181 |
try: |
|
|
182 |
plt.savefig(outfile) |
|
|
183 |
except: |
|
|
184 |
raise Warning('failed to save plot.') |
|
|
185 |
plt.close(fig) |
|
|
186 |
|
|
|
187 |
|
|
|
188 |
|
|
|
189 |
class TrainingPlot_2Panel(): |
|
|
190 |
# todo remove since replaced by tensorboard? |
|
|
191 |
|
|
|
192 |
def __init__(self, cf): |
|
|
193 |
|
|
|
194 |
self.file_name = cf.plot_dir + '/monitor_{}'.format(cf.fold) |
|
|
195 |
self.exp_name = cf.fold_dir |
|
|
196 |
self.do_validation = cf.do_validation |
|
|
197 |
self.separate_values_dict = cf.assign_values_to_extra_figure |
|
|
198 |
self.figure_list = [] |
|
|
199 |
for n in range(cf.n_monitoring_figures): |
|
|
200 |
self.figure_list.append(plt.figure(figsize=(10, 6))) |
|
|
201 |
self.figure_list[-1].ax1 = plt.subplot(111) |
|
|
202 |
self.figure_list[-1].ax1.set_xlabel('epochs') |
|
|
203 |
self.figure_list[-1].ax1.set_ylabel('loss / metrics') |
|
|
204 |
self.figure_list[-1].ax1.set_xlim(0, cf.num_epochs) |
|
|
205 |
self.figure_list[-1].ax1.grid() |
|
|
206 |
|
|
|
207 |
self.figure_list[0].ax1.set_ylim(0, 1.5) |
|
|
208 |
self.color_palette = ['b', 'c', 'r', 'purple', 'm', 'y', 'k', 'tab:gray'] |
|
|
209 |
|
|
|
210 |
def update_and_save(self, metrics, epoch): |
|
|
211 |
|
|
|
212 |
for figure_ix in range(len(self.figure_list)): |
|
|
213 |
fig = self.figure_list[figure_ix] |
|
|
214 |
detection_monitoring_plot(fig.ax1, metrics, self.exp_name, self.color_palette, epoch, figure_ix, |
|
|
215 |
self.separate_values_dict, |
|
|
216 |
self.do_validation) |
|
|
217 |
fig.savefig(self.file_name + '_{}'.format(figure_ix)) |
|
|
218 |
|
|
|
219 |
|
|
|
220 |
def detection_monitoring_plot(ax1, metrics, exp_name, color_palette, epoch, figure_ix, separate_values_dict, do_validation): |
|
|
221 |
# todo remove since replaced by tensorboard? |
|
|
222 |
monitor_values_keys = metrics['train']['monitor_values'][1][0].keys() |
|
|
223 |
separate_values = [v for fig_ix in separate_values_dict.values() for v in fig_ix] |
|
|
224 |
if figure_ix == 0: |
|
|
225 |
plot_keys = [ii for ii in monitor_values_keys if ii not in separate_values] |
|
|
226 |
plot_keys += [k for k in metrics['train'].keys() if k != 'monitor_values'] |
|
|
227 |
else: |
|
|
228 |
plot_keys = separate_values_dict[figure_ix] |
|
|
229 |
|
|
|
230 |
|
|
|
231 |
x = np.arange(1, epoch + 1) |
|
|
232 |
|
|
|
233 |
for kix, pk in enumerate(plot_keys): |
|
|
234 |
if pk in metrics['train'].keys(): |
|
|
235 |
y_train = metrics['train'][pk][1:] |
|
|
236 |
if do_validation: |
|
|
237 |
y_val = metrics['val'][pk][1:] |
|
|
238 |
else: |
|
|
239 |
y_train = [np.mean([er[pk] for er in metrics['train']['monitor_values'][e]]) for e in x] |
|
|
240 |
if do_validation: |
|
|
241 |
y_val = [np.mean([er[pk] for er in metrics['val']['monitor_values'][e]]) for e in x] |
|
|
242 |
|
|
|
243 |
ax1.plot(x, y_train, label='train_{}'.format(pk), linestyle='--', color=color_palette[kix]) |
|
|
244 |
if do_validation: |
|
|
245 |
ax1.plot(x, y_val, label='val_{}'.format(pk), linestyle='-', color=color_palette[kix]) |
|
|
246 |
|
|
|
247 |
if epoch == 1: |
|
|
248 |
box = ax1.get_position() |
|
|
249 |
ax1.set_position([box.x0, box.y0, box.width * 0.8, box.height]) |
|
|
250 |
ax1.legend(loc='center left', bbox_to_anchor=(1, 0.5)) |
|
|
251 |
ax1.set_title(exp_name) |
|
|
252 |
|
|
|
253 |
|
|
|
254 |
def plot_prediction_hist(label_list: list, pred_list: list, type_list: list, outfile: str): |
|
|
255 |
""" |
|
|
256 |
plot histogram of predictions for a specific class. |
|
|
257 |
:param label_list: list of 1s and 0s specifying whether prediction is a true positive match (1) or a false positive (0). |
|
|
258 |
False negatives (missed ground truth objects) are artificially added predictions with score 0 and label 1. |
|
|
259 |
:param pred_list: list of prediction-scores. |
|
|
260 |
:param type_list: list of prediction-types for stastic-info in title. |
|
|
261 |
""" |
|
|
262 |
preds = np.array(pred_list) |
|
|
263 |
labels = np.array(label_list) |
|
|
264 |
title = outfile.split('/')[-1] + ' count:{}'.format(len(label_list)) |
|
|
265 |
plt.figure() |
|
|
266 |
plt.yscale('log') |
|
|
267 |
if 0 in labels: |
|
|
268 |
plt.hist(preds[labels == 0], alpha=0.3, color='g', range=(0, 1), bins=50, label='false pos.') |
|
|
269 |
if 1 in labels: |
|
|
270 |
plt.hist(preds[labels == 1], alpha=0.3, color='b', range=(0, 1), bins=50, label='true pos. (false neg. @ score=0)') |
|
|
271 |
|
|
|
272 |
if type_list is not None: |
|
|
273 |
fp_count = type_list.count('det_fp') |
|
|
274 |
fn_count = type_list.count('det_fn') |
|
|
275 |
tp_count = type_list.count('det_tp') |
|
|
276 |
pos_count = fn_count + tp_count |
|
|
277 |
title += ' tp:{} fp:{} fn:{} pos:{}'. format(tp_count, fp_count, fn_count, pos_count) |
|
|
278 |
|
|
|
279 |
plt.legend() |
|
|
280 |
plt.title(title) |
|
|
281 |
plt.xlabel('confidence score') |
|
|
282 |
plt.ylabel('log n') |
|
|
283 |
plt.savefig(outfile) |
|
|
284 |
plt.close() |
|
|
285 |
|
|
|
286 |
|
|
|
287 |
def plot_stat_curves(stats: list, outfile: str): |
|
|
288 |
|
|
|
289 |
for c in ['roc', 'prc']: |
|
|
290 |
plt.figure() |
|
|
291 |
for s in stats: |
|
|
292 |
if not (isinstance(s[c], float) and np.isnan(s[c])): |
|
|
293 |
plt.plot(s[c][0], s[c][1], label=s['name'] + '_' + c) |
|
|
294 |
plt.title(outfile.split('/')[-1] + '_' + c) |
|
|
295 |
plt.legend(loc=3 if c == 'prc' else 4) |
|
|
296 |
plt.xlabel('precision' if c == 'prc' else '1-spec.') |
|
|
297 |
plt.ylabel('recall') |
|
|
298 |
plt.savefig(outfile + '_' + c) |
|
|
299 |
plt.close() |