Diff of /scripts/run_evaluation.py [000000] .. [2afb35]

Switch to unified view

a b/scripts/run_evaluation.py
1
    #==============================================================================#
2
#  Author:       Dominik Müller                                                #
3
#  Copyright:    2020 IT-Infrastructure for Translational Medical Research,    #
4
#                University of Augsburg                                        #
5
#                                                                              #
6
#  This program is free software: you can redistribute it and/or modify        #
7
#  it under the terms of the GNU General Public License as published by        #
8
#  the Free Software Foundation, either version 3 of the License, or           #
9
#  (at your option) any later version.                                         #
10
#                                                                              #
11
#  This program is distributed in the hope that it will be useful,             #
12
#  but WITHOUT ANY WARRANTY; without even the implied warranty of              #
13
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the               #
14
#  GNU General Public License for more details.                                #
15
#                                                                              #
16
#  You should have received a copy of the GNU General Public License           #
17
#  along with this program.  If not, see <http://www.gnu.org/licenses/>.       #
18
#==============================================================================#
19
#-----------------------------------------------------#
20
#                   Library imports                   #
21
#-----------------------------------------------------#
22
import matplotlib.pyplot as plt
23
import matplotlib.animation as animation
24
import numpy as np
25
import pandas as pd
26
import os
27
from tqdm import tqdm
28
from miscnn.data_loading.interfaces import NIFTI_interface
29
from miscnn import Data_IO
30
from plotnine import *
31
32
#-----------------------------------------------------#
33
#                    Visualization                    #
34
#-----------------------------------------------------#
35
def visualize_evaluation(case_id, vol, truth, pred, eva_path):
36
    # Squeeze image files to remove channel axis
37
    vol = np.squeeze(vol, axis=-1)
38
    truth = np.squeeze(truth, axis=-1)
39
    pred = np.squeeze(pred, axis=-1)
40
    # Color volumes according to truth and pred segmentation
41
    vol_raw = overlay_segmentation(vol, np.zeros(vol.shape))
42
    vol_truth = overlay_segmentation(vol, truth)
43
    vol_pred = overlay_segmentation(vol, pred)
44
    # Create a figure and two axes objects from matplot
45
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3)
46
    # Initialize the two subplots (axes) with an empty image
47
    data = np.zeros(vol.shape[0:2])
48
    ax1.set_title("CT Scan")
49
    ax2.set_title("Ground Truth")
50
    ax3.set_title("Prediction")
51
    img1 = ax1.imshow(data)
52
    img2 = ax2.imshow(data)
53
    img3 = ax3.imshow(data)
54
    # Update function for both images to show the slice for the current frame
55
    def update(i):
56
        plt.suptitle("Case ID: " + str(case_id) + " - " + "Slice: " + str(i))
57
        img1.set_data(vol_raw[:,:,i])
58
        img2.set_data(vol_truth[:,:,i])
59
        img3.set_data(vol_pred[:,:,i])
60
        return [img1, img2, img3]
61
    # Compute the animation (gif)
62
    ani = animation.FuncAnimation(fig, update, frames=truth.shape[2],
63
                                  interval=10, repeat_delay=0, blit=False)
64
    # Set up the output path for the gif
65
    if not os.path.exists(eva_path):
66
        os.mkdir(eva_path)
67
    file_name = "visualization." + str(case_id).zfill(5) + ".gif"
68
    out_path = os.path.join(eva_path, file_name)
69
    # Save the animation (gif)
70
    ani.save(out_path, writer='imagemagick', fps=20, dpi=150)
71
    # Close the matplot
72
    plt.close()
73
74
# Based on: https://github.com/neheller/kits19/blob/master/starter_code/visualize.py
75
def overlay_segmentation(vol, seg):
76
    # Clip intensities to -1250 and +250
77
    vol = np.clip(vol, -1250, 250)
78
    # Scale volume to greyscale range
