Diff of /plotting.py [000000] .. [bb7f56]

Switch to unified view

a b/plotting.py
1
#!/usr/bin/env python
2
# Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
# ==============================================================================
16
from typing import Union
17
18
import matplotlib
19
matplotlib.use('Agg')
20
import matplotlib.pyplot as plt
21
import matplotlib.gridspec as gridspec
22
import numpy as np
23
import os
24
from copy import deepcopy
25
import pickle
26
27
def suppress_axes_lines(ax):
28
    """
29
    :param ax: pyplot axes object
30
    """
31
    ax.axes.get_xaxis().set_ticks([])
32
    ax.axes.get_yaxis().set_ticks([])
33
    ax.spines['top'].set_visible(False)
34
    ax.spines['right'].set_visible(False)
35
    ax.spines['bottom'].set_visible(False)
36
    ax.spines['left'].set_visible(False)
37
38
    return
39
40
def plot_batch_prediction(batch: dict, results_dict: dict, cf, outfile: Union[str, None]=None,
41
                          suptitle: Union[str, None]=None):
42
    """
43
    plot the input images, ground truth annotations, and output predictions of a batch. If 3D batch, plots a 2D projection
44
    of one randomly sampled element (patient) in the batch. Since plotting all slices of patient volume blows up costs of
45
    time and space, only a section containing a randomly sampled ground truth annotation is plotted.
46
    :param batch: dict with keys: 'data' (input image), 'seg' (pixelwise annotations), 'pid'
47
    :param results_dict: list over batch element. Each element is a list of boxes (prediction and ground truth),
48
    where every box is a dictionary containing box_coords, box_score and box_type.
49
    """
50
    print ("Outfile beginning"+str(outfile))
51
    if outfile is None:
52
        outfile = os.path.join(cf.plot_dir, 'pred_example_{}.png'.format(cf.fold))
53
54
    data = batch['data']
55
    segs = batch['seg']
56
    pids = batch['pid']
57
    # for 3D, repeat pid over batch elements.
58
    if len(set(pids)) == 1:
59
        pids = [pids] * data.shape[0]
60
61
    seg_preds = results_dict['seg_preds']
62
    roi_results = deepcopy(results_dict['boxes'])
63
64
    # Randomly sampled one patient of batch and project data into 2D slices for plotting.
65
    if cf.dim == 3:
66
        patient_ix = np.random.choice(data.shape[0])
67
        data = np.transpose(data[patient_ix], axes=(3, 0, 1, 2))
68
69
        # select interesting foreground section to plot.
70
        gt_boxes = [box['box_coords'] for box in roi_results[patient_ix] if box['box_type'] == 'gt']
71
        if len(gt_boxes) > 0:
72
            z_cuts = [np.max((int(gt_boxes[0][4]) - 5, 0)), np.min((int(gt_boxes[0][5]) + 5, data.shape[0]))]
73
        else:
