--- a
+++ b/plotting.py
@@ -0,0 +1,299 @@
+#!/usr/bin/env python
+# Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from typing import Union
+
+import matplotlib
+matplotlib.use('Agg')
+import matplotlib.pyplot as plt
+import matplotlib.gridspec as gridspec
+import numpy as np
+import os
+from copy import deepcopy
+import pickle
+
+def suppress_axes_lines(ax):
+    """
+    :param ax: pyplot axes object
+    """
+    ax.axes.get_xaxis().set_ticks([])
+    ax.axes.get_yaxis().set_ticks([])
+    ax.spines['top'].set_visible(False)
+    ax.spines['right'].set_visible(False)
+    ax.spines['bottom'].set_visible(False)
+    ax.spines['left'].set_visible(False)
+
+    return
+
+def plot_batch_prediction(batch: dict, results_dict: dict, cf, outfile: Union[str, None]=None,
+                          suptitle: Union[str, None]=None):
+    """
+    plot the input images, ground truth annotations, and output predictions of a batch. If 3D batch, plots a 2D projection
+    of one randomly sampled element (patient) in the batch. Since plotting all slices of patient volume blows up costs of
+    time and space, only a section containing a randomly sampled ground truth annotation is plotted.
+    :param batch: dict with keys: 'data' (input image), 'seg' (pixelwise annotations), 'pid'
+    :param results_dict: list over batch element. Each element is a list of boxes (prediction and ground truth),
+    where every box is a dictionary containing box_coords, box_score and box_type.
+    """
+    print ("Outfile beginning"+str(outfile))
+    if outfile is None:
+        outfile = os.path.join(cf.plot_dir, 'pred_example_{}.png'.format(cf.fold))
+
+    data = batch['data']
+    segs = batch['seg']
+    pids = batch['pid']
+    # for 3D, repeat pid over batch elements.
+    if len(set(pids)) == 1:
+        pids = [pids] * data.shape[0]
+
+    seg_preds = results_dict['seg_preds']
+    roi_results = deepcopy(results_dict['boxes'])
+
+    # Randomly sampled one patient of batch and project data into 2D slices for plotting.
+    if cf.dim == 3:
+        patient_ix = np.random.choice(data.shape[0])
+        data = np.transpose(data[patient_ix], axes=(3, 0, 1, 2))
+
+        # select interesting foreground section to plot.
+        gt_boxes = [box['box_coords'] for box in roi_results[patient_ix] if box['box_type'] == 'gt']
+        if len(gt_boxes) > 0:
+            z_cuts = [np.max((int(gt_boxes[0][4]) - 5, 0)), np.min((int(gt_boxes[0][5]) + 5, data.shape[0]))]
+        else:
+            z_cuts = [data.shape[0]//2 - 5, int(data.shape[0]//2 + np.min([10, data.shape[0]//2]))]
+        p_roi_results = roi_results[patient_ix]
+        roi_results = [[] for _ in range(data.shape[0])]
+
+        # iterate over cubes and spread across slices.
+        for box in p_roi_results:
+            b = box['box_coords']
+            # dismiss negative anchor slices.
+            slices = np.round(np.unique(np.clip(np.arange(b[4], b[5] + 1), 0, data.shape[0]-1)))
+            for s in slices:
+                roi_results[int(s)].append(box)
+                roi_results[int(s)][-1]['box_coords'] = b[:4]
+
+        roi_results = roi_results[z_cuts[0]: z_cuts[1]]
+        data = data[z_cuts[0]: z_cuts[1]]
+        segs = np.transpose(segs[patient_ix], axes=(3, 0, 1, 2))[z_cuts[0]: z_cuts[1]]
+        seg_preds = np.transpose(seg_preds[patient_ix], axes=(3, 0, 1, 2))[z_cuts[0]: z_cuts[1]]
+        pids = [pids[patient_ix]] * data.shape[0]
+
+    try:
+        # all dimensions except for the 'channel-dimension' are required to match
+        for i in [0, 2, 3]:
+            assert data.shape[i] == segs.shape[i] == seg_preds.shape[i]
+    except:
+        raise Warning('Shapes of arrays to plot not in agreement!'
+                      'Shapes {} vs. {} vs {}'.format(data.shape, segs.shape, seg_preds.shape))
+
+
+    show_arrays = np.concatenate([data, segs, seg_preds, data[:, 0][:, None]], axis=1).astype(float)
+    approx_figshape = (4 * show_arrays.shape[0], 4 * show_arrays.shape[1])
+    fig = plt.figure(figsize=approx_figshape)
+    gs = gridspec.GridSpec(show_arrays.shape[1] + 1, show_arrays.shape[0])
+    gs.update(wspace=0.1, hspace=0.1)
+    for b in range(show_arrays.shape[0]):
+        for m in range(show_arrays.shape[1]):
+
+            ax = plt.subplot(gs[m, b])
+            suppress_axes_lines(ax)
+            if m < show_arrays.shape[1]:
+                arr = show_arrays[b, m]
+
+            if m < data.shape[1] or m == show_arrays.shape[1] - 1:
+                if b == 0:
+                    ax.set_ylabel("Input" + (" + GT & Pred Box" if m == show_arrays.shape[1] - 1 else ""))
+                cmap = 'gray'
+                vmin = None
+                vmax = None
+            else:
+                cmap = None
+                vmin = 0
+                vmax = cf.num_seg_classes - 1
+
+            if m == 0:
+                plt.title('{}'.format(pids[b][:10]), fontsize=20)
+
+            plt.imshow(arr, cmap=cmap, vmin=vmin, vmax=vmax)
+            if m >= (data.shape[1]):
+                if b == 0:
+                    if m == data.shape[1]:
+                        ax.set_ylabel("GT Box & Seg")
+                    if m == data.shape[1]+1:
+                        ax.set_ylabel("GT Box + Pred Seg & Box")
+                for box in roi_results[b]:
+                    if box['box_type'] != 'patient_tn_box': # don't plot true negative dummy boxes.
+                        coords = box['box_coords']
+                        if box['box_type'] == 'det':
+                            # dont plot background preds or low confidence boxes.
+                            if box['box_pred_class_id'] > 0 and box['box_score'] > 0.1:
+                                plot_text = True
+                                score = np.max(box['box_score'])
+                                score_text = '{}|{:.0f}'.format(box['box_pred_class_id'], score*100)
+                                # if prob detection: plot only boxes from correct sampling instance.
+                                if 'sample_id' in box.keys() and int(box['sample_id']) != m - data.shape[1] - 2:
+                                        continue
+                                # if prob detection: plot reconstructed boxes only in corresponding line.
+                                if not 'sample_id' in box.keys() and  m != data.shape[1] + 1:
+                                    continue
+
+                                score_font_size = 7
+                                text_color = 'w'
+                                text_x = coords[1] + 10*(box['box_pred_class_id'] -1) #avoid overlap of scores in plot.
+                                text_y = coords[2] + 5
+                            else:
+                                continue
+                        elif box['box_type'] == 'gt':
+                            plot_text = True
+                            score_text = int(box['box_label'])
+                            score_font_size = 7
+                            text_color = 'r'
+                            text_x = coords[1]
+                            text_y = coords[0] - 1
+                        else:
+                            plot_text = False
+
+                        color_var = 'extra_usage' if 'extra_usage' in list(box.keys()) else 'box_type'
+                        color = cf.box_color_palette[box[color_var]]
+                        plt.plot([coords[1], coords[3]], [coords[0], coords[0]], color=color, linewidth=1, alpha=1) # up
+                        plt.plot([coords[1], coords[3]], [coords[2], coords[2]], color=color, linewidth=1, alpha=1) # down
+                        plt.plot([coords[1], coords[1]], [coords[0], coords[2]], color=color, linewidth=1, alpha=1) # left
+                        plt.plot([coords[3], coords[3]], [coords[0], coords[2]], color=color, linewidth=1, alpha=1) # right
+                        if plot_text:
+                            plt.text(text_x, text_y, score_text, fontsize=score_font_size, color=text_color)
+
+    if suptitle is not None:
+        plt.suptitle(suptitle, fontsize=22)
+
+    print ("Outfile end"+str(outfile))
+    try:
+        plt.savefig(outfile)
+    except:
+        raise Warning('failed to save plot.')
+    plt.close(fig)
+
+
+
+class TrainingPlot_2Panel():
+    # todo remove since replaced by tensorboard?
+
+    def __init__(self, cf):
+
+        self.file_name = cf.plot_dir + '/monitor_{}'.format(cf.fold)
+        self.exp_name = cf.fold_dir
+        self.do_validation = cf.do_validation
+        self.separate_values_dict = cf.assign_values_to_extra_figure
+        self.figure_list = []
+        for n in range(cf.n_monitoring_figures):
+            self.figure_list.append(plt.figure(figsize=(10, 6)))
+            self.figure_list[-1].ax1 = plt.subplot(111)
+            self.figure_list[-1].ax1.set_xlabel('epochs')
+            self.figure_list[-1].ax1.set_ylabel('loss / metrics')
+            self.figure_list[-1].ax1.set_xlim(0, cf.num_epochs)
+            self.figure_list[-1].ax1.grid()
+
+        self.figure_list[0].ax1.set_ylim(0, 1.5)
+        self.color_palette = ['b', 'c', 'r', 'purple', 'm', 'y', 'k', 'tab:gray']
+
+    def update_and_save(self, metrics, epoch):
+
+        for figure_ix in range(len(self.figure_list)):
+            fig = self.figure_list[figure_ix]
+            detection_monitoring_plot(fig.ax1, metrics, self.exp_name, self.color_palette, epoch, figure_ix,
+                                      self.separate_values_dict,
+                                      self.do_validation)
+            fig.savefig(self.file_name + '_{}'.format(figure_ix))
+
+
+def detection_monitoring_plot(ax1, metrics, exp_name, color_palette, epoch, figure_ix, separate_values_dict, do_validation):
+    # todo remove since replaced by tensorboard?
+    monitor_values_keys = metrics['train']['monitor_values'][1][0].keys()
+    separate_values = [v for fig_ix in separate_values_dict.values() for v in fig_ix]
+    if figure_ix == 0:
+        plot_keys = [ii for ii in monitor_values_keys if ii not in separate_values]
+        plot_keys += [k for k in metrics['train'].keys() if k != 'monitor_values']
+    else:
+        plot_keys = separate_values_dict[figure_ix]
+
+
+    x = np.arange(1, epoch + 1)
+
+    for kix, pk in enumerate(plot_keys):
+        if pk in metrics['train'].keys():
+            y_train = metrics['train'][pk][1:]
+            if do_validation:
+                y_val = metrics['val'][pk][1:]
+        else:
+            y_train = [np.mean([er[pk] for er in metrics['train']['monitor_values'][e]]) for e in x]
+            if do_validation:
+                y_val = [np.mean([er[pk] for er in metrics['val']['monitor_values'][e]]) for e in x]
+
+        ax1.plot(x, y_train, label='train_{}'.format(pk), linestyle='--', color=color_palette[kix])
+        if do_validation:
+            ax1.plot(x, y_val, label='val_{}'.format(pk), linestyle='-', color=color_palette[kix])
+
+    if epoch == 1:
+        box = ax1.get_position()
+        ax1.set_position([box.x0, box.y0, box.width * 0.8, box.height])
+        ax1.legend(loc='center left', bbox_to_anchor=(1, 0.5))
+        ax1.set_title(exp_name)
+
+
+def plot_prediction_hist(label_list: list, pred_list: list, type_list: list, outfile: str):
+    """
+    plot histogram of predictions for a specific class.
+    :param label_list: list of 1s and 0s specifying whether prediction is a true positive match (1) or a false positive (0).
+    False negatives (missed ground truth objects) are artificially added predictions with score 0 and label 1.
+    :param pred_list: list of prediction-scores.
+    :param type_list: list of prediction-types for stastic-info in title.
+    """
+    preds = np.array(pred_list)
+    labels = np.array(label_list)
+    title = outfile.split('/')[-1] + ' count:{}'.format(len(label_list))
+    plt.figure()
+    plt.yscale('log')
+    if 0 in labels:
+        plt.hist(preds[labels == 0], alpha=0.3, color='g', range=(0, 1), bins=50, label='false pos.')
+    if 1 in labels:
+        plt.hist(preds[labels == 1], alpha=0.3, color='b', range=(0, 1), bins=50, label='true pos. (false neg. @ score=0)')
+
+    if type_list is not None:
+        fp_count = type_list.count('det_fp')
+        fn_count = type_list.count('det_fn')
+        tp_count = type_list.count('det_tp')
+        pos_count = fn_count + tp_count
+        title += ' tp:{} fp:{} fn:{} pos:{}'. format(tp_count, fp_count, fn_count, pos_count)
+
+    plt.legend()
+    plt.title(title)
+    plt.xlabel('confidence score')
+    plt.ylabel('log n')
+    plt.savefig(outfile)
+    plt.close()
+
+
+def plot_stat_curves(stats: list, outfile: str):
+
+    for c in ['roc', 'prc']:
+        plt.figure()
+        for s in stats:
+            if not (isinstance(s[c], float) and np.isnan(s[c])):
+                plt.plot(s[c][0], s[c][1], label=s['name'] + '_' + c)
+        plt.title(outfile.split('/')[-1] + '_' + c)
+        plt.legend(loc=3 if c == 'prc' else 4)
+        plt.xlabel('precision' if c == 'prc' else '1-spec.')
+        plt.ylabel('recall')
+        plt.savefig(outfile + '_' + c)
+        plt.close()