79
    vol_scaled = (vol - np.min(vol)) / (np.max(vol) - np.min(vol))
80
    vol_greyscale = np.around(vol_scaled * 255, decimals=0).astype(np.uint8)
81
    # Convert volume to RGB
82
    vol_rgb = np.stack([vol_greyscale, vol_greyscale, vol_greyscale], axis=-1)
83
    # Initialize segmentation in RGB
84
    shp = seg.shape
85
    seg_rgb = np.zeros((shp[0], shp[1], shp[2], 3), dtype=np.int)
86
    # Set class to appropriate color
87
    seg_rgb[np.equal(seg, 1)] = [0, 0, 255]
88
    seg_rgb[np.equal(seg, 2)] = [0, 0, 255]
89
    seg_rgb[np.equal(seg, 3)] = [255, 0, 0]
90
    # Get binary array for places where an ROI lives
91
    segbin = np.greater(seg, 0)
92
    repeated_segbin = np.stack((segbin, segbin, segbin), axis=-1)
93
    # Weighted sum where there's a value to overlay
94
    alpha = 0.3
95
    vol_overlayed = np.where(
96
        repeated_segbin,
97
        np.round(alpha*seg_rgb+(1-alpha)*vol_rgb).astype(np.uint8),
98
        np.round(vol_rgb).astype(np.uint8)
99
    )
100
    # Return final volume with segmentation overlay
101
    return vol_overlayed
102
103
#-----------------------------------------------------#
104
#                  Score Calculations                 #
105
#-----------------------------------------------------#
106
def calc_DSC(truth, pred, classes):
107
    dice_scores = []
108
    # Iterate over each class
109
    for i in range(classes):
110
        try:
111
            gt = np.equal(truth, i)
112
            pd = np.equal(pred, i)
113
            # Calculate Dice
114
            dice = 2*np.logical_and(pd, gt).sum() / (pd.sum() + gt.sum())
115
            dice_scores.append(dice)
116
        except ZeroDivisionError:
117
            dice_scores.append(0.0)
118
    # Return computed Dice Similarity Coefficients
119
    return dice_scores
120
121
def calc_IoU(truth, pred, classes):
122
    iou_scores = []
123
    # Iterate over each class
124
    for i in range(classes):
125
        try:
126
            gt = np.equal(truth, i)
127
            pd = np.equal(pred, i)
128
            # Calculate iou
129
            iou = np.logical_and(pd, gt).sum() / (pd.sum() + gt.sum() - np.logical_and(pd, gt).sum())
130
            iou_scores.append(iou)
131
        except ZeroDivisionError:
132
            iou_scores.append(0.0)
133
    # Return computed IoU
134
    return iou_scores
135
136
def calc_Sensitivity(truth, pred, classes):
137
    sens_scores = []
138
    # Iterate over each class
139
    for i in range(classes):
140
        try:
141
            gt = np.equal(truth, i)
142
            pd = np.equal(pred, i)
143
            # Calculate sensitivity
144
            sens = np.logical_and(pd, gt).sum() / gt.sum()
145
            sens_scores.append(sens)
146
        except ZeroDivisionError:
147
            sens_scores.append(0.0)
148
    # Return computed sensitivity scores
149
    return sens_scores
150
151
def calc_Specificity(truth, pred, classes):
152
    spec_scores = []
153
    # Iterate over each class
154
    for i in range(classes):
155
        try:
156
            not_gt = np.logical_not(np.equal(truth, i))
157
            not_pd = np.logical_not(np.equal(pred, i))
158
            # Calculate specificity
159
            spec = np.logical_and(not_pd, not_gt).sum() / (not_gt).sum()
160
            spec_scores.append(spec)
161
        except ZeroDivisionError:
162
            spec_scores.append(0.0)
163
    # Return computed specificity scores
164
    return spec_scores
165
166
def calc_Accuracy(truth, pred, classes):
167
    acc_scores = []
