Diff of /run/validate_torchio.py [000000] .. [cc8b8f]

Switch to unified view

a b/run/validate_torchio.py
1
##########################
2
# Nicola Altini (2020)
3
# V-Net for Hippocampus Segmentation from MRI with PyTorch
4
##########################
5
# python run/validate_torchio.py
6
# python run/validate_torchio.py --dir=logs/no_augm_torchio
7
# python run/validate_torchio.py --dir=path/to/logs/dir --verbose=VERBOSE
8
9
##########################
10
# Imports
11
##########################
12
import os
13
import sys
14
import argparse
15
import numpy as np
16
import torch
17
from sklearn.model_selection import KFold
18
19
##########################
20
# Local Imports
21
##########################
22
current_path_abs = os.path.abspath('.')
23
sys.path.append(current_path_abs)
24
print('{} appended to sys!'.format(current_path_abs))
25
26
from run.utils import (train_val_split_config, print_folder, print_config, check_train_set)
27
from config.config import *
28
from config.paths import logs_folder, train_images, train_labels
29
from semseg.train import val_model
30
from semseg.data_loader import TorchIODataLoader3DValidation
31
32
33
def run(logs_dir="logs"):
34
    config = SemSegMRIConfig()
35
36
    ##########################
37
    # Check training set
38
    ##########################
39
    check_train_set(config)
40
41
    ##########################
42
    # Config
43
    ##########################
44
    config.batch_size = 1
45
    print_config(config)
46
47
    path_nets_crossval = [os.path.join(logs_dir,"model_folder_{:d}.pt".format(idx))
48
                          for idx in range(config.num_folders)]
49
50
    ##########################
51
    # Val loop
52
    ##########################
53
    cuda_dev = torch.device('cuda')
54
55
    if config.do_crossval:
56
        ##########################
57
        # cross-validation
58
        ##########################
59
        multi_dices_crossval = list()
60
        mean_multi_dice_crossval = list()
61
        std_multi_dice_crossval = list()
62
63
        kf = KFold(n_splits=config.num_folders)
64
        for idx, (train_index, val_index) in enumerate(kf.split(train_images)):
65
            print_folder(idx, train_index, val_index)
66
            config_crossval = train_val_split_config(config, train_index, val_index)
67
68
            ##########################
69
            # Training (cross-validation)
70
            ##########################
71
            model_path = path_nets_crossval[idx]
72
            print("Model: {}".format(model_path))
73
            net = torch.load(model_path)
74
75
            ##########################
76
            # Validation (cross-validation)
77
            ##########################
78
            val_data_loader_3D = TorchIODataLoader3DValidation(config_crossval)
79
            multi_dices, mean_multi_dice, std_multi_dice = val_model(net, val_data_loader_3D,
80
                                                                     config_crossval, device=cuda_dev)
81
            multi_dices_crossval.append(multi_dices)
82
            mean_multi_dice_crossval.append(mean_multi_dice)
83
            std_multi_dice_crossval.append(std_multi_dice)
84
            torch.save(net, os.path.join(logs_folder, "model_folder_{:d}.pt".format(idx)))
85
86
        ##########################
87
        # Saving Validation Results
88
        ##########################
89
        multi_dices_crossval_flatten = [item for sublist in multi_dices_crossval for item in sublist]
90
        mean_multi_dice_crossval_flatten = np.mean(multi_dices_crossval_flatten)
91
        std_multi_dice_crossval_flatten = np.std(multi_dices_crossval_flatten)
92
        print("Multi-Dice: {:.4f} +/- {:.4f}".format(mean_multi_dice_crossval_flatten, std_multi_dice_crossval_flatten))
93
        # Multi-Dice: 0.8668 +/- 0.0337
94
95
96
############################
97
# MAIN
98
############################
99
if __name__ == "__main__":
100
    parser = argparse.ArgumentParser(description="Run Validation (With torchio based Data Loader) "
101
                                                 "for Hippocampus Segmentation")
102
    parser.add_argument(
103
        "-V",
104
        "--verbose",
105
        default=False, type=bool,
106
        help="Boolean flag. Set to true for VERBOSE mode; false otherwise."
107
    )
108
    parser.add_argument(
109
        "-D",
110
        "--dir",
111
        default="logs", type=str,
112
        help="Local path to logs dir"
113
    )
114
    parser.add_argument(
115
        "--net",
116
        default='vnet',
117
        help="Specify the network to use [unet | vnet] ** FOR FUTURE RELEASES **"
118
    )
119
120
    args = parser.parse_args()
121
    run(logs_dir=args.dir)