--- a +++ b/scripts/cv_analysis/evaluate.py @@ -0,0 +1,341 @@ +#==============================================================================# +# Author: Dominik Müller # +# Copyright: 2020 IT-Infrastructure for Translational Medical Research, # +# University of Augsburg # +# # +# This program is free software: you can redistribute it and/or modify # +# it under the terms of the GNU General Public License as published by # +# the Free Software Foundation, either version 3 of the License, or # +# (at your option) any later version. # +# # +# This program is distributed in the hope that it will be useful, # +# but WITHOUT ANY WARRANTY; without even the implied warranty of # +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # +# GNU General Public License for more details. # +# # +# You should have received a copy of the GNU General Public License # +# along with this program. If not, see <http://www.gnu.org/licenses/>. # +#==============================================================================# +#-----------------------------------------------------# +# Library imports # +#-----------------------------------------------------# +import matplotlib.pyplot as plt +import matplotlib.animation as animation +import numpy as np +import pandas as pd +import os +from tqdm import tqdm +from miscnn.data_loading.interfaces import NIFTI_interface +from miscnn import Data_IO +from miscnn.evaluation.cross_validation import load_disk2fold +from plotnine import * +import argparse + +#-----------------------------------------------------# +# Argparser # +#-----------------------------------------------------# +parser = argparse.ArgumentParser(description="Automated COVID-19 Segmentation") +parser.add_argument("-e", "--evaluation", help="Path to evaluation directory", + required=True, type=str, dest="eval") +parser.add_argument("-p", "--predictions", help="Path to predictions directory", + required=True, type=str, dest="pred") +args = parser.parse_args() +eval_path = args.eval +pred_path = args.pred + +#-----------------------------------------------------# +# Visualization # +#-----------------------------------------------------# +def visualize_evaluation(case_id, vol, truth, pred, eva_path): + # Squeeze image files to remove channel axis + vol = np.squeeze(vol, axis=-1) + truth = np.squeeze(truth, axis=-1) + pred = np.squeeze(pred, axis=-1) + # Color volumes according to truth and pred segmentation + vol_raw = overlay_segmentation(vol, np.zeros(vol.shape)) + vol_truth = overlay_segmentation(vol, truth) + vol_pred = overlay_segmentation(vol, pred) + # Create a figure and two axes objects from matplot + fig, (ax1, ax2, ax3) = plt.subplots(1, 3) + # Initialize the two subplots (axes) with an empty image + data = np.zeros(vol.shape[0:2]) + ax1.set_title("CT Scan") + ax2.set_title("Ground Truth") + ax3.set_title("Prediction") + img1 = ax1.imshow(data) + img2 = ax2.imshow(data) + img3 = ax3.imshow(data) + # Update function for both images to show the slice for the current frame + def update(i): + plt.suptitle("Case ID: " + str(case_id) + " - " + "Slice: " + str(i)) + img1.set_data(vol_raw[:,:,i]) + img2.set_data(vol_truth[:,:,i]) + img3.set_data(vol_pred[:,:,i]) + return [img1, img2, img3] + # Compute the animation (gif) + ani = animation.FuncAnimation(fig, update, frames=truth.shape[2], + interval=10, repeat_delay=0, blit=False) + # Set up the output path for the gif + if not os.path.exists(eva_path): + os.mkdir(eva_path) + file_name = "visualization." + str(case_id).zfill(5) + ".gif" + out_path = os.path.join(eva_path, file_name) + # Save the animation (gif) + ani.save(out_path, writer='imagemagick', fps=20, dpi=150) + # Close the matplot + plt.close() + +# Based on: https://github.com/neheller/kits19/blob/master/starter_code/visualize.py +def overlay_segmentation(vol, seg): + # Clip intensities to -1250 and +250 + vol = np.clip(vol, -1250, 250) + # Scale volume to greyscale range + vol_scaled = (vol - np.min(vol)) / (np.max(vol) - np.min(vol)) + vol_greyscale = np.around(vol_scaled * 255, decimals=0).astype(np.uint8) + # Convert volume to RGB + vol_rgb = np.stack([vol_greyscale, vol_greyscale, vol_greyscale], axis=-1) + # Initialize segmentation in RGB + shp = seg.shape + seg_rgb = np.zeros((shp[0], shp[1], shp[2], 3), dtype=np.int) + # Set class to appropriate color + seg_rgb[np.equal(seg, 1)] = [0, 0, 255] + seg_rgb[np.equal(seg, 2)] = [0, 0, 255] + seg_rgb[np.equal(seg, 3)] = [255, 0, 0] + # Get binary array for places where an ROI lives + segbin = np.greater(seg, 0) + repeated_segbin = np.stack((segbin, segbin, segbin), axis=-1) + # Weighted sum where there's a value to overlay + alpha = 0.3 + vol_overlayed = np.where( + repeated_segbin, + np.round(alpha*seg_rgb+(1-alpha)*vol_rgb).astype(np.uint8), + np.round(vol_rgb).astype(np.uint8) + ) + # Return final volume with segmentation overlay + return vol_overlayed + +#-----------------------------------------------------# +# Score Calculations # +#-----------------------------------------------------# +def calc_DSC(truth, pred, classes): + dice_scores = [] + # Iterate over each class + for i in range(classes): + try: + gt = np.equal(truth, i) + pd = np.equal(pred, i) + # Calculate Dice + dice = 2*np.logical_and(pd, gt).sum() / (pd.sum() + gt.sum()) + dice_scores.append(dice) + except ZeroDivisionError: + dice_scores.append(0.0) + # Return computed Dice Similarity Coefficients + return dice_scores + +def calc_IoU(truth, pred, classes): + iou_scores = [] + # Iterate over each class + for i in range(classes): + try: + gt = np.equal(truth, i) + pd = np.equal(pred, i) + # Calculate iou + iou = np.logical_and(pd, gt).sum() / (pd.sum() + gt.sum() - np.logical_and(pd, gt).sum()) + iou_scores.append(iou) + except ZeroDivisionError: + iou_scores.append(0.0) + # Return computed IoU + return iou_scores + +def calc_Sensitivity(truth, pred, classes): + sens_scores = [] + # Iterate over each class + for i in range(classes): + try: + gt = np.equal(truth, i) + pd = np.equal(pred, i) + # Calculate sensitivity + sens = np.logical_and(pd, gt).sum() / gt.sum() + sens_scores.append(sens) + except ZeroDivisionError: + sens_scores.append(0.0) + # Return computed sensitivity scores + return sens_scores + +def calc_Specificity(truth, pred, classes): + spec_scores = [] + # Iterate over each class + for i in range(classes): + try: + not_gt = np.logical_not(np.equal(truth, i)) + not_pd = np.logical_not(np.equal(pred, i)) + # Calculate specificity + spec = np.logical_and(not_pd, not_gt).sum() / (not_gt).sum() + spec_scores.append(spec) + except ZeroDivisionError: + spec_scores.append(0.0) + # Return computed specificity scores + return spec_scores + +def calc_Accuracy(truth, pred, classes): + acc_scores = [] + # Iterate over each class + for i in range(classes): + try: + gt = np.equal(truth, i) + pd = np.equal(pred, i) + not_gt = np.logical_not(np.equal(truth, i)) + not_pd = np.logical_not(np.equal(pred, i)) + # Calculate accuracy + acc = (np.logical_and(pd, gt).sum() + \ + np.logical_and(not_pd, not_gt).sum()) / gt.size + acc_scores.append(acc) + except ZeroDivisionError: + acc_scores.append(0.0) + # Return computed accuracy scores + return acc_scores + +def calc_Precision(truth, pred, classes): + prec_scores = [] + # Iterate over each class + for i in range(classes): + try: + gt = np.equal(truth, i) + pd = np.equal(pred, i) + # Calculate precision + prec = np.logical_and(pd, gt).sum() / pd.sum() + prec_scores.append(prec) + except ZeroDivisionError: + prec_scores.append(0.0) + # Return computed precision scores + return prec_scores + +#-----------------------------------------------------# +# Plotting # +#-----------------------------------------------------# +def plot_fitting(df_log): + # Melt Data Set + df_fitting = df_log.melt(id_vars=["epoch"], + value_vars=["loss", "val_loss"], + var_name="Dataset", + value_name="score") + # Plot + fig = (ggplot(df_fitting, aes("epoch", "score", color="factor(Dataset)")) + + geom_smooth(method="gpr", size=2) + + ggtitle("Fitting Curve during Training") + + xlab("Epoch") + + ylab("Loss Function") + + scale_y_continuous(limits=[0, 3]) + + scale_colour_discrete(name="Dataset", + labels=["Training", "Validation"]) + + theme_bw(base_size=28)) + # # Plot + # fig = (ggplot(df_fitting, aes("epoch", "score", color="factor(Dataset)")) + # + geom_line() + # + ggtitle("Fitting Curve during Training") + # + xlab("Epoch") + # + ylab("Loss Function") + # + theme_bw()) + fig.save(filename="fitting_curve.png", path=eval_path, + width=12, height=10, dpi=300) + +#-----------------------------------------------------# +# Run Evaluation # +#-----------------------------------------------------# +# Initialize Data IO Interface for NIfTI data +## We are using 4 classes due to [background, lung_left, lung_right, covid-19] +interface = NIFTI_interface(channels=1, classes=4) + +# Create Data IO object to load and write samples in the file structure +data_io = Data_IO(interface, input_path="data", output_path=pred_path) + +# Access all available samples in our file structure +sample_list = data_io.get_indiceslist() +sample_list.sort() + +# Initialize dataframe +cols = ["index", "score", "background", "lung_L", "lung_R", "infection"] +df = pd.DataFrame(data=[], dtype=np.float64, columns=cols) + +# Iterate over each sample +for index in tqdm(sample_list): + # Load a sample including its image, ground truth and prediction + sample = data_io.sample_loader(index, load_seg=True, load_pred=True) + # Access image, ground truth and prediction data + image = sample.img_data + truth = sample.seg_data + pred = sample.pred_data + # Compute diverse Scores + dsc = calc_DSC(truth, pred, classes=4) + df = df.append(pd.Series([index, "DSC"] + dsc, index=cols), + ignore_index=True) + iou = calc_IoU(truth, pred, classes=4) + df = df.append(pd.Series([index, "IoU"] + iou, index=cols), + ignore_index=True) + sens = calc_Sensitivity(truth, pred, classes=4) + df = df.append(pd.Series([index, "Sens"] + sens, index=cols), + ignore_index=True) + spec = calc_Specificity(truth, pred, classes=4) + df = df.append(pd.Series([index, "Spec"] + spec, index=cols), + ignore_index=True) + prec = calc_Precision(truth, pred, classes=4) + df = df.append(pd.Series([index, "Prec"] + prec, index=cols), + ignore_index=True) + acc = calc_Accuracy(truth, pred, classes=4) + df = df.append(pd.Series([index, "Acc"] + acc, index=cols), + ignore_index=True) + # Compute Visualization + # visualize_evaluation(index, image, truth, pred, os.path.join(eval_path, "visualization")) + +# Output complete dataframe +print(df) +# Store complete dataframe to disk +path_res_detailed = os.path.join(eval_path, "cv_results.detailed.csv") +df.to_csv(path_res_detailed, index=False) + +# Initialize fitting logging dataframe +cols = ["epoch", "dice_crossentropy", "dice_soft", "loss", "lr", "tversky_loss", + "val_dice_crossentropy", "val_dice_soft", "val_loss","val_tversky_loss", + "fold"] +df_log = pd.DataFrame(data=[], dtype=np.float64, columns=cols) +cols_val = ["score", "background", "infection", "lungs", "fold"] +df_global = pd.DataFrame(data=[], dtype=np.float64, columns=cols_val) + +# Compute per-fold scores +for fold in os.listdir(eval_path): + # Skip all files in evaluation which are not cross-validation dirs + if not fold.startswith("fold_") : continue + # Identify validation samples of this fold + _, val_list = load_disk2fold(os.path.join(eval_path, fold, "sample_list.json")) + # Obtain metrics for validation list + df_val = df.loc[df["index"].isin(val_list)] + # Print out average and std evaluation metrics for the current fold + df_avg = df_val.groupby(by="score").mean() + df_std = df_val.groupby(by="score").std() + print(fold) + print(df_avg) + print(df_std) + # Compute average metrics for validation list + df_val = df_val.groupby(by="score").median() + # Combine lung left and lung right class by mean + df_val["lungs"] = df_val[["lung_L", "lung_R"]].mean(axis=1) + df_val.drop(["lung_L", "lung_R"], axis=1, inplace=True) + # Add df_val df to df_global + df_val["fold"] = fold + df_val = df_val.reset_index() + df_global = df_global.append(df_val, ignore_index=True) + # Load logging data for fitting plot + path_log = os.path.join(eval_path, fold, "history.tsv") + df_logfold = pd.read_csv(path_log, header=0, sep="\t") + df_logfold["fold"] = fold + # Add logging data to global fitting log dataframe + df_log = df_log.append(df_logfold, ignore_index=True) + +# Run plotting of fitting process +plot_fitting(df_log) + +# Output cross-validation results +print(df_global) +# Save cross-validation results to disk +path_res_global = os.path.join(eval_path, "cv_results.final.csv") +df_global.to_csv(path_res_global, index=False)