168
    # Iterate over each class
169
    for i in range(classes):
170
        try:
171
            gt = np.equal(truth, i)
172
            pd = np.equal(pred, i)
173
            not_gt = np.logical_not(np.equal(truth, i))
174
            not_pd = np.logical_not(np.equal(pred, i))
175
            # Calculate accuracy
176
            acc = (np.logical_and(pd, gt).sum() + \
177
                   np.logical_and(not_pd, not_gt).sum()) /  gt.size
178
            acc_scores.append(acc)
179
        except ZeroDivisionError:
180
            acc_scores.append(0.0)
181
    # Return computed accuracy scores
182
    return acc_scores
183
184
def calc_Precision(truth, pred, classes):
185
    prec_scores = []
186
    # Iterate over each class
187
    for i in range(classes):
188
        try:
189
            gt = np.equal(truth, i)
190
            pd = np.equal(pred, i)
191
            # Calculate precision
192
            prec = np.logical_and(pd, gt).sum() / pd.sum()
193
            prec_scores.append(prec)
194
        except ZeroDivisionError:
195
            prec_scores.append(0.0)
196
    # Return computed precision scores
197
    return prec_scores
198
199
#-----------------------------------------------------#
200
#                      Plotting                       #
201
#-----------------------------------------------------#
202
def plot_fitting(df_log):
203
    # Melt Data Set
204
    df_fitting = df_log.melt(id_vars=["epoch"],
205
                             value_vars=["loss", "val_loss"],
206
                             var_name="Dataset",
207
                             value_name="score")
208
    # Plot
209
    fig = (ggplot(df_fitting, aes("epoch", "score", color="factor(Dataset)"))
210
                  + geom_smooth(method="gpr", size=2)
211
                  + ggtitle("Fitting Curve during Training")
212
                  + xlab("Epoch")
213
                  + ylab("Loss Function")
214
                  + scale_y_continuous(limits=[0, 3])
215
                  + scale_colour_discrete(name="Dataset",
216
                                          labels=["Training", "Validation"])
217
                  + theme_bw(base_size=28))
218
    # # Plot
219
    # fig = (ggplot(df_fitting, aes("epoch", "score", color="factor(Dataset)"))
220
    #               + geom_line()
221
    #               + ggtitle("Fitting Curve during Training")
222
    #               + xlab("Epoch")
223
    #               + ylab("Loss Function")
224
    #               + theme_bw())
225
    fig.save(filename="fitting_curve.png", path="evaluation",
226
             width=12, height=10, dpi=300)
227
228
#-----------------------------------------------------#
229
#                    Run Evaluation                   #
230
#-----------------------------------------------------#
231
# Initialize Data IO Interface for NIfTI data
232
## We are using 4 classes due to [background, lung_left, lung_right, covid-19]
233
interface = NIFTI_interface(channels=1, classes=4)
234
235
# Create Data IO object to load and write samples in the file structure
236
data_io = Data_IO(interface, input_path="data", output_path="predictions")
237
238
# Access all available samples in our file structure
239
sample_list = data_io.get_indiceslist()
240
sample_list.sort()
241
242
# Initialize dataframe
243
cols = ["index", "score", "background", "lung_L", "lung_R", "infection"]
244
df = pd.DataFrame(data=[], dtype=np.float64, columns=cols)
245
246
# Iterate over each sample
247
for index in tqdm(sample_list):
248
    # Load a sample including its image, ground truth and prediction
249
    sample = data_io.sample_loader(index, load_seg=True, load_pred=True)
250
    # Access image, ground truth and prediction data
251
    image = sample.img_data
252
    truth = sample.seg_data
253
    pred = sample.pred_data
254
    # Compute diverse Scores
255
    dsc = calc_DSC(truth, pred, classes=4)
256
    df = df.append(pd.Series([index, "DSC"] + dsc, index=cols),
257
                   ignore_index=True)
