Diff of /run/validate.py [000000] .. [cc8b8f]

Switch to unified view

a b/run/validate.py
1
##########################
2
# Nicola Altini (2020)
3
# V-Net for Hippocampus Segmentation from MRI with PyTorch
4
##########################
5
# python run/validate.py
6
# python run/validate.py --dir=logs/no_augm_torchio
7
# python run/validate.py --dir=logs/no_augm_torchio --write=0
8
# python run/validate.py --dir=path/to/logs/dir --write=WRITE --verbose=VERBOSE
9
10
##########################
11
# Imports
12
##########################
13
import torch
14
import numpy as np
15
import os
16
from sklearn.model_selection import KFold
17
import argparse
18
import sys
19
20
##########################
21
# Local Imports
22
##########################
23
current_path_abs = os.path.abspath('.')
24
sys.path.append(current_path_abs)
25
print('{} appended to sys!'.format(current_path_abs))
26
27
from config.paths import ( train_images_folder, train_labels_folder, train_prediction_folder,
28
                           train_images, train_labels,
29
                           test_images_folder, test_images, test_prediction_folder)
30
from run.utils import (train_val_split, print_folder, nii_load, sitk_load, nii_write, print_config,
31
                       sitk_write, print_test, np3d_to_torch5d, torch5d_to_np3d, print_metrics, plot_confusion_matrix)
32
from config.config import SemSegMRIConfig
33
from semseg.utils import multi_dice_coeff
34
from sklearn.metrics import confusion_matrix, f1_score
35
36
37
def run(logs_dir="logs", write_out=False, plot_conf=False):
38
    ##########################
39
    # Config
40
    ##########################
41
    config = SemSegMRIConfig()
42
    print_config(config)
43
44
    ###########################
45
    # Load Net
46
    ###########################
47
    cuda_dev = torch.device("cuda")
48
49
    # Load From State Dict
50
    # path_net = "logs/model_epoch_0080.pht"
51
    # net = VNet3D(num_outs=config.num_outs, channels=config.num_channels)
52
    # net.load_state_dict(torch.load(path_net))
53
54
    path_net = os.path.join(logs_dir,"model.pt")
55
    path_nets_crossval = [os.path.join(logs_dir,"model_folder_{:d}.pt".format(idx))
56
                          for idx in range(config.num_folders)]
57
58
    ###########################
59
    # Eval Loop
60
    ###########################
61
    use_nib = True
62
    pad_ref = (48,64,48)
63
    multi_dices = list()
64
    f1_scores = list()
65
66
    os.makedirs(train_prediction_folder, exist_ok=True)
67
    os.makedirs(test_prediction_folder, exist_ok=True)
68
69
    train_and_test = [True, False]
70
    train_and_test_images = [train_images, test_images]
71
    train_and_test_images_folder = [train_images_folder, test_images_folder]
72
    train_and_test_prediction_folder = [train_prediction_folder, test_prediction_folder]
73
    os.makedirs(train_prediction_folder,exist_ok=True)
74
    os.makedirs(test_prediction_folder,exist_ok=True)
75
76
    train_confusion_matrix = np.zeros((config.num_outs, config.num_outs))
77
78
    for train_or_test_images, train_or_test_images_folder, train_or_test_prediction_folder, is_training in \
79
            zip(train_and_test_images, train_and_test_images_folder, train_and_test_prediction_folder, train_and_test):
80
        print("Images Folder: {}".format(train_or_test_images_folder))
81
        print("IsTraining: {}".format(is_training))
82
83
        kf = KFold(n_splits=config.num_folders)
84
        for idx_crossval, (train_index, val_index) in enumerate(kf.split(train_images)):
85
            if is_training:
86
                print_folder(idx_crossval, train_index, val_index)
87
                model_path = path_nets_crossval[idx_crossval]
88
                print("Model: {}".format(model_path))
89
                net = torch.load(model_path)
90
                _, train_or_test_images, _, train_labels_crossval = \
