|
a |
|
b/run/utils.py |
|
|
1 |
import copy |
|
|
2 |
import os |
|
|
3 |
import torch |
|
|
4 |
|
|
|
5 |
import numpy as np |
|
|
6 |
import nibabel as nib |
|
|
7 |
import SimpleITK as sitk |
|
|
8 |
|
|
|
9 |
from semseg.data_loader import TorchIODataLoader3DTraining |
|
|
10 |
from models.vnet3d import VNet3D |
|
|
11 |
from semseg.utils import zero_pad_3d_image, z_score_normalization |
|
|
12 |
|
|
|
13 |
|
|
|
14 |
def print_config(config): |
|
|
15 |
attributes_config = [attr for attr in dir(config) |
|
|
16 |
if not attr.startswith('__')] |
|
|
17 |
print("Config") |
|
|
18 |
for item in attributes_config: |
|
|
19 |
attr_val = getattr(config,item) |
|
|
20 |
if len(str(attr_val)) < 100: |
|
|
21 |
print("{:15s} ==> {}".format(item, attr_val)) |
|
|
22 |
else: |
|
|
23 |
print("{:15s} ==> String too long [{} characters]".format(item,len(str(attr_val)))) |
|
|
24 |
|
|
|
25 |
|
|
|
26 |
def check_train_set(config): |
|
|
27 |
num_train_images = len(config.train_images) |
|
|
28 |
num_train_labels = len(config.train_labels) |
|
|
29 |
|
|
|
30 |
assert num_train_images == num_train_labels, "Mismatch in number of training images and labels!" |
|
|
31 |
|
|
|
32 |
print("There are: {} Training Images".format(num_train_images)) |
|
|
33 |
print("There are: {} Training Labels".format(num_train_labels)) |
|
|
34 |
|
|
|
35 |
|
|
|
36 |
def check_torch_loader(config, check_net=False): |
|
|
37 |
train_data_loader_3D = TorchIODataLoader3DTraining(config) |
|
|
38 |
iterable_data_loader = iter(train_data_loader_3D) |
|
|
39 |
el = next(iterable_data_loader) |
|
|
40 |
inputs, labels = el['t1']['data'], el['label']['data'] |
|
|
41 |
print("Shape of Batch: [input {}] [label {}]".format(inputs.shape, labels.shape)) |
|
|
42 |
if check_net: |
|
|
43 |
net = VNet3D(num_outs=config.num_outs, channels=config.num_channels) |
|
|
44 |
outputs = net(inputs) |
|
|
45 |
print("Shape of Output: [output {}]".format(outputs.shape)) |
|
|
46 |
|
|
|
47 |
|
|
|
48 |
def print_folder(idx, train_index, val_index): |
|
|
49 |
print("+==================+") |
|
|
50 |
print("+ Cross Validation +") |
|
|
51 |
print("+ Folder {:d} +".format(idx)) |
|
|
52 |
print("+==================+") |
|
|
53 |
print("TRAIN [Images: {:3d}]:\n{}".format(len(train_index), train_index)) |
|
|
54 |
print("VAL [Images: {:3d}]:\n{}".format(len(val_index), val_index)) |
|
|
55 |
|
|
|
56 |
|
|
|
57 |
def print_test(): |
|
|
58 |
print("+============+") |
|
|
59 |
print("+ Test +") |
|
|
60 |
print("+============+") |
|
|
61 |
|
|
|
62 |
|
|
|
63 |
def train_val_split(train_images, train_labels, train_index, val_index): |
|
|
64 |
train_images_np, train_labels_np = np.array(train_images), np.array(train_labels) |
|
|
65 |
train_images_list = list(train_images_np[train_index]) |
|
|
66 |
val_images_list = list(train_images_np[val_index]) |
|
|
67 |
train_labels_list = list(train_labels_np[train_index]) |
|
|
68 |
val_labels_list = list(train_labels_np[val_index]) |
|
|
69 |
return train_images_list, val_images_list, train_labels_list, val_labels_list |
|
|
70 |
|
|
|
71 |
|
|
|
72 |
def train_val_split_config(config, train_index, val_index): |
|
|
73 |
train_images_list, val_images_list, train_labels_list, val_labels_list = \ |
|
|
74 |
train_val_split(config.train_images, config.train_labels, train_index, val_index) |
|
|
75 |
new_config = copy.copy(config) |
|
|
76 |
new_config.train_images, new_config.val_images = train_images_list, val_images_list |
|
|
77 |
new_config.train_labels, new_config.val_labels = train_labels_list, val_labels_list |
|
|
78 |
return new_config |
|
|
79 |
|
|
|
80 |
|
|
|
81 |
def nii_load(train_image_path): |
|
|
82 |
train_image_nii = nib.load(str(train_image_path), mmap=False) |
|
|
83 |
train_image_np = train_image_nii.get_fdata(dtype=np.float32) |
|
|
84 |
affine = train_image_nii.affine |
|
|
85 |
return train_image_np, affine |
|
|
86 |
|
|
|
87 |
|
|
|
88 |
def sitk_load(train_image_path): |
|
|
89 |
train_image_sitk = sitk.ReadImage(train_image_path) |
|
|
90 |
train_image_np = sitk.GetArrayFromImage(train_image_sitk) |
|
|
91 |
origin, spacing, direction = train_image_sitk.GetOrigin(), \ |
|
|
92 |
train_image_sitk.GetSpacing(), train_image_sitk.GetDirection() |
|
|
93 |
meta_sitk = { |
|
|
94 |
'origin' : origin, |
|
|
95 |
'spacing' : spacing, |
|
|
96 |
'direction': direction |
|
|
97 |
} |
|
|
98 |
return train_image_np, meta_sitk |
|
|
99 |
|
|
|
100 |
|
|
|
101 |
def nii_write(outputs_np, affine, filename_out): |
|
|
102 |
outputs_nib = nib.Nifti1Image(outputs_np, affine) |
|
|
103 |
outputs_nib.header['qform_code'] = 1 |
|
|
104 |
outputs_nib.header['sform_code'] = 0 |
|
|
105 |
outputs_nib.to_filename(filename_out) |
|
|
106 |
|
|
|
107 |
|
|
|
108 |
def sitk_write(outputs_np, meta_sitk, filename_out): |
|
|
109 |
outputs_sitk = sitk.GetImageFromArray(outputs_np) |
|
|
110 |
outputs_sitk.SetDirection(meta_sitk['direction']) |
|
|
111 |
outputs_sitk.SetSpacing(meta_sitk['spacing']) |
|
|
112 |
outputs_sitk.SetOrigin(meta_sitk['origin']) |
|
|
113 |
sitk.WriteImage(outputs_sitk, filename_out) |
|
|
114 |
|
|
|
115 |
|
|
|
116 |
def np3d_to_torch5d(train_image_np, pad_ref, cuda_dev): |
|
|
117 |
train_image_np = z_score_normalization(train_image_np) |
|
|
118 |
|
|
|
119 |
inputs_padded = zero_pad_3d_image(train_image_np, pad_ref, |
|
|
120 |
value_to_pad=train_image_np.min()) |
|
|
121 |
inputs_padded = np.expand_dims(inputs_padded, axis=0) # 1 x Z x Y x X |
|
|
122 |
inputs_padded = np.expand_dims(inputs_padded, axis=0) # 1 x 1 x Z x Y x X |
|
|
123 |
|
|
|
124 |
inputs = torch.from_numpy(inputs_padded).float() |
|
|
125 |
inputs = inputs.to(cuda_dev) |
|
|
126 |
return inputs |
|
|
127 |
|
|
|
128 |
|
|
|
129 |
def torch5d_to_np3d(outputs, original_shape): |
|
|
130 |
outputs = torch.argmax(outputs, dim=1) # 1 x Z x Y x X |
|
|
131 |
outputs_np = outputs.data.cpu().numpy() |
|
|
132 |
outputs_np = outputs_np[0] # Z x Y x X |
|
|
133 |
outputs_np = outputs_np[:original_shape[0],:original_shape[1],:original_shape[2]] |
|
|
134 |
outputs_np = outputs_np.astype(np.uint8) |
|
|
135 |
return outputs_np |
|
|
136 |
|
|
|
137 |
|
|
|
138 |
def print_metrics(multi_dices, f1_scores, train_confusion_matrix): |
|
|
139 |
multi_dices_np = np.array(multi_dices) |
|
|
140 |
mean_multi_dice = np.mean(multi_dices_np) |
|
|
141 |
std_multi_dice = np.std(multi_dices_np, ddof=1) |
|
|
142 |
|
|
|
143 |
f1_scores = np.array(f1_scores) |
|
|
144 |
|
|
|
145 |
f1_scores_anterior_mean = np.mean(f1_scores[:, 1]) |
|
|
146 |
f1_scores_anterior_std = np.std(f1_scores[:, 1], ddof=1) |
|
|
147 |
|
|
|
148 |
f1_scores_posterior_mean = np.mean(f1_scores[:, 2]) |
|
|
149 |
f1_scores_posterior_std = np.std(f1_scores[:, 2], ddof=1) |
|
|
150 |
|
|
|
151 |
print("+================================+") |
|
|
152 |
print("Multi Class Dice ===> {:.4f} +/- {:.4f}".format(mean_multi_dice, std_multi_dice)) |
|
|
153 |
print("Images with Dice > 0.8 ===> {} on {}".format((multi_dices_np > 0.8).sum(), multi_dices_np.size)) |
|
|
154 |
print("+================================+") |
|
|
155 |
print("Hippocampus Anterior Dice ===> {:.4f} +/- {:.4f}".format(f1_scores_anterior_mean, f1_scores_anterior_std)) |
|
|
156 |
print("Hippocampus Posterior Dice ===> {:.4f} +/- {:.4f}".format(f1_scores_posterior_mean, f1_scores_posterior_std)) |
|
|
157 |
print("+================================+") |
|
|
158 |
print("Confusion Matrix") |
|
|
159 |
print(train_confusion_matrix) |
|
|
160 |
print("+================================+") |
|
|
161 |
print("Normalized (All) Confusion Matrix") |
|
|
162 |
train_confusion_matrix_normalized_all = train_confusion_matrix / train_confusion_matrix.sum() |
|
|
163 |
print(train_confusion_matrix_normalized_all) |
|
|
164 |
print("+================================+") |
|
|
165 |
print("Normalized (Row) Confusion Matrix") |
|
|
166 |
train_confusion_matrix_normalized_row = train_confusion_matrix.astype('float') / \ |
|
|
167 |
train_confusion_matrix.sum(axis=1)[:, np.newaxis] |
|
|
168 |
print(train_confusion_matrix_normalized_row) |
|
|
169 |
print("+================================+") |
|
|
170 |
|
|
|
171 |
|
|
|
172 |
def plot_confusion_matrix(cm, |
|
|
173 |
target_names=None, |
|
|
174 |
title='Confusion matrix', |
|
|
175 |
cmap=None, |
|
|
176 |
normalize=True, |
|
|
177 |
already_normalized=False, |
|
|
178 |
path_out=None): |
|
|
179 |
""" |
|
|
180 |
given a sklearn confusion matrix (cm), make a nice plot |
|
|
181 |
|
|
|
182 |
Arguments |
|
|
183 |
--------- |
|
|
184 |
cm: confusion matrix from sklearn.metrics.confusion_matrix |
|
|
185 |
|
|
|
186 |
target_names: given classification classes such as [0, 1, 2] |
|
|
187 |
the class names, for example: ['high', 'medium', 'low'] |
|
|
188 |
|
|
|
189 |
title: the text to display at the top of the matrix |
|
|
190 |
|
|
|
191 |
cmap: the gradient of the values displayed from matplotlib.pyplot.cm |
|
|
192 |
see http://matplotlib.org/examples/color/colormaps_reference.html |
|
|
193 |
plt.get_cmap('jet') or plt.cm.Blues |
|
|
194 |
|
|
|
195 |
normalize: If False, plot the raw numbers |
|
|
196 |
If True, plot the proportions |
|
|
197 |
|
|
|
198 |
Usage |
|
|
199 |
----- |
|
|
200 |
plot_confusion_matrix(cm = cm, # confusion matrix created by |
|
|
201 |
# sklearn.metrics.confusion_matrix |
|
|
202 |
normalize = True, # show proportions |
|
|
203 |
target_names = y_labels_vals, # list of names of the classes |
|
|
204 |
title = best_estimator_name) # title of graph |
|
|
205 |
|
|
|
206 |
Citiation |
|
|
207 |
--------- |
|
|
208 |
http://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html |
|
|
209 |
|
|
|
210 |
""" |
|
|
211 |
import matplotlib.pyplot as plt |
|
|
212 |
import numpy as np |
|
|
213 |
import itertools |
|
|
214 |
|
|
|
215 |
accuracy = np.trace(cm) / np.sum(cm).astype('float') |
|
|
216 |
misclass = 1 - accuracy |
|
|
217 |
|
|
|
218 |
if cmap is None: |
|
|
219 |
cmap = plt.get_cmap('Blues') |
|
|
220 |
|
|
|
221 |
if normalize: |
|
|
222 |
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] |
|
|
223 |
|
|
|
224 |
plt.figure(figsize=(8, 8)) |
|
|
225 |
plt.matshow(cm, cmap=cmap) |
|
|
226 |
plt.title(title, pad=25.) |
|
|
227 |
plt.colorbar() |
|
|
228 |
|
|
|
229 |
if target_names is not None: |
|
|
230 |
tick_marks = np.arange(len(target_names)) |
|
|
231 |
plt.xticks(tick_marks, target_names, rotation=45) |
|
|
232 |
plt.yticks(tick_marks, target_names) |
|
|
233 |
|
|
|
234 |
thresh = cm.max() / 1.5 if normalize or already_normalized else cm.max() / 2 |
|
|
235 |
print("Thresh = {}".format(thresh)) |
|
|
236 |
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): |
|
|
237 |
if normalize or already_normalized: |
|
|
238 |
plt.text(j, i, "{:0.4f}".format(cm[i, j]), |
|
|
239 |
horizontalalignment="center", |
|
|
240 |
color="white" if cm[i, j] > thresh else "black") |
|
|
241 |
else: |
|
|
242 |
plt.text(j, i, "{:,}".format(cm[i, j]), |
|
|
243 |
horizontalalignment="center", |
|
|
244 |
color="white" if cm[i, j] > thresh else "black") |
|
|
245 |
|
|
|
246 |
plt.ylabel('True label') |
|
|
247 |
plt.xlabel('Predicted label\naccuracy={:0.4f}; misclass={:0.4f}'.format(accuracy, misclass)) |
|
|
248 |
if path_out is not None: |
|
|
249 |
plt.savefig(path_out) |
|
|
250 |
plt.show() |