258
    iou = calc_IoU(truth, pred, classes=4)
259
    df = df.append(pd.Series([index, "IoU"] + iou, index=cols),
260
                   ignore_index=True)
261
    sens = calc_Sensitivity(truth, pred, classes=4)
262
    df = df.append(pd.Series([index, "Sens"] + sens, index=cols),
263
                   ignore_index=True)
264
    spec = calc_Specificity(truth, pred, classes=4)
265
    df = df.append(pd.Series([index, "Spec"] + spec, index=cols),
266
                   ignore_index=True)
267
    prec = calc_Precision(truth, pred, classes=4)
268
    df = df.append(pd.Series([index, "Prec"] + prec, index=cols),
269
                   ignore_index=True)
270
    acc = calc_Accuracy(truth, pred, classes=4)
271
    df = df.append(pd.Series([index, "Acc"] + acc, index=cols),
272
                   ignore_index=True)
273
    # Compute Visualization
274
    visualize_evaluation(index, image, truth, pred, "evaluation/visualization")
275
276
# Output complete dataframe
277
print(df)
278
# Store complete dataframe to disk
279
path_res_detailed = os.path.join("evaluation", "cv_results.detailed.csv")
280
df.to_csv(path_res_detailed, index=False)
281
282
# Initialize fitting logging dataframe
283
cols = ["epoch", "dice_crossentropy", "dice_soft", "loss", "lr", "tversky_loss",
284
        "val_dice_crossentropy", "val_dice_soft", "val_loss","val_tversky_loss",
285
        "fold"]
286
df_log = pd.DataFrame(data=[], dtype=np.float64, columns=cols)
287
cols_val = ["score", "background", "infection", "lungs", "fold"]
288
df_global = pd.DataFrame(data=[], dtype=np.float64, columns=cols_val)
289
290
# Compute per-fold scores
291
for fold in os.listdir("evaluation"):
292
    # Skip all files in evaluation which are not cross-validation dirs
293
    if not fold.startswith("fold_") : continue
294
    # Identify validation samples of this fold
295
    path_detval= os.path.join("evaluation", fold, "sample_list.csv")
296
    detval = pd.read_csv(path_detval, sep=" ", header=None)
297
    detval = detval.iloc[1].dropna()
298
    val_list = detval.values.tolist()[1:]
299
    # Obtain metrics for validation list
300
    df_val = df.loc[df["index"].isin(val_list)]
301
    # Print out average and std evaluation metrics for the current fold
302
    df_avg = df_val.groupby(by="score").mean()
303
    df_std = df_val.groupby(by="score").std()
304
    print(fold)
305
    print(df_avg)
306
    print(df_std)
307
    # Compute average metrics for validation list
308
    df_val = df_val.groupby(by="score").median()
309
    # Combine lung left and lung right class by mean
310
    df_val["lungs"] = df_val[["lung_L", "lung_R"]].mean(axis=1)
311
    df_val.drop(["lung_L", "lung_R"], axis=1, inplace=True)
312
    # Add df_val df to df_global
313
    df_val["fold"] = fold
314
    df_val = df_val.reset_index()
315
    df_global = df_global.append(df_val, ignore_index=True)
316
    # Load logging data for fitting plot
317
    path_log = os.path.join("evaluation", fold, "history.tsv")
318
    df_logfold = pd.read_csv(path_log, header=0, sep="\t")
319
    df_logfold["fold"] = fold
320
    # Add logging data to global fitting log dataframe
321
    df_log = df_log.append(df_logfold, ignore_index=True)
322
323
# Run plotting of fitting process
324
plot_fitting(df_log)
325
326
# Output cross-validation results
327
print(df_global)
328
# Save cross-validation results to disk
329
path_res_global = os.path.join("evaluation", "cv_results.final.csv")
330
df_global.to_csv(path_res_global, index=False)