--- a +++ b/plotting.py @@ -0,0 +1,299 @@ +#!/usr/bin/env python +# Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ). +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from typing import Union + +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt +import matplotlib.gridspec as gridspec +import numpy as np +import os +from copy import deepcopy +import pickle + +def suppress_axes_lines(ax): + """ + :param ax: pyplot axes object + """ + ax.axes.get_xaxis().set_ticks([]) + ax.axes.get_yaxis().set_ticks([]) + ax.spines['top'].set_visible(False) + ax.spines['right'].set_visible(False) + ax.spines['bottom'].set_visible(False) + ax.spines['left'].set_visible(False) + + return + +def plot_batch_prediction(batch: dict, results_dict: dict, cf, outfile: Union[str, None]=None, + suptitle: Union[str, None]=None): + """ + plot the input images, ground truth annotations, and output predictions of a batch. If 3D batch, plots a 2D projection + of one randomly sampled element (patient) in the batch. Since plotting all slices of patient volume blows up costs of + time and space, only a section containing a randomly sampled ground truth annotation is plotted. + :param batch: dict with keys: 'data' (input image), 'seg' (pixelwise annotations), 'pid' + :param results_dict: list over batch element. Each element is a list of boxes (prediction and ground truth), + where every box is a dictionary containing box_coords, box_score and box_type. + """ + print ("Outfile beginning"+str(outfile)) + if outfile is None: + outfile = os.path.join(cf.plot_dir, 'pred_example_{}.png'.format(cf.fold)) + + data = batch['data'] + segs = batch['seg'] + pids = batch['pid'] + # for 3D, repeat pid over batch elements. + if len(set(pids)) == 1: + pids = [pids] * data.shape[0] + + seg_preds = results_dict['seg_preds'] + roi_results = deepcopy(results_dict['boxes']) + + # Randomly sampled one patient of batch and project data into 2D slices for plotting. + if cf.dim == 3: + patient_ix = np.random.choice(data.shape[0]) + data = np.transpose(data[patient_ix], axes=(3, 0, 1, 2)) + + # select interesting foreground section to plot. + gt_boxes = [box['box_coords'] for box in roi_results[patient_ix] if box['box_type'] == 'gt'] + if len(gt_boxes) > 0: + z_cuts = [np.max((int(gt_boxes[0][4]) - 5, 0)), np.min((int(gt_boxes[0][5]) + 5, data.shape[0]))] + else: + z_cuts = [data.shape[0]//2 - 5, int(data.shape[0]//2 + np.min([10, data.shape[0]//2]))] + p_roi_results = roi_results[patient_ix] + roi_results = [[] for _ in range(data.shape[0])] + + # iterate over cubes and spread across slices. + for box in p_roi_results: + b = box['box_coords'] + # dismiss negative anchor slices. + slices = np.round(np.unique(np.clip(np.arange(b[4], b[5] + 1), 0, data.shape[0]-1))) + for s in slices: + roi_results[int(s)].append(box) + roi_results[int(s)][-1]['box_coords'] = b[:4] + + roi_results = roi_results[z_cuts[0]: z_cuts[1]] + data = data[z_cuts[0]: z_cuts[1]] + segs = np.transpose(segs[patient_ix], axes=(3, 0, 1, 2))[z_cuts[0]: z_cuts[1]] + seg_preds = np.transpose(seg_preds[patient_ix], axes=(3, 0, 1, 2))[z_cuts[0]: z_cuts[1]] + pids = [pids[patient_ix]] * data.shape[0] + + try: + # all dimensions except for the 'channel-dimension' are required to match + for i in [0, 2, 3]: + assert data.shape[i] == segs.shape[i] == seg_preds.shape[i] + except: + raise Warning('Shapes of arrays to plot not in agreement!' + 'Shapes {} vs. {} vs {}'.format(data.shape, segs.shape, seg_preds.shape)) + + + show_arrays = np.concatenate([data, segs, seg_preds, data[:, 0][:, None]], axis=1).astype(float) + approx_figshape = (4 * show_arrays.shape[0], 4 * show_arrays.shape[1]) + fig = plt.figure(figsize=approx_figshape) + gs = gridspec.GridSpec(show_arrays.shape[1] + 1, show_arrays.shape[0]) + gs.update(wspace=0.1, hspace=0.1) + for b in range(show_arrays.shape[0]): + for m in range(show_arrays.shape[1]): + + ax = plt.subplot(gs[m, b]) + suppress_axes_lines(ax) + if m < show_arrays.shape[1]: + arr = show_arrays[b, m] + + if m < data.shape[1] or m == show_arrays.shape[1] - 1: + if b == 0: + ax.set_ylabel("Input" + (" + GT & Pred Box" if m == show_arrays.shape[1] - 1 else "")) + cmap = 'gray' + vmin = None + vmax = None + else: + cmap = None + vmin = 0 + vmax = cf.num_seg_classes - 1 + + if m == 0: + plt.title('{}'.format(pids[b][:10]), fontsize=20) + + plt.imshow(arr, cmap=cmap, vmin=vmin, vmax=vmax) + if m >= (data.shape[1]): + if b == 0: + if m == data.shape[1]: + ax.set_ylabel("GT Box & Seg") + if m == data.shape[1]+1: + ax.set_ylabel("GT Box + Pred Seg & Box") + for box in roi_results[b]: + if box['box_type'] != 'patient_tn_box': # don't plot true negative dummy boxes. + coords = box['box_coords'] + if box['box_type'] == 'det': + # dont plot background preds or low confidence boxes. + if box['box_pred_class_id'] > 0 and box['box_score'] > 0.1: + plot_text = True + score = np.max(box['box_score']) + score_text = '{}|{:.0f}'.format(box['box_pred_class_id'], score*100) + # if prob detection: plot only boxes from correct sampling instance. + if 'sample_id' in box.keys() and int(box['sample_id']) != m - data.shape[1] - 2: + continue + # if prob detection: plot reconstructed boxes only in corresponding line. + if not 'sample_id' in box.keys() and m != data.shape[1] + 1: + continue + + score_font_size = 7 + text_color = 'w' + text_x = coords[1] + 10*(box['box_pred_class_id'] -1) #avoid overlap of scores in plot. + text_y = coords[2] + 5 + else: + continue + elif box['box_type'] == 'gt': + plot_text = True + score_text = int(box['box_label']) + score_font_size = 7 + text_color = 'r' + text_x = coords[1] + text_y = coords[0] - 1 + else: + plot_text = False + + color_var = 'extra_usage' if 'extra_usage' in list(box.keys()) else 'box_type' + color = cf.box_color_palette[box[color_var]] + plt.plot([coords[1], coords[3]], [coords[0], coords[0]], color=color, linewidth=1, alpha=1) # up + plt.plot([coords[1], coords[3]], [coords[2], coords[2]], color=color, linewidth=1, alpha=1) # down + plt.plot([coords[1], coords[1]], [coords[0], coords[2]], color=color, linewidth=1, alpha=1) # left + plt.plot([coords[3], coords[3]], [coords[0], coords[2]], color=color, linewidth=1, alpha=1) # right + if plot_text: + plt.text(text_x, text_y, score_text, fontsize=score_font_size, color=text_color) + + if suptitle is not None: + plt.suptitle(suptitle, fontsize=22) + + print ("Outfile end"+str(outfile)) + try: + plt.savefig(outfile) + except: + raise Warning('failed to save plot.') + plt.close(fig) + + + +class TrainingPlot_2Panel(): + # todo remove since replaced by tensorboard? + + def __init__(self, cf): + + self.file_name = cf.plot_dir + '/monitor_{}'.format(cf.fold) + self.exp_name = cf.fold_dir + self.do_validation = cf.do_validation + self.separate_values_dict = cf.assign_values_to_extra_figure + self.figure_list = [] + for n in range(cf.n_monitoring_figures): + self.figure_list.append(plt.figure(figsize=(10, 6))) + self.figure_list[-1].ax1 = plt.subplot(111) + self.figure_list[-1].ax1.set_xlabel('epochs') + self.figure_list[-1].ax1.set_ylabel('loss / metrics') + self.figure_list[-1].ax1.set_xlim(0, cf.num_epochs) + self.figure_list[-1].ax1.grid() + + self.figure_list[0].ax1.set_ylim(0, 1.5) + self.color_palette = ['b', 'c', 'r', 'purple', 'm', 'y', 'k', 'tab:gray'] + + def update_and_save(self, metrics, epoch): + + for figure_ix in range(len(self.figure_list)): + fig = self.figure_list[figure_ix] + detection_monitoring_plot(fig.ax1, metrics, self.exp_name, self.color_palette, epoch, figure_ix, + self.separate_values_dict, + self.do_validation) + fig.savefig(self.file_name + '_{}'.format(figure_ix)) + + +def detection_monitoring_plot(ax1, metrics, exp_name, color_palette, epoch, figure_ix, separate_values_dict, do_validation): + # todo remove since replaced by tensorboard? + monitor_values_keys = metrics['train']['monitor_values'][1][0].keys() + separate_values = [v for fig_ix in separate_values_dict.values() for v in fig_ix] + if figure_ix == 0: + plot_keys = [ii for ii in monitor_values_keys if ii not in separate_values] + plot_keys += [k for k in metrics['train'].keys() if k != 'monitor_values'] + else: + plot_keys = separate_values_dict[figure_ix] + + + x = np.arange(1, epoch + 1) + + for kix, pk in enumerate(plot_keys): + if pk in metrics['train'].keys(): + y_train = metrics['train'][pk][1:] + if do_validation: + y_val = metrics['val'][pk][1:] + else: + y_train = [np.mean([er[pk] for er in metrics['train']['monitor_values'][e]]) for e in x] + if do_validation: + y_val = [np.mean([er[pk] for er in metrics['val']['monitor_values'][e]]) for e in x] + + ax1.plot(x, y_train, label='train_{}'.format(pk), linestyle='--', color=color_palette[kix]) + if do_validation: + ax1.plot(x, y_val, label='val_{}'.format(pk), linestyle='-', color=color_palette[kix]) + + if epoch == 1: + box = ax1.get_position() + ax1.set_position([box.x0, box.y0, box.width * 0.8, box.height]) + ax1.legend(loc='center left', bbox_to_anchor=(1, 0.5)) + ax1.set_title(exp_name) + + +def plot_prediction_hist(label_list: list, pred_list: list, type_list: list, outfile: str): + """ + plot histogram of predictions for a specific class. + :param label_list: list of 1s and 0s specifying whether prediction is a true positive match (1) or a false positive (0). + False negatives (missed ground truth objects) are artificially added predictions with score 0 and label 1. + :param pred_list: list of prediction-scores. + :param type_list: list of prediction-types for stastic-info in title. + """ + preds = np.array(pred_list) + labels = np.array(label_list) + title = outfile.split('/')[-1] + ' count:{}'.format(len(label_list)) + plt.figure() + plt.yscale('log') + if 0 in labels: + plt.hist(preds[labels == 0], alpha=0.3, color='g', range=(0, 1), bins=50, label='false pos.') + if 1 in labels: + plt.hist(preds[labels == 1], alpha=0.3, color='b', range=(0, 1), bins=50, label='true pos. (false neg. @ score=0)') + + if type_list is not None: + fp_count = type_list.count('det_fp') + fn_count = type_list.count('det_fn') + tp_count = type_list.count('det_tp') + pos_count = fn_count + tp_count + title += ' tp:{} fp:{} fn:{} pos:{}'. format(tp_count, fp_count, fn_count, pos_count) + + plt.legend() + plt.title(title) + plt.xlabel('confidence score') + plt.ylabel('log n') + plt.savefig(outfile) + plt.close() + + +def plot_stat_curves(stats: list, outfile: str): + + for c in ['roc', 'prc']: + plt.figure() + for s in stats: + if not (isinstance(s[c], float) and np.isnan(s[c])): + plt.plot(s[c][0], s[c][1], label=s['name'] + '_' + c) + plt.title(outfile.split('/')[-1] + '_' + c) + plt.legend(loc=3 if c == 'prc' else 4) + plt.xlabel('precision' if c == 'prc' else '1-spec.') + plt.ylabel('recall') + plt.savefig(outfile + '_' + c) + plt.close()