|
a |
|
b/scripts/run_evaluation.py |
|
|
1 |
#==============================================================================# |
|
|
2 |
# Author: Dominik Müller # |
|
|
3 |
# Copyright: 2020 IT-Infrastructure for Translational Medical Research, # |
|
|
4 |
# University of Augsburg # |
|
|
5 |
# # |
|
|
6 |
# This program is free software: you can redistribute it and/or modify # |
|
|
7 |
# it under the terms of the GNU General Public License as published by # |
|
|
8 |
# the Free Software Foundation, either version 3 of the License, or # |
|
|
9 |
# (at your option) any later version. # |
|
|
10 |
# # |
|
|
11 |
# This program is distributed in the hope that it will be useful, # |
|
|
12 |
# but WITHOUT ANY WARRANTY; without even the implied warranty of # |
|
|
13 |
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # |
|
|
14 |
# GNU General Public License for more details. # |
|
|
15 |
# # |
|
|
16 |
# You should have received a copy of the GNU General Public License # |
|
|
17 |
# along with this program. If not, see <http://www.gnu.org/licenses/>. # |
|
|
18 |
#==============================================================================# |
|
|
19 |
#-----------------------------------------------------# |
|
|
20 |
# Library imports # |
|
|
21 |
#-----------------------------------------------------# |
|
|
22 |
import matplotlib.pyplot as plt |
|
|
23 |
import matplotlib.animation as animation |
|
|
24 |
import numpy as np |
|
|
25 |
import pandas as pd |
|
|
26 |
import os |
|
|
27 |
from tqdm import tqdm |
|
|
28 |
from miscnn.data_loading.interfaces import NIFTI_interface |
|
|
29 |
from miscnn import Data_IO |
|
|
30 |
from plotnine import * |
|
|
31 |
|
|
|
32 |
#-----------------------------------------------------# |
|
|
33 |
# Visualization # |
|
|
34 |
#-----------------------------------------------------# |
|
|
35 |
def visualize_evaluation(case_id, vol, truth, pred, eva_path): |
|
|
36 |
# Squeeze image files to remove channel axis |
|
|
37 |
vol = np.squeeze(vol, axis=-1) |
|
|
38 |
truth = np.squeeze(truth, axis=-1) |
|
|
39 |
pred = np.squeeze(pred, axis=-1) |
|
|
40 |
# Color volumes according to truth and pred segmentation |
|
|
41 |
vol_raw = overlay_segmentation(vol, np.zeros(vol.shape)) |
|
|
42 |
vol_truth = overlay_segmentation(vol, truth) |
|
|
43 |
vol_pred = overlay_segmentation(vol, pred) |
|
|
44 |
# Create a figure and two axes objects from matplot |
|
|
45 |
fig, (ax1, ax2, ax3) = plt.subplots(1, 3) |
|
|
46 |
# Initialize the two subplots (axes) with an empty image |
|
|
47 |
data = np.zeros(vol.shape[0:2]) |
|
|
48 |
ax1.set_title("CT Scan") |
|
|
49 |
ax2.set_title("Ground Truth") |
|
|
50 |
ax3.set_title("Prediction") |
|
|
51 |
img1 = ax1.imshow(data) |
|
|
52 |
img2 = ax2.imshow(data) |
|
|
53 |
img3 = ax3.imshow(data) |
|
|
54 |
# Update function for both images to show the slice for the current frame |
|
|
55 |
def update(i): |
|
|
56 |
plt.suptitle("Case ID: " + str(case_id) + " - " + "Slice: " + str(i)) |
|
|
57 |
img1.set_data(vol_raw[:,:,i]) |
|
|
58 |
img2.set_data(vol_truth[:,:,i]) |
|
|
59 |
img3.set_data(vol_pred[:,:,i]) |
|
|
60 |
return [img1, img2, img3] |
|
|
61 |
# Compute the animation (gif) |
|
|
62 |
ani = animation.FuncAnimation(fig, update, frames=truth.shape[2], |
|
|
63 |
interval=10, repeat_delay=0, blit=False) |
|
|
64 |
# Set up the output path for the gif |
|
|
65 |
if not os.path.exists(eva_path): |
|
|
66 |
os.mkdir(eva_path) |
|
|
67 |
file_name = "visualization." + str(case_id).zfill(5) + ".gif" |
|
|
68 |
out_path = os.path.join(eva_path, file_name) |
|
|
69 |
# Save the animation (gif) |
|
|
70 |
ani.save(out_path, writer='imagemagick', fps=20, dpi=150) |
|
|
71 |
# Close the matplot |
|
|
72 |
plt.close() |
|
|
73 |
|
|
|
74 |
# Based on: https://github.com/neheller/kits19/blob/master/starter_code/visualize.py |
|
|
75 |
def overlay_segmentation(vol, seg): |
|
|
76 |
# Clip intensities to -1250 and +250 |
|
|
77 |
vol = np.clip(vol, -1250, 250) |
|
|
78 |
# Scale volume to greyscale range |
|
|
79 |
vol_scaled = (vol - np.min(vol)) / (np.max(vol) - np.min(vol)) |
|
|
80 |
vol_greyscale = np.around(vol_scaled * 255, decimals=0).astype(np.uint8) |
|
|
81 |
# Convert volume to RGB |
|
|
82 |
vol_rgb = np.stack([vol_greyscale, vol_greyscale, vol_greyscale], axis=-1) |
|
|
83 |
# Initialize segmentation in RGB |
|
|
84 |
shp = seg.shape |
|
|
85 |
seg_rgb = np.zeros((shp[0], shp[1], shp[2], 3), dtype=np.int) |
|
|
86 |
# Set class to appropriate color |
|
|
87 |
seg_rgb[np.equal(seg, 1)] = [0, 0, 255] |
|
|
88 |
seg_rgb[np.equal(seg, 2)] = [0, 0, 255] |
|
|
89 |
seg_rgb[np.equal(seg, 3)] = [255, 0, 0] |
|
|
90 |
# Get binary array for places where an ROI lives |
|
|
91 |
segbin = np.greater(seg, 0) |
|
|
92 |
repeated_segbin = np.stack((segbin, segbin, segbin), axis=-1) |
|
|
93 |
# Weighted sum where there's a value to overlay |
|
|
94 |
alpha = 0.3 |
|
|
95 |
vol_overlayed = np.where( |
|
|
96 |
repeated_segbin, |
|
|
97 |
np.round(alpha*seg_rgb+(1-alpha)*vol_rgb).astype(np.uint8), |
|
|
98 |
np.round(vol_rgb).astype(np.uint8) |
|
|
99 |
) |
|
|
100 |
# Return final volume with segmentation overlay |
|
|
101 |
return vol_overlayed |
|
|
102 |
|
|
|
103 |
#-----------------------------------------------------# |
|
|
104 |
# Score Calculations # |
|
|
105 |
#-----------------------------------------------------# |
|
|
106 |
def calc_DSC(truth, pred, classes): |
|
|
107 |
dice_scores = [] |
|
|
108 |
# Iterate over each class |
|
|
109 |
for i in range(classes): |
|
|
110 |
try: |
|
|
111 |
gt = np.equal(truth, i) |
|
|
112 |
pd = np.equal(pred, i) |
|
|
113 |
# Calculate Dice |
|
|
114 |
dice = 2*np.logical_and(pd, gt).sum() / (pd.sum() + gt.sum()) |
|
|
115 |
dice_scores.append(dice) |
|
|
116 |
except ZeroDivisionError: |
|
|
117 |
dice_scores.append(0.0) |
|
|
118 |
# Return computed Dice Similarity Coefficients |
|
|
119 |
return dice_scores |
|
|
120 |
|
|
|
121 |
def calc_IoU(truth, pred, classes): |
|
|
122 |
iou_scores = [] |
|
|
123 |
# Iterate over each class |
|
|
124 |
for i in range(classes): |
|
|
125 |
try: |
|
|
126 |
gt = np.equal(truth, i) |
|
|
127 |
pd = np.equal(pred, i) |
|
|
128 |
# Calculate iou |
|
|
129 |
iou = np.logical_and(pd, gt).sum() / (pd.sum() + gt.sum() - np.logical_and(pd, gt).sum()) |
|
|
130 |
iou_scores.append(iou) |
|
|
131 |
except ZeroDivisionError: |
|
|
132 |
iou_scores.append(0.0) |
|
|
133 |
# Return computed IoU |
|
|
134 |
return iou_scores |
|
|
135 |
|
|
|
136 |
def calc_Sensitivity(truth, pred, classes): |
|
|
137 |
sens_scores = [] |
|
|
138 |
# Iterate over each class |
|
|
139 |
for i in range(classes): |
|
|
140 |
try: |
|
|
141 |
gt = np.equal(truth, i) |
|
|
142 |
pd = np.equal(pred, i) |
|
|
143 |
# Calculate sensitivity |
|
|
144 |
sens = np.logical_and(pd, gt).sum() / gt.sum() |
|
|
145 |
sens_scores.append(sens) |
|
|
146 |
except ZeroDivisionError: |
|
|
147 |
sens_scores.append(0.0) |
|
|
148 |
# Return computed sensitivity scores |
|
|
149 |
return sens_scores |
|
|
150 |
|
|
|
151 |
def calc_Specificity(truth, pred, classes): |
|
|
152 |
spec_scores = [] |
|
|
153 |
# Iterate over each class |
|
|
154 |
for i in range(classes): |
|
|
155 |
try: |
|
|
156 |
not_gt = np.logical_not(np.equal(truth, i)) |
|
|
157 |
not_pd = np.logical_not(np.equal(pred, i)) |
|
|
158 |
# Calculate specificity |
|
|
159 |
spec = np.logical_and(not_pd, not_gt).sum() / (not_gt).sum() |
|
|
160 |
spec_scores.append(spec) |
|
|
161 |
except ZeroDivisionError: |
|
|
162 |
spec_scores.append(0.0) |
|
|
163 |
# Return computed specificity scores |
|
|
164 |
return spec_scores |
|
|
165 |
|
|
|
166 |
def calc_Accuracy(truth, pred, classes): |
|
|
167 |
acc_scores = [] |
|
|
168 |
# Iterate over each class |
|
|
169 |
for i in range(classes): |
|
|
170 |
try: |
|
|
171 |
gt = np.equal(truth, i) |
|
|
172 |
pd = np.equal(pred, i) |
|
|
173 |
not_gt = np.logical_not(np.equal(truth, i)) |
|
|
174 |
not_pd = np.logical_not(np.equal(pred, i)) |
|
|
175 |
# Calculate accuracy |
|
|
176 |
acc = (np.logical_and(pd, gt).sum() + \ |
|
|
177 |
np.logical_and(not_pd, not_gt).sum()) / gt.size |
|
|
178 |
acc_scores.append(acc) |
|
|
179 |
except ZeroDivisionError: |
|
|
180 |
acc_scores.append(0.0) |
|
|
181 |
# Return computed accuracy scores |
|
|
182 |
return acc_scores |
|
|
183 |
|
|
|
184 |
def calc_Precision(truth, pred, classes): |
|
|
185 |
prec_scores = [] |
|
|
186 |
# Iterate over each class |
|
|
187 |
for i in range(classes): |
|
|
188 |
try: |
|
|
189 |
gt = np.equal(truth, i) |
|
|
190 |
pd = np.equal(pred, i) |
|
|
191 |
# Calculate precision |
|
|
192 |
prec = np.logical_and(pd, gt).sum() / pd.sum() |
|
|
193 |
prec_scores.append(prec) |
|
|
194 |
except ZeroDivisionError: |
|
|
195 |
prec_scores.append(0.0) |
|
|
196 |
# Return computed precision scores |
|
|
197 |
return prec_scores |
|
|
198 |
|
|
|
199 |
#-----------------------------------------------------# |
|
|
200 |
# Plotting # |
|
|
201 |
#-----------------------------------------------------# |
|
|
202 |
def plot_fitting(df_log): |
|
|
203 |
# Melt Data Set |
|
|
204 |
df_fitting = df_log.melt(id_vars=["epoch"], |
|
|
205 |
value_vars=["loss", "val_loss"], |
|
|
206 |
var_name="Dataset", |
|
|
207 |
value_name="score") |
|
|
208 |
# Plot |
|
|
209 |
fig = (ggplot(df_fitting, aes("epoch", "score", color="factor(Dataset)")) |
|
|
210 |
+ geom_smooth(method="gpr", size=2) |
|
|
211 |
+ ggtitle("Fitting Curve during Training") |
|
|
212 |
+ xlab("Epoch") |
|
|
213 |
+ ylab("Loss Function") |
|
|
214 |
+ scale_y_continuous(limits=[0, 3]) |
|
|
215 |
+ scale_colour_discrete(name="Dataset", |
|
|
216 |
labels=["Training", "Validation"]) |
|
|
217 |
+ theme_bw(base_size=28)) |
|
|
218 |
# # Plot |
|
|
219 |
# fig = (ggplot(df_fitting, aes("epoch", "score", color="factor(Dataset)")) |
|
|
220 |
# + geom_line() |
|
|
221 |
# + ggtitle("Fitting Curve during Training") |
|
|
222 |
# + xlab("Epoch") |
|
|
223 |
# + ylab("Loss Function") |
|
|
224 |
# + theme_bw()) |
|
|
225 |
fig.save(filename="fitting_curve.png", path="evaluation", |
|
|
226 |
width=12, height=10, dpi=300) |
|
|
227 |
|
|
|
228 |
#-----------------------------------------------------# |
|
|
229 |
# Run Evaluation # |
|
|
230 |
#-----------------------------------------------------# |
|
|
231 |
# Initialize Data IO Interface for NIfTI data |
|
|
232 |
## We are using 4 classes due to [background, lung_left, lung_right, covid-19] |
|
|
233 |
interface = NIFTI_interface(channels=1, classes=4) |
|
|
234 |
|
|
|
235 |
# Create Data IO object to load and write samples in the file structure |
|
|
236 |
data_io = Data_IO(interface, input_path="data", output_path="predictions") |
|
|
237 |
|
|
|
238 |
# Access all available samples in our file structure |
|
|
239 |
sample_list = data_io.get_indiceslist() |
|
|
240 |
sample_list.sort() |
|
|
241 |
|
|
|
242 |
# Initialize dataframe |
|
|
243 |
cols = ["index", "score", "background", "lung_L", "lung_R", "infection"] |
|
|
244 |
df = pd.DataFrame(data=[], dtype=np.float64, columns=cols) |
|
|
245 |
|
|
|
246 |
# Iterate over each sample |
|
|
247 |
for index in tqdm(sample_list): |
|
|
248 |
# Load a sample including its image, ground truth and prediction |
|
|
249 |
sample = data_io.sample_loader(index, load_seg=True, load_pred=True) |
|
|
250 |
# Access image, ground truth and prediction data |
|
|
251 |
image = sample.img_data |
|
|
252 |
truth = sample.seg_data |
|
|
253 |
pred = sample.pred_data |
|
|
254 |
# Compute diverse Scores |
|
|
255 |
dsc = calc_DSC(truth, pred, classes=4) |
|
|
256 |
df = df.append(pd.Series([index, "DSC"] + dsc, index=cols), |
|
|
257 |
ignore_index=True) |
|
|
258 |
iou = calc_IoU(truth, pred, classes=4) |
|
|
259 |
df = df.append(pd.Series([index, "IoU"] + iou, index=cols), |
|
|
260 |
ignore_index=True) |
|
|
261 |
sens = calc_Sensitivity(truth, pred, classes=4) |
|
|
262 |
df = df.append(pd.Series([index, "Sens"] + sens, index=cols), |
|
|
263 |
ignore_index=True) |
|
|
264 |
spec = calc_Specificity(truth, pred, classes=4) |
|
|
265 |
df = df.append(pd.Series([index, "Spec"] + spec, index=cols), |
|
|
266 |
ignore_index=True) |
|
|
267 |
prec = calc_Precision(truth, pred, classes=4) |
|
|
268 |
df = df.append(pd.Series([index, "Prec"] + prec, index=cols), |
|
|
269 |
ignore_index=True) |
|
|
270 |
acc = calc_Accuracy(truth, pred, classes=4) |
|
|
271 |
df = df.append(pd.Series([index, "Acc"] + acc, index=cols), |
|
|
272 |
ignore_index=True) |
|
|
273 |
# Compute Visualization |
|
|
274 |
visualize_evaluation(index, image, truth, pred, "evaluation/visualization") |
|
|
275 |
|
|
|
276 |
# Output complete dataframe |
|
|
277 |
print(df) |
|
|
278 |
# Store complete dataframe to disk |
|
|
279 |
path_res_detailed = os.path.join("evaluation", "cv_results.detailed.csv") |
|
|
280 |
df.to_csv(path_res_detailed, index=False) |
|
|
281 |
|
|
|
282 |
# Initialize fitting logging dataframe |
|
|
283 |
cols = ["epoch", "dice_crossentropy", "dice_soft", "loss", "lr", "tversky_loss", |
|
|
284 |
"val_dice_crossentropy", "val_dice_soft", "val_loss","val_tversky_loss", |
|
|
285 |
"fold"] |
|
|
286 |
df_log = pd.DataFrame(data=[], dtype=np.float64, columns=cols) |
|
|
287 |
cols_val = ["score", "background", "infection", "lungs", "fold"] |
|
|
288 |
df_global = pd.DataFrame(data=[], dtype=np.float64, columns=cols_val) |
|
|
289 |
|
|
|
290 |
# Compute per-fold scores |
|
|
291 |
for fold in os.listdir("evaluation"): |
|
|
292 |
# Skip all files in evaluation which are not cross-validation dirs |
|
|
293 |
if not fold.startswith("fold_") : continue |
|
|
294 |
# Identify validation samples of this fold |
|
|
295 |
path_detval= os.path.join("evaluation", fold, "sample_list.csv") |
|
|
296 |
detval = pd.read_csv(path_detval, sep=" ", header=None) |
|
|
297 |
detval = detval.iloc[1].dropna() |
|
|
298 |
val_list = detval.values.tolist()[1:] |
|
|
299 |
# Obtain metrics for validation list |
|
|
300 |
df_val = df.loc[df["index"].isin(val_list)] |
|
|
301 |
# Print out average and std evaluation metrics for the current fold |
|
|
302 |
df_avg = df_val.groupby(by="score").mean() |
|
|
303 |
df_std = df_val.groupby(by="score").std() |
|
|
304 |
print(fold) |
|
|
305 |
print(df_avg) |
|
|
306 |
print(df_std) |
|
|
307 |
# Compute average metrics for validation list |
|
|
308 |
df_val = df_val.groupby(by="score").median() |
|
|
309 |
# Combine lung left and lung right class by mean |
|
|
310 |
df_val["lungs"] = df_val[["lung_L", "lung_R"]].mean(axis=1) |
|
|
311 |
df_val.drop(["lung_L", "lung_R"], axis=1, inplace=True) |
|
|
312 |
# Add df_val df to df_global |
|
|
313 |
df_val["fold"] = fold |
|
|
314 |
df_val = df_val.reset_index() |
|
|
315 |
df_global = df_global.append(df_val, ignore_index=True) |
|
|
316 |
# Load logging data for fitting plot |
|
|
317 |
path_log = os.path.join("evaluation", fold, "history.tsv") |
|
|
318 |
df_logfold = pd.read_csv(path_log, header=0, sep="\t") |
|
|
319 |
df_logfold["fold"] = fold |
|
|
320 |
# Add logging data to global fitting log dataframe |
|
|
321 |
df_log = df_log.append(df_logfold, ignore_index=True) |
|
|
322 |
|
|
|
323 |
# Run plotting of fitting process |
|
|
324 |
plot_fitting(df_log) |
|
|
325 |
|
|
|
326 |
# Output cross-validation results |
|
|
327 |
print(df_global) |
|
|
328 |
# Save cross-validation results to disk |
|
|
329 |
path_res_global = os.path.join("evaluation", "cv_results.final.csv") |
|
|
330 |
df_global.to_csv(path_res_global, index=False) |