74
            z_cuts = [data.shape[0]//2 - 5, int(data.shape[0]//2 + np.min([10, data.shape[0]//2]))]
75
        p_roi_results = roi_results[patient_ix]
76
        roi_results = [[] for _ in range(data.shape[0])]
77
78
        # iterate over cubes and spread across slices.
79
        for box in p_roi_results:
80
            b = box['box_coords']
81
            # dismiss negative anchor slices.
82
            slices = np.round(np.unique(np.clip(np.arange(b[4], b[5] + 1), 0, data.shape[0]-1)))
83
            for s in slices:
84
                roi_results[int(s)].append(box)
85
                roi_results[int(s)][-1]['box_coords'] = b[:4]
86
87
        roi_results = roi_results[z_cuts[0]: z_cuts[1]]
88
        data = data[z_cuts[0]: z_cuts[1]]
89
        segs = np.transpose(segs[patient_ix], axes=(3, 0, 1, 2))[z_cuts[0]: z_cuts[1]]
90
        seg_preds = np.transpose(seg_preds[patient_ix], axes=(3, 0, 1, 2))[z_cuts[0]: z_cuts[1]]
91
        pids = [pids[patient_ix]] * data.shape[0]
92
93
    try:
94
        # all dimensions except for the 'channel-dimension' are required to match
95
        for i in [0, 2, 3]:
96
            assert data.shape[i] == segs.shape[i] == seg_preds.shape[i]
97
    except:
98
        raise Warning('Shapes of arrays to plot not in agreement!'
99
                      'Shapes {} vs. {} vs {}'.format(data.shape, segs.shape, seg_preds.shape))
100
101
102
    show_arrays = np.concatenate([data, segs, seg_preds, data[:, 0][:, None]], axis=1).astype(float)
103
    approx_figshape = (4 * show_arrays.shape[0], 4 * show_arrays.shape[1])
104
    fig = plt.figure(figsize=approx_figshape)
105
    gs = gridspec.GridSpec(show_arrays.shape[1] + 1, show_arrays.shape[0])
106
    gs.update(wspace=0.1, hspace=0.1)
107
    for b in range(show_arrays.shape[0]):
108
        for m in range(show_arrays.shape[1]):
109
110
            ax = plt.subplot(gs[m, b])
111
            suppress_axes_lines(ax)
112
            if m < show_arrays.shape[1]:
113
                arr = show_arrays[b, m]
114
115
            if m < data.shape[1] or m == show_arrays.shape[1] - 1:
116
                if b == 0:
117
                    ax.set_ylabel("Input" + (" + GT & Pred Box" if m == show_arrays.shape[1] - 1 else ""))
118
                cmap = 'gray'
119
                vmin = None
120
                vmax = None
121
            else:
122
                cmap = None
123
                vmin = 0
124
                vmax = cf.num_seg_classes - 1
125
126
            if m == 0:
127
                plt.title('{}'.format(pids[b][:10]), fontsize=20)
128
129
            plt.imshow(arr, cmap=cmap, vmin=vmin, vmax=vmax)
130
            if m >= (data.shape[1]):
131
                if b == 0:
132
                    if m == data.shape[1]:
133
                        ax.set_ylabel("GT Box & Seg")
134
                    if m == data.shape[1]+1:
135
                        ax.set_ylabel("GT Box + Pred Seg & Box")
136
                for box in roi_results[b]:
137
                    if box['box_type'] != 'patient_tn_box': # don't plot true negative dummy boxes.
138
                        coords = box['box_coords']
139
                        if box['box_type'] == 'det':
140
                            # dont plot background preds or low confidence boxes.
141
                            if box['box_pred_class_id'] > 0 and box['box_score'] > 0.1:
142
                                plot_text = True
143
                                score = np.max(box['box_score'])
144
                                score_text = '{}|{:.0f}'.format(box['box_pred_class_id'], score*100)
145
                                # if prob detection: plot only boxes from correct sampling instance.
146
                                if 'sample_id' in box.keys() and int(box['sample_id']) != m - data.shape[1] - 2:
147
                                        continue
148
                                # if prob detection: plot reconstructed boxes only in corresponding line.
149
                                if not 'sample_id' in box.keys() and  m != data.shape[1] + 1:
150
                                    continue
151
152
                                score_font_size = 7
153
                                text_color = 'w'
154
                                text_x = coords[1] + 10*(box['box_pred_class_id'] -1) #avoid overlap of scores in plot.
155
                                text_y = coords[2] + 5
156
                            else:
157
                                continue
158
                        elif box['box_type'] == 'gt':
159
                            plot_text = True
160
                            score_text = int(box['box_label'])
161
                            score_font_size = 7
162
                            text_color = 'r'
163
                            text_x = coords[1]
164
                            text_y = coords[0] - 1
165
                        else:
166
                            plot_text = False
167
168
                        color_var = 'extra_usage' if 'extra_usage' in list(box.keys()) else 'box_type'
169
                        color = cf.box_color_palette[box[color_var]]
170
                        plt.plot([coords[1], coords[3]], [coords[0], coords[0]], color=color, linewidth=1, alpha=1) # up
171
                        plt.plot([coords[1], coords[3]], [coords[2], coords[2]], color=color, linewidth=1, alpha=1) # down
172
                        plt.plot([coords[1], coords[1]], [coords[0], coords[2]], color=color, linewidth=1, alpha=1) # left
173
                        plt.plot([coords[3], coords[3]], [coords[0], coords[2]], color=color, linewidth=1, alpha=1) # right
174
                        if plot_text:
175
                            plt.text(text_x, text_y, score_text, fontsize=score_font_size, color=text_color)
176
177
    if suptitle is not None:
178
        plt.suptitle(suptitle, fontsize=22)
179
180
    print ("Outfile end"+str(outfile))
181
    try:
182
        plt.savefig(outfile)
183
    except:
184
        raise Warning('failed to save plot.')
185
    plt.close(fig)
186
187
188
189
class TrainingPlot_2Panel():
190
    # todo remove since replaced by tensorboard?
191
192
    def __init__(self, cf):
193
194
        self.file_name = cf.plot_dir + '/monitor_{}'.format(cf.fold)
195
        self.exp_name = cf.fold_dir
196
        self.do_validation = cf.do_validation
197
        self.separate_values_dict = cf.assign_values_to_extra_figure
198
        self.figure_list = []
199
        for n in range(cf.n_monitoring_figures):
200
            self.figure_list.append(plt.figure(figsize=(10, 6)))
201
            self.figure_list[-1].ax1 = plt.subplot(111)
202
            self.figure_list[-1].ax1.set_xlabel('epochs')
203
            self.figure_list[-1].ax1.set_ylabel('loss / metrics')
204
            self.figure_list[-1].ax1.set_xlim(0, cf.num_epochs)
205
            self.figure_list[-1].ax1.grid()
206
207
        self.figure_list[0].ax1.set_ylim(0, 1.5)
208
        self.color_palette = ['b', 'c', 'r', 'purple', 'm', 'y', 'k', 'tab:gray']
209
210
    def update_and_save(self, metrics, epoch):
211
212
        for figure_ix in range(len(self.figure_list)):
213
            fig = self.figure_list[figure_ix]
214
            detection_monitoring_plot(fig.ax1, metrics, self.exp_name, self.color_palette, epoch, figure_ix,
215
                                      self.separate_values_dict,
216
                                      self.do_validation)
217
            fig.savefig(self.file_name + '_{}'.format(figure_ix))
218
219
220
def detection_monitoring_plot(ax1, metrics, exp_name, color_palette, epoch, figure_ix, separate_values_dict, do_validation):
221
    # todo remove since replaced by tensorboard?
222
    monitor_values_keys = metrics['train']['monitor_values'][1][0].keys()
223
    separate_values = [v for fig_ix in separate_values_dict.values() for v in fig_ix]
224
    if figure_ix == 0:
225
        plot_keys = [ii for ii in monitor_values_keys if ii not in separate_values]
226
        plot_keys += [k for k in metrics['train'].keys() if k != 'monitor_values']
227
    else:
228
        plot_keys = separate_values_dict[figure_ix]
229
230
231
    x = np.arange(1, epoch + 1)
232
233
    for kix, pk in enumerate(plot_keys):
234
        if pk in metrics['train'].keys():
235
            y_train = metrics['train'][pk][1:]
236
            if do_validation:
237
                y_val = metrics['val'][pk][1:]
238
        else:
239
            y_train = [np.mean([er[pk] for er in metrics['train']['monitor_values'][e]]) for e in x]
240
            if do_validation:
241
                y_val = [np.mean([er[pk] for er in metrics['val']['monitor_values'][e]]) for e in x]
242
243
        ax1.plot(x, y_train, label='train_{}'.format(pk), linestyle='--', color=color_palette[kix])
244
        if do_validation:
245
            ax1.plot(x, y_val, label='val_{}'.format(pk), linestyle='-', color=color_palette[kix])
246
247
    if epoch == 1:
248
        box = ax1.get_position()
249
        ax1.set_position([box.x0, box.y0, box.width * 0.8, box.height])
250
        ax1.legend(loc='center left', bbox_to_anchor=(1, 0.5))
251
        ax1.set_title(exp_name)
252
253
254
def plot_prediction_hist(label_list: list, pred_list: list, type_list: list, outfile: str):
255
    """
256
    plot histogram of predictions for a specific class.
257
    :param label_list: list of 1s and 0s specifying whether prediction is a true positive match (1) or a false positive (0).
258
    False negatives (missed ground truth objects) are artificially added predictions with score 0 and label 1.
259
    :param pred_list: list of prediction-scores.
260
    :param type_list: list of prediction-types for stastic-info in title.
261
    """
262
    preds = np.array(pred_list)
263
    labels = np.array(label_list)
264
    title = outfile.split('/')[-1] + ' count:{}'.format(len(label_list))
265
    plt.figure()
266
    plt.yscale('log')
267
    if 0 in labels:
268
        plt.hist(preds[labels == 0], alpha=0.3, color='g', range=(0, 1), bins=50, label='false pos.')
269
    if 1 in labels:
270
        plt.hist(preds[labels == 1], alpha=0.3, color='b', range=(0, 1), bins=50, label='true pos. (false neg. @ score=0)')
271
272
    if type_list is not None:
273
        fp_count = type_list.count('det_fp')
274
        fn_count = type_list.count('det_fn')
275
        tp_count = type_list.count('det_tp')
276
        pos_count = fn_count + tp_count
277
        title += ' tp:{} fp:{} fn:{} pos:{}'. format(tp_count, fp_count, fn_count, pos_count)
278
279
    plt.legend()
280
    plt.title(title)
281
    plt.xlabel('confidence score')
282
    plt.ylabel('log n')
283
    plt.savefig(outfile)
284
    plt.close()
285
286
287
def plot_stat_curves(stats: list, outfile: str):
288
289
    for c in ['roc', 'prc']:
290
        plt.figure()
291
        for s in stats:
292
            if not (isinstance(s[c], float) and np.isnan(s[c])):
293
                plt.plot(s[c][0], s[c][1], label=s['name'] + '_' + c)
294
        plt.title(outfile.split('/')[-1] + '_' + c)
295
        plt.legend(loc=3 if c == 'prc' else 4)
296
        plt.xlabel('precision' if c == 'prc' else '1-spec.')
297
        plt.ylabel('recall')
298
        plt.savefig(outfile + '_' + c)
299
        plt.close()