91
                    train_val_split(train_images, train_labels, train_index, val_index)
92
            else:
93
                print_test()
94
                net = torch.load(path_net)
95
            net = net.cuda(cuda_dev)
96
            net.eval()
97
98
            for idx, train_image in enumerate(train_or_test_images):
99
                print("Iter {} on {}".format(idx,len(train_or_test_images)))
100
                print("Image: {}".format(train_image))
101
                train_image_path = os.path.join(train_or_test_images_folder, train_image)
102
103
                if use_nib:
104
                    train_image_np, affine = nii_load(train_image_path)
105
                else:
106
                    train_image_np, meta_sitk = sitk_load(train_image_path)
107
108
                with torch.no_grad():
109
                    inputs = np3d_to_torch5d(train_image_np, pad_ref, cuda_dev)
110
                    outputs = net(inputs)
111
                    outputs_np = torch5d_to_np3d(outputs, train_image_np.shape)
112
113
                if write_out:
114
                    filename_out = os.path.join(train_or_test_prediction_folder, train_image)
115
                    if use_nib:
116
                        nii_write(outputs_np, affine, filename_out)
117
                    else:
118
                        sitk_write(outputs_np, meta_sitk, filename_out)
119
120
                if is_training:
121
                    train_label = train_labels_crossval[idx]
122
                    train_label_path = os.path.join(train_labels_folder, train_label)
123
                    if use_nib:
124
                        train_label_np, _ = nii_load(train_label_path)
125
                    else:
126
                        train_label_np, _ = sitk_load(train_label_path)
127
128
                    multi_dice = multi_dice_coeff(np.expand_dims(train_label_np,axis=0),
129
                                                  np.expand_dims(outputs_np,axis=0),
130
                                                  config.num_outs)
131
                    print("Multi Class Dice Coeff = {:.4f}".format(multi_dice))
132
                    multi_dices.append(multi_dice)
133
134
                    f1_score_idx = f1_score(train_label_np.flatten(), outputs_np.flatten(), average=None)
135
                    cm_idx = confusion_matrix(train_label_np.flatten(), outputs_np.flatten())
136
                    train_confusion_matrix += cm_idx
137
                    f1_scores.append(f1_score_idx)
138
139
            if not is_training:
140
                break
141
142
    print_metrics(multi_dices, f1_scores, train_confusion_matrix)
143
144
    if plot_conf:
145
        plot_confusion_matrix(train_confusion_matrix,
146
                              target_names=None, title='Cross-Validation Confusion matrix',
147
                              cmap=None, normalize=False, already_normalized=False,
148
                              path_out="images/conf_matrix_no_norm_no_augm_torchio.png")
149
        plot_confusion_matrix(train_confusion_matrix,
150
                              target_names=None, title='Cross-Validation Confusion matrix (row-normalized)',
151
                              cmap=None, normalize=True, already_normalized=False,
152
                              path_out="images/conf_matrix_normalized_row_no_augm_torchio.png")
153
154
155
############################
156
# MAIN
157
############################
158
if __name__ == "__main__":
159
    parser = argparse.ArgumentParser(description="Run Validation for Hippocampus Segmentation")
160
    parser.add_argument(
161
        "-V",
162
        "--verbose",
163
        default=False, type=bool,
164
        help="Boolean flag. Set to true for VERBOSE mode; false otherwise."
165
    )
166
    parser.add_argument(
167
        "-D",
168
        "--dir",
169
        default="logs", type=str,
170
        help="Local path to logs dir"
171
    )
172
    parser.add_argument(
173
        "-W",
174
        "--write",
175
        default=False, type=bool,
176
        help="Boolean flag. Set to true for WRITE mode; false otherwise."
177
    )
178
    parser.add_argument(
179
        "--net",
180
        default='vnet',
181
        help="Specify the network to use [unet | vnet] ** FOR FUTURE RELEASES **"
182
    )
183
184
    args = parser.parse_args()
185
    run(logs_dir=args.dir, write_out=args.write, plot_conf=args.verbose)