Diff of /run/utils.py [000000] .. [cc8b8f]

Switch to unified view

a b/run/utils.py
1
import copy
2
import os
3
import torch
4
5
import numpy as np
6
import nibabel as nib
7
import SimpleITK as sitk
8
9
from semseg.data_loader import TorchIODataLoader3DTraining
10
from models.vnet3d import VNet3D
11
from semseg.utils import zero_pad_3d_image, z_score_normalization
12
13
14
def print_config(config):
15
    attributes_config = [attr for attr in dir(config)
16
                         if not attr.startswith('__')]
17
    print("Config")
18
    for item in attributes_config:
19
        attr_val = getattr(config,item)
20
        if len(str(attr_val)) < 100:
21
            print("{:15s} ==> {}".format(item, attr_val))
22
        else:
23
            print("{:15s} ==> String too long [{} characters]".format(item,len(str(attr_val))))
24
25
26
def check_train_set(config):
27
    num_train_images = len(config.train_images)
28
    num_train_labels = len(config.train_labels)
29
30
    assert num_train_images == num_train_labels, "Mismatch in number of training images and labels!"
31
32
    print("There are: {} Training Images".format(num_train_images))
33
    print("There are: {} Training Labels".format(num_train_labels))
34
35
36
def check_torch_loader(config, check_net=False):
37
    train_data_loader_3D = TorchIODataLoader3DTraining(config)
38
    iterable_data_loader = iter(train_data_loader_3D)
39
    el = next(iterable_data_loader)
40
    inputs, labels = el['t1']['data'], el['label']['data']
41
    print("Shape of Batch: [input {}] [label {}]".format(inputs.shape, labels.shape))
42
    if check_net:
43
        net = VNet3D(num_outs=config.num_outs, channels=config.num_channels)
44
        outputs = net(inputs)
45
        print("Shape of Output: [output {}]".format(outputs.shape))
46
47
48
def print_folder(idx, train_index, val_index):
49
    print("+==================+")
50
    print("+ Cross Validation +")
51
    print("+     Folder {:d}     +".format(idx))
52
    print("+==================+")
53
    print("TRAIN [Images: {:3d}]:\n{}".format(len(train_index), train_index))
54
    print("VAL   [Images: {:3d}]:\n{}".format(len(val_index), val_index))
55
56
57
def print_test():
58
    print("+============+")
59
    print("+   Test     +")
60
    print("+============+")
61
62
63
def train_val_split(train_images, train_labels, train_index, val_index):
64
    train_images_np, train_labels_np = np.array(train_images), np.array(train_labels)
65
    train_images_list = list(train_images_np[train_index])
66
    val_images_list = list(train_images_np[val_index])
67
    train_labels_list = list(train_labels_np[train_index])
68
    val_labels_list = list(train_labels_np[val_index])
69
    return train_images_list, val_images_list, train_labels_list, val_labels_list
70
71
72
def train_val_split_config(config, train_index, val_index):
73
    train_images_list, val_images_list, train_labels_list, val_labels_list = \
74
        train_val_split(config.train_images, config.train_labels, train_index, val_index)
75
    new_config = copy.copy(config)
76
    new_config.train_images, new_config.val_images = train_images_list, val_images_list
77
    new_config.train_labels, new_config.val_labels = train_labels_list, val_labels_list
78
    return new_config
79
80
81
def nii_load(train_image_path):
82
    train_image_nii = nib.load(str(train_image_path), mmap=False)
83
    train_image_np = train_image_nii.get_fdata(dtype=np.float32)
84
    affine = train_image_nii.affine
85
    return train_image_np, affine
86
87
88
def sitk_load(train_image_path):
89
    train_image_sitk = sitk.ReadImage(train_image_path)
90
    train_image_np = sitk.GetArrayFromImage(train_image_sitk)
91
    origin, spacing, direction = train_image_sitk.GetOrigin(), \
92
                                 train_image_sitk.GetSpacing(), train_image_sitk.GetDirection()
93
    meta_sitk = {
94
        'origin'   : origin,
95
        'spacing'  : spacing,
96
        'direction': direction
97
    }
98
    return train_image_np, meta_sitk
99
100
101
def nii_write(outputs_np, affine, filename_out):
102
    outputs_nib = nib.Nifti1Image(outputs_np, affine)
103
    outputs_nib.header['qform_code'] = 1
104
    outputs_nib.header['sform_code'] = 0
105
    outputs_nib.to_filename(filename_out)
106
107
108
def sitk_write(outputs_np, meta_sitk, filename_out):
109
    outputs_sitk = sitk.GetImageFromArray(outputs_np)
110
    outputs_sitk.SetDirection(meta_sitk['direction'])
111
    outputs_sitk.SetSpacing(meta_sitk['spacing'])
112
    outputs_sitk.SetOrigin(meta_sitk['origin'])
113
    sitk.WriteImage(outputs_sitk, filename_out)
114
115
116
def np3d_to_torch5d(train_image_np, pad_ref, cuda_dev):
117
    train_image_np = z_score_normalization(train_image_np)
118
119
    inputs_padded = zero_pad_3d_image(train_image_np, pad_ref,
120
                                      value_to_pad=train_image_np.min())
121
    inputs_padded = np.expand_dims(inputs_padded, axis=0)  # 1 x Z x Y x X
122
    inputs_padded = np.expand_dims(inputs_padded, axis=0)  # 1 x 1 x Z x Y x X
123
124
    inputs = torch.from_numpy(inputs_padded).float()
