Switch to side-by-side view

--- a
+++ b/scripts/cv_analysis/evaluate.py
@@ -0,0 +1,341 @@
+#==============================================================================#
+#  Author:       Dominik Müller                                                #
+#  Copyright:    2020 IT-Infrastructure for Translational Medical Research,    #
+#                University of Augsburg                                        #
+#                                                                              #
+#  This program is free software: you can redistribute it and/or modify        #
+#  it under the terms of the GNU General Public License as published by        #
+#  the Free Software Foundation, either version 3 of the License, or           #
+#  (at your option) any later version.                                         #
+#                                                                              #
+#  This program is distributed in the hope that it will be useful,             #
+#  but WITHOUT ANY WARRANTY; without even the implied warranty of              #
+#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the               #
+#  GNU General Public License for more details.                                #
+#                                                                              #
+#  You should have received a copy of the GNU General Public License           #
+#  along with this program.  If not, see <http://www.gnu.org/licenses/>.       #
+#==============================================================================#
+#-----------------------------------------------------#
+#                   Library imports                   #
+#-----------------------------------------------------#
+import matplotlib.pyplot as plt
+import matplotlib.animation as animation
+import numpy as np
+import pandas as pd
+import os
+from tqdm import tqdm
+from miscnn.data_loading.interfaces import NIFTI_interface
+from miscnn import Data_IO
+from miscnn.evaluation.cross_validation import load_disk2fold
+from plotnine import *
+import argparse
+
+#-----------------------------------------------------#
+#                      Argparser                      #
+#-----------------------------------------------------#
+parser = argparse.ArgumentParser(description="Automated COVID-19 Segmentation")
+parser.add_argument("-e", "--evaluation", help="Path to evaluation directory",
+                    required=True, type=str, dest="eval")
+parser.add_argument("-p", "--predictions", help="Path to predictions directory",
+                    required=True, type=str, dest="pred")
+args = parser.parse_args()
+eval_path = args.eval
+pred_path = args.pred
+
+#-----------------------------------------------------#
+#                    Visualization                    #
+#-----------------------------------------------------#
+def visualize_evaluation(case_id, vol, truth, pred, eva_path):
+    # Squeeze image files to remove channel axis
+    vol = np.squeeze(vol, axis=-1)
+    truth = np.squeeze(truth, axis=-1)
+    pred = np.squeeze(pred, axis=-1)
+    # Color volumes according to truth and pred segmentation
+    vol_raw = overlay_segmentation(vol, np.zeros(vol.shape))
+    vol_truth = overlay_segmentation(vol, truth)
+    vol_pred = overlay_segmentation(vol, pred)
+    # Create a figure and two axes objects from matplot
+    fig, (ax1, ax2, ax3) = plt.subplots(1, 3)
+    # Initialize the two subplots (axes) with an empty image
+    data = np.zeros(vol.shape[0:2])
+    ax1.set_title("CT Scan")
+    ax2.set_title("Ground Truth")
+    ax3.set_title("Prediction")
+    img1 = ax1.imshow(data)
+    img2 = ax2.imshow(data)
+    img3 = ax3.imshow(data)
+    # Update function for both images to show the slice for the current frame
+    def update(i):
+        plt.suptitle("Case ID: " + str(case_id) + " - " + "Slice: " + str(i))
+        img1.set_data(vol_raw[:,:,i])
+        img2.set_data(vol_truth[:,:,i])
+        img3.set_data(vol_pred[:,:,i])
+        return [img1, img2, img3]
+    # Compute the animation (gif)
+    ani = animation.FuncAnimation(fig, update, frames=truth.shape[2],
+                                  interval=10, repeat_delay=0, blit=False)
+    # Set up the output path for the gif
+    if not os.path.exists(eva_path):
+        os.mkdir(eva_path)
+    file_name = "visualization." + str(case_id).zfill(5) + ".gif"
+    out_path = os.path.join(eva_path, file_name)
+    # Save the animation (gif)
+    ani.save(out_path, writer='imagemagick', fps=20, dpi=150)
+    # Close the matplot
+    plt.close()
+
+# Based on: https://github.com/neheller/kits19/blob/master/starter_code/visualize.py
+def overlay_segmentation(vol, seg):
+    # Clip intensities to -1250 and +250
+    vol = np.clip(vol, -1250, 250)
+    # Scale volume to greyscale range
+    vol_scaled = (vol - np.min(vol)) / (np.max(vol) - np.min(vol))
+    vol_greyscale = np.around(vol_scaled * 255, decimals=0).astype(np.uint8)
+    # Convert volume to RGB
+    vol_rgb = np.stack([vol_greyscale, vol_greyscale, vol_greyscale], axis=-1)
+    # Initialize segmentation in RGB
+    shp = seg.shape
+    seg_rgb = np.zeros((shp[0], shp[1], shp[2], 3), dtype=np.int)
+    # Set class to appropriate color
+    seg_rgb[np.equal(seg, 1)] = [0, 0, 255]
+    seg_rgb[np.equal(seg, 2)] = [0, 0, 255]
+    seg_rgb[np.equal(seg, 3)] = [255, 0, 0]
+    # Get binary array for places where an ROI lives
+    segbin = np.greater(seg, 0)
+    repeated_segbin = np.stack((segbin, segbin, segbin), axis=-1)
+    # Weighted sum where there's a value to overlay
+    alpha = 0.3
+    vol_overlayed = np.where(
+        repeated_segbin,
+        np.round(alpha*seg_rgb+(1-alpha)*vol_rgb).astype(np.uint8),
+        np.round(vol_rgb).astype(np.uint8)
+    )
+    # Return final volume with segmentation overlay
+    return vol_overlayed
+
+#-----------------------------------------------------#
+#                  Score Calculations                 #
+#-----------------------------------------------------#
+def calc_DSC(truth, pred, classes):
+    dice_scores = []
+    # Iterate over each class
+    for i in range(classes):
+        try:
+            gt = np.equal(truth, i)
+            pd = np.equal(pred, i)
+            # Calculate Dice
+            dice = 2*np.logical_and(pd, gt).sum() / (pd.sum() + gt.sum())
+            dice_scores.append(dice)
+        except ZeroDivisionError:
+            dice_scores.append(0.0)
+    # Return computed Dice Similarity Coefficients
+    return dice_scores
+
+def calc_IoU(truth, pred, classes):
+    iou_scores = []
+    # Iterate over each class
+    for i in range(classes):
+        try:
+            gt = np.equal(truth, i)
+            pd = np.equal(pred, i)
+            # Calculate iou
+            iou = np.logical_and(pd, gt).sum() / (pd.sum() + gt.sum() - np.logical_and(pd, gt).sum())
+            iou_scores.append(iou)
+        except ZeroDivisionError:
+            iou_scores.append(0.0)
+    # Return computed IoU
+    return iou_scores
+
+def calc_Sensitivity(truth, pred, classes):
+    sens_scores = []
+    # Iterate over each class
+    for i in range(classes):
+        try:
+            gt = np.equal(truth, i)
+            pd = np.equal(pred, i)
+            # Calculate sensitivity
+            sens = np.logical_and(pd, gt).sum() / gt.sum()
+            sens_scores.append(sens)
+        except ZeroDivisionError:
+            sens_scores.append(0.0)
+    # Return computed sensitivity scores
+    return sens_scores
+
+def calc_Specificity(truth, pred, classes):
+    spec_scores = []
+    # Iterate over each class
+    for i in range(classes):
+        try:
+            not_gt = np.logical_not(np.equal(truth, i))
+            not_pd = np.logical_not(np.equal(pred, i))
+            # Calculate specificity
+            spec = np.logical_and(not_pd, not_gt).sum() / (not_gt).sum()
+            spec_scores.append(spec)
+        except ZeroDivisionError:
+            spec_scores.append(0.0)
+    # Return computed specificity scores
+    return spec_scores
+
+def calc_Accuracy(truth, pred, classes):
+    acc_scores = []
+    # Iterate over each class
+    for i in range(classes):
+        try:
+            gt = np.equal(truth, i)
+            pd = np.equal(pred, i)
+            not_gt = np.logical_not(np.equal(truth, i))
+            not_pd = np.logical_not(np.equal(pred, i))
+            # Calculate accuracy
+            acc = (np.logical_and(pd, gt).sum() + \
+                   np.logical_and(not_pd, not_gt).sum()) /  gt.size
+            acc_scores.append(acc)
+        except ZeroDivisionError:
+            acc_scores.append(0.0)
+    # Return computed accuracy scores
+    return acc_scores
+
+def calc_Precision(truth, pred, classes):
+    prec_scores = []
+    # Iterate over each class
+    for i in range(classes):
+        try:
+            gt = np.equal(truth, i)
+            pd = np.equal(pred, i)
+            # Calculate precision
+            prec = np.logical_and(pd, gt).sum() / pd.sum()
+            prec_scores.append(prec)
+        except ZeroDivisionError:
+            prec_scores.append(0.0)
+    # Return computed precision scores
+    return prec_scores
+
+#-----------------------------------------------------#
+#                      Plotting                       #
+#-----------------------------------------------------#
+def plot_fitting(df_log):
+    # Melt Data Set
+    df_fitting = df_log.melt(id_vars=["epoch"],
+                             value_vars=["loss", "val_loss"],
+                             var_name="Dataset",
+                             value_name="score")
+    # Plot
+    fig = (ggplot(df_fitting, aes("epoch", "score", color="factor(Dataset)"))
+                  + geom_smooth(method="gpr", size=2)
+                  + ggtitle("Fitting Curve during Training")
+                  + xlab("Epoch")
+                  + ylab("Loss Function")
+                  + scale_y_continuous(limits=[0, 3])
+                  + scale_colour_discrete(name="Dataset",
+                                          labels=["Training", "Validation"])
+                  + theme_bw(base_size=28))
+    # # Plot
+    # fig = (ggplot(df_fitting, aes("epoch", "score", color="factor(Dataset)"))
+    #               + geom_line()
+    #               + ggtitle("Fitting Curve during Training")
+    #               + xlab("Epoch")
+    #               + ylab("Loss Function")
+    #               + theme_bw())
+    fig.save(filename="fitting_curve.png", path=eval_path,
+             width=12, height=10, dpi=300)
+
+#-----------------------------------------------------#
+#                    Run Evaluation                   #
+#-----------------------------------------------------#
+# Initialize Data IO Interface for NIfTI data
+## We are using 4 classes due to [background, lung_left, lung_right, covid-19]
+interface = NIFTI_interface(channels=1, classes=4)
+
+# Create Data IO object to load and write samples in the file structure
+data_io = Data_IO(interface, input_path="data", output_path=pred_path)
+
+# Access all available samples in our file structure
+sample_list = data_io.get_indiceslist()
+sample_list.sort()
+
+# Initialize dataframe
+cols = ["index", "score", "background", "lung_L", "lung_R", "infection"]
+df = pd.DataFrame(data=[], dtype=np.float64, columns=cols)
+
+# Iterate over each sample
+for index in tqdm(sample_list):
+    # Load a sample including its image, ground truth and prediction
+    sample = data_io.sample_loader(index, load_seg=True, load_pred=True)
+    # Access image, ground truth and prediction data
+    image = sample.img_data
+    truth = sample.seg_data
+    pred = sample.pred_data
+    # Compute diverse Scores
+    dsc = calc_DSC(truth, pred, classes=4)
+    df = df.append(pd.Series([index, "DSC"] + dsc, index=cols),
+                   ignore_index=True)
+    iou = calc_IoU(truth, pred, classes=4)
+    df = df.append(pd.Series([index, "IoU"] + iou, index=cols),
+                   ignore_index=True)
+    sens = calc_Sensitivity(truth, pred, classes=4)
+    df = df.append(pd.Series([index, "Sens"] + sens, index=cols),
+                   ignore_index=True)
+    spec = calc_Specificity(truth, pred, classes=4)
+    df = df.append(pd.Series([index, "Spec"] + spec, index=cols),
+                   ignore_index=True)
+    prec = calc_Precision(truth, pred, classes=4)
+    df = df.append(pd.Series([index, "Prec"] + prec, index=cols),
+                   ignore_index=True)
+    acc = calc_Accuracy(truth, pred, classes=4)
+    df = df.append(pd.Series([index, "Acc"] + acc, index=cols),
+                   ignore_index=True)
+    # Compute Visualization
+    # visualize_evaluation(index, image, truth, pred, os.path.join(eval_path, "visualization"))
+
+# Output complete dataframe
+print(df)
+# Store complete dataframe to disk
+path_res_detailed = os.path.join(eval_path, "cv_results.detailed.csv")
+df.to_csv(path_res_detailed, index=False)
+
+# Initialize fitting logging dataframe
+cols = ["epoch", "dice_crossentropy", "dice_soft", "loss", "lr", "tversky_loss",
+        "val_dice_crossentropy", "val_dice_soft", "val_loss","val_tversky_loss",
+        "fold"]
+df_log = pd.DataFrame(data=[], dtype=np.float64, columns=cols)
+cols_val = ["score", "background", "infection", "lungs", "fold"]
+df_global = pd.DataFrame(data=[], dtype=np.float64, columns=cols_val)
+
+# Compute per-fold scores
+for fold in os.listdir(eval_path):
+    # Skip all files in evaluation which are not cross-validation dirs
+    if not fold.startswith("fold_") : continue
+    # Identify validation samples of this fold
+    _, val_list = load_disk2fold(os.path.join(eval_path, fold, "sample_list.json"))
+    # Obtain metrics for validation list
+    df_val = df.loc[df["index"].isin(val_list)]
+    # Print out average and std evaluation metrics for the current fold
+    df_avg = df_val.groupby(by="score").mean()
+    df_std = df_val.groupby(by="score").std()
+    print(fold)
+    print(df_avg)
+    print(df_std)
+    # Compute average metrics for validation list
+    df_val = df_val.groupby(by="score").median()
+    # Combine lung left and lung right class by mean
+    df_val["lungs"] = df_val[["lung_L", "lung_R"]].mean(axis=1)
+    df_val.drop(["lung_L", "lung_R"], axis=1, inplace=True)
+    # Add df_val df to df_global
+    df_val["fold"] = fold
+    df_val = df_val.reset_index()
+    df_global = df_global.append(df_val, ignore_index=True)
+    # Load logging data for fitting plot
+    path_log = os.path.join(eval_path, fold, "history.tsv")
+    df_logfold = pd.read_csv(path_log, header=0, sep="\t")
+    df_logfold["fold"] = fold
+    # Add logging data to global fitting log dataframe
+    df_log = df_log.append(df_logfold, ignore_index=True)
+
+# Run plotting of fitting process
+plot_fitting(df_log)
+
+# Output cross-validation results
+print(df_global)
+# Save cross-validation results to disk
+path_res_global = os.path.join(eval_path, "cv_results.final.csv")
+df_global.to_csv(path_res_global, index=False)