Diff of /run/train.py [000000] .. [cc8b8f]

Switch to unified view

a b/run/train.py
1
##########################
2
# Nicola Altini (2020)
3
# V-Net for Hippocampus Segmentation from MRI with PyTorch
4
##########################
5
# python run/train.py
6
# python run/train.py --epochs=NUM_EPOCHS --batch=BATCH_SIZE --workers=NUM_WORKERS --lr=LR
7
# python run/train.py --epochs=5 --batch=1 --net=unet
8
9
##########################
10
# Imports
11
##########################
12
import argparse
13
import os
14
import sys
15
import numpy as np
16
import torch
17
import torch.optim as optim
18
from sklearn.model_selection import KFold
19
20
##########################
21
# Local Imports
22
##########################
23
current_path_abs = os.path.abspath('.')
24
sys.path.append(current_path_abs)
25
print('{} appended to sys!'.format(current_path_abs))
26
27
from run.utils import print_config, check_train_set, check_torch_loader, print_folder, train_val_split_config
28
from config.config import SemSegMRIConfig
29
from config.paths import logs_folder
30
from semseg.train import train_model, val_model
31
from semseg.data_loader import TorchIODataLoader3DTraining, TorchIODataLoader3DValidation
32
from models.vnet3d import VNet3D
33
from models.unet3d import UNet3D
34
35
36
def get_net(config):
37
    name = config.net
38
    assert name in ['unet', 'vnet'], "Network Name not valid or not supported! Use one of ['unet', 'vnet']"
39
    if name == 'vnet':
40
        return VNet3D(num_outs=config.num_outs, channels=config.num_channels)
41
    elif name == 'unet':
42
        return UNet3D(num_out_classes=config.num_outs, input_channels=1, init_feat_channels=32)
43
44
45
def run(config):
46
    ##########################
47
    # Check training set
48
    ##########################
49
    check_train_set(config)
50
51
    ##########################
52
    # Config
53
    ##########################
54
    print_config(config)
55
56
    ##########################
57
    # Check Torch DataLoader and Net
58
    ##########################
59
    check_torch_loader(config, check_net=False)
60
61
    ##########################
62
    # Training loop
63
    ##########################
64
    cuda_dev = torch.device('cuda')
65
66
    if config.do_crossval:
67
        ##########################
68
        # Training (cross-validation)
69
        ##########################
70
        multi_dices_crossval = list()
71
        mean_multi_dice_crossval = list()
72
        std_multi_dice_crossval = list()
73
74
        kf = KFold(n_splits=config.num_folders)
75
        for idx, (train_index, val_index) in enumerate(kf.split(config.train_images)):
76
            print_folder(idx, train_index, val_index)
77
            config_crossval = train_val_split_config(config, train_index, val_index)
78
79
            ##########################
80
            # Training (cross-validation)
81
            ##########################
82
            net = get_net(config_crossval)
83
            config_crossval.lr = 0.01
84
            optimizer = optim.Adam(net.parameters(), lr=config_crossval.lr)
85
            train_data_loader_3D = TorchIODataLoader3DTraining(config_crossval)
86
            net = train_model(net, optimizer, train_data_loader_3D,
87
                              config_crossval, device=cuda_dev, logs_folder=logs_folder)
88
89
            ##########################
90
            # Validation (cross-validation)
91
            ##########################
92
            val_data_loader_3D = TorchIODataLoader3DValidation(config_crossval)
93
            multi_dices, mean_multi_dice, std_multi_dice = val_model(net, val_data_loader_3D,
94
                                                                     config_crossval, device=cuda_dev)
95
            multi_dices_crossval.append(multi_dices)
96
            mean_multi_dice_crossval.append(mean_multi_dice)
97
            std_multi_dice_crossval.append(std_multi_dice)
98
            torch.save(net, os.path.join(logs_folder, "model_folder_{:d}.pt".format(idx)))
99
100
        ##########################
101
        # Saving Validation Results
102
        ##########################
103
        multi_dices_crossval_flatten = [item for sublist in multi_dices_crossval for item in sublist]
104
        mean_multi_dice_crossval_flatten = np.mean(multi_dices_crossval_flatten)
105
        std_multi_dice_crossval_flatten = np.std(multi_dices_crossval_flatten)
106
        print("Multi-Dice: {:.4f} +/- {:.4f}".format(mean_multi_dice_crossval_flatten, std_multi_dice_crossval_flatten))
107
        # Multi-Dice: 0.8728 +/- 0.0227
108
109
    ##########################
110
    # Training (full training set)
111
    ##########################
112
    net = get_net(config)
113
    config.lr = 0.01
114
    optimizer = optim.Adam(net.parameters(), lr=config.lr)
115
    train_data_loader_3D = TorchIODataLoader3DTraining(config)
116
    net = train_model(net, optimizer, train_data_loader_3D,
117
                      config, device=cuda_dev, logs_folder=logs_folder)
118
119
    torch.save(net,os.path.join(logs_folder,"model.pt"))
120
121
122
############################
123
# MAIN
124
############################
125
if __name__ == "__main__":
126
    config = SemSegMRIConfig()
127
128
    parser = argparse.ArgumentParser(description="Run Training on Hippocampus Segmentation")
129
    parser.add_argument(
130
        "-e",
131
        "--epochs",
132
        default=config.epochs, type=int,
133
        help="Specify the number of epochs required for training"
134
    )
135
    parser.add_argument(
136
        "-b",
137
        "--batch",
138
        default=config.batch_size, type=int,
139
        help="Specify the batch size"
140
    )
141
    parser.add_argument(
142
        "-v",
143
        "--val_epochs",
144
        default=config.val_epochs, type=int,
145
        help="Specify the number of validation epochs during training ** FOR FUTURE RELEASES **"
146
    )
147
    parser.add_argument(
148
        "-w",
149
        "--workers",
150
        default=config.num_workers, type=int,
151
        help="Specify the number of workers"
152
    )
153
    parser.add_argument(
154
        "--net",
155
        default='vnet',
156
        help="Specify the network to use [unet | vnet] ** FOR FUTURE RELEASES **"
157
    )
158
    parser.add_argument(
159
        "--lr",
160
        default=config.lr, type=float,
161
        help="Learning Rate"
162
    )
163
164
    args = parser.parse_args()
165
    config.net = args.net
166
    config.epochs = args.epochs
167
    config.batch_size = args.batch
168
    config.val_epochs = args.val_epochs
169
    config.num_workers = args.workers
170
    config.lr = args.lr
171
172
    run(config)