125
    inputs = inputs.to(cuda_dev)
126
    return inputs
127
128
129
def torch5d_to_np3d(outputs, original_shape):
130
    outputs = torch.argmax(outputs, dim=1)  # 1 x Z x Y x X
131
    outputs_np = outputs.data.cpu().numpy()
132
    outputs_np = outputs_np[0]  # Z x Y x X
133
    outputs_np = outputs_np[:original_shape[0],:original_shape[1],:original_shape[2]]
134
    outputs_np = outputs_np.astype(np.uint8)
135
    return outputs_np
136
137
138
def print_metrics(multi_dices, f1_scores, train_confusion_matrix):
139
    multi_dices_np = np.array(multi_dices)
140
    mean_multi_dice = np.mean(multi_dices_np)
141
    std_multi_dice = np.std(multi_dices_np, ddof=1)
142
143
    f1_scores = np.array(f1_scores)
144
145
    f1_scores_anterior_mean = np.mean(f1_scores[:, 1])
146
    f1_scores_anterior_std = np.std(f1_scores[:, 1], ddof=1)
147
148
    f1_scores_posterior_mean = np.mean(f1_scores[:, 2])
149
    f1_scores_posterior_std = np.std(f1_scores[:, 2], ddof=1)
150
151
    print("+================================+")
152
    print("Multi Class Dice           ===> {:.4f} +/- {:.4f}".format(mean_multi_dice, std_multi_dice))
153
    print("Images with Dice > 0.8     ===> {} on {}".format((multi_dices_np > 0.8).sum(), multi_dices_np.size))
154
    print("+================================+")
155
    print("Hippocampus Anterior Dice  ===> {:.4f} +/- {:.4f}".format(f1_scores_anterior_mean, f1_scores_anterior_std))
156
    print("Hippocampus Posterior Dice ===> {:.4f} +/- {:.4f}".format(f1_scores_posterior_mean, f1_scores_posterior_std))
157
    print("+================================+")
158
    print("Confusion Matrix")
159
    print(train_confusion_matrix)
160
    print("+================================+")
161
    print("Normalized (All) Confusion Matrix")
162
    train_confusion_matrix_normalized_all = train_confusion_matrix / train_confusion_matrix.sum()
163
    print(train_confusion_matrix_normalized_all)
164
    print("+================================+")
165
    print("Normalized (Row) Confusion Matrix")
166
    train_confusion_matrix_normalized_row = train_confusion_matrix.astype('float') / \
167
                                            train_confusion_matrix.sum(axis=1)[:, np.newaxis]
168
    print(train_confusion_matrix_normalized_row)
169
    print("+================================+")
170
171
172
def plot_confusion_matrix(cm,
173
                          target_names=None,
174
                          title='Confusion matrix',
175
                          cmap=None,
176
                          normalize=True,
177
                          already_normalized=False,
178
                          path_out=None):
179
    """
180
    given a sklearn confusion matrix (cm), make a nice plot
181
182
    Arguments
183
    ---------
184
    cm:           confusion matrix from sklearn.metrics.confusion_matrix
185
186
    target_names: given classification classes such as [0, 1, 2]
187
                  the class names, for example: ['high', 'medium', 'low']
188
189
    title:        the text to display at the top of the matrix
190
191
    cmap:         the gradient of the values displayed from matplotlib.pyplot.cm
192
                  see http://matplotlib.org/examples/color/colormaps_reference.html
193
                  plt.get_cmap('jet') or plt.cm.Blues
194
195
    normalize:    If False, plot the raw numbers
196
                  If True, plot the proportions
197
198
    Usage
199
    -----
200
    plot_confusion_matrix(cm           = cm,                  # confusion matrix created by
201
                                                              # sklearn.metrics.confusion_matrix
202
                          normalize    = True,                # show proportions
203
                          target_names = y_labels_vals,       # list of names of the classes
204
                          title        = best_estimator_name) # title of graph
205
206
    Citiation
207
    ---------
208
    http://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html
209
210
    """
211
    import matplotlib.pyplot as plt
212
    import numpy as np
213
    import itertools
214
215
    accuracy = np.trace(cm) / np.sum(cm).astype('float')
216
    misclass = 1 - accuracy
217
218
    if cmap is None:
219
        cmap = plt.get_cmap('Blues')
220
221
    if normalize:
222
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
223
224
    plt.figure(figsize=(8, 8))
225
    plt.matshow(cm, cmap=cmap)
226
    plt.title(title, pad=25.)
227
    plt.colorbar()
228
229
    if target_names is not None:
230
        tick_marks = np.arange(len(target_names))
231
        plt.xticks(tick_marks, target_names, rotation=45)
232
        plt.yticks(tick_marks, target_names)
233
234
    thresh = cm.max() / 1.5 if normalize or already_normalized else cm.max() / 2
235
    print("Thresh = {}".format(thresh))
236
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
237
        if normalize or already_normalized:
238
            plt.text(j, i, "{:0.4f}".format(cm[i, j]),
239
                     horizontalalignment="center",
240
                     color="white" if cm[i, j] > thresh else "black")
241
        else:
242
            plt.text(j, i, "{:,}".format(cm[i, j]),
243
                     horizontalalignment="center",
244
                     color="white" if cm[i, j] > thresh else "black")
245
246
    plt.ylabel('True label')
247
    plt.xlabel('Predicted label\naccuracy={:0.4f}; misclass={:0.4f}'.format(accuracy, misclass))
248
    if path_out is not None:
249
        plt.savefig(path_out)
250
    plt.show()