|
a |
|
b/run/train.py |
|
|
1 |
########################## |
|
|
2 |
# Nicola Altini (2020) |
|
|
3 |
# V-Net for Hippocampus Segmentation from MRI with PyTorch |
|
|
4 |
########################## |
|
|
5 |
# python run/train.py |
|
|
6 |
# python run/train.py --epochs=NUM_EPOCHS --batch=BATCH_SIZE --workers=NUM_WORKERS --lr=LR |
|
|
7 |
# python run/train.py --epochs=5 --batch=1 --net=unet |
|
|
8 |
|
|
|
9 |
########################## |
|
|
10 |
# Imports |
|
|
11 |
########################## |
|
|
12 |
import argparse |
|
|
13 |
import os |
|
|
14 |
import sys |
|
|
15 |
import numpy as np |
|
|
16 |
import torch |
|
|
17 |
import torch.optim as optim |
|
|
18 |
from sklearn.model_selection import KFold |
|
|
19 |
|
|
|
20 |
########################## |
|
|
21 |
# Local Imports |
|
|
22 |
########################## |
|
|
23 |
current_path_abs = os.path.abspath('.') |
|
|
24 |
sys.path.append(current_path_abs) |
|
|
25 |
print('{} appended to sys!'.format(current_path_abs)) |
|
|
26 |
|
|
|
27 |
from run.utils import print_config, check_train_set, check_torch_loader, print_folder, train_val_split_config |
|
|
28 |
from config.config import SemSegMRIConfig |
|
|
29 |
from config.paths import logs_folder |
|
|
30 |
from semseg.train import train_model, val_model |
|
|
31 |
from semseg.data_loader import TorchIODataLoader3DTraining, TorchIODataLoader3DValidation |
|
|
32 |
from models.vnet3d import VNet3D |
|
|
33 |
from models.unet3d import UNet3D |
|
|
34 |
|
|
|
35 |
|
|
|
36 |
def get_net(config): |
|
|
37 |
name = config.net |
|
|
38 |
assert name in ['unet', 'vnet'], "Network Name not valid or not supported! Use one of ['unet', 'vnet']" |
|
|
39 |
if name == 'vnet': |
|
|
40 |
return VNet3D(num_outs=config.num_outs, channels=config.num_channels) |
|
|
41 |
elif name == 'unet': |
|
|
42 |
return UNet3D(num_out_classes=config.num_outs, input_channels=1, init_feat_channels=32) |
|
|
43 |
|
|
|
44 |
|
|
|
45 |
def run(config): |
|
|
46 |
########################## |
|
|
47 |
# Check training set |
|
|
48 |
########################## |
|
|
49 |
check_train_set(config) |
|
|
50 |
|
|
|
51 |
########################## |
|
|
52 |
# Config |
|
|
53 |
########################## |
|
|
54 |
print_config(config) |
|
|
55 |
|
|
|
56 |
########################## |
|
|
57 |
# Check Torch DataLoader and Net |
|
|
58 |
########################## |
|
|
59 |
check_torch_loader(config, check_net=False) |
|
|
60 |
|
|
|
61 |
########################## |
|
|
62 |
# Training loop |
|
|
63 |
########################## |
|
|
64 |
cuda_dev = torch.device('cuda') |
|
|
65 |
|
|
|
66 |
if config.do_crossval: |
|
|
67 |
########################## |
|
|
68 |
# Training (cross-validation) |
|
|
69 |
########################## |
|
|
70 |
multi_dices_crossval = list() |
|
|
71 |
mean_multi_dice_crossval = list() |
|
|
72 |
std_multi_dice_crossval = list() |
|
|
73 |
|
|
|
74 |
kf = KFold(n_splits=config.num_folders) |
|
|
75 |
for idx, (train_index, val_index) in enumerate(kf.split(config.train_images)): |
|
|
76 |
print_folder(idx, train_index, val_index) |
|
|
77 |
config_crossval = train_val_split_config(config, train_index, val_index) |
|
|
78 |
|
|
|
79 |
########################## |
|
|
80 |
# Training (cross-validation) |
|
|
81 |
########################## |
|
|
82 |
net = get_net(config_crossval) |
|
|
83 |
config_crossval.lr = 0.01 |
|
|
84 |
optimizer = optim.Adam(net.parameters(), lr=config_crossval.lr) |
|
|
85 |
train_data_loader_3D = TorchIODataLoader3DTraining(config_crossval) |
|
|
86 |
net = train_model(net, optimizer, train_data_loader_3D, |
|
|
87 |
config_crossval, device=cuda_dev, logs_folder=logs_folder) |
|
|
88 |
|
|
|
89 |
########################## |
|
|
90 |
# Validation (cross-validation) |
|
|
91 |
########################## |
|
|
92 |
val_data_loader_3D = TorchIODataLoader3DValidation(config_crossval) |
|
|
93 |
multi_dices, mean_multi_dice, std_multi_dice = val_model(net, val_data_loader_3D, |
|
|
94 |
config_crossval, device=cuda_dev) |
|
|
95 |
multi_dices_crossval.append(multi_dices) |
|
|
96 |
mean_multi_dice_crossval.append(mean_multi_dice) |
|
|
97 |
std_multi_dice_crossval.append(std_multi_dice) |
|
|
98 |
torch.save(net, os.path.join(logs_folder, "model_folder_{:d}.pt".format(idx))) |
|
|
99 |
|
|
|
100 |
########################## |
|
|
101 |
# Saving Validation Results |
|
|
102 |
########################## |
|
|
103 |
multi_dices_crossval_flatten = [item for sublist in multi_dices_crossval for item in sublist] |
|
|
104 |
mean_multi_dice_crossval_flatten = np.mean(multi_dices_crossval_flatten) |
|
|
105 |
std_multi_dice_crossval_flatten = np.std(multi_dices_crossval_flatten) |
|
|
106 |
print("Multi-Dice: {:.4f} +/- {:.4f}".format(mean_multi_dice_crossval_flatten, std_multi_dice_crossval_flatten)) |
|
|
107 |
# Multi-Dice: 0.8728 +/- 0.0227 |
|
|
108 |
|
|
|
109 |
########################## |
|
|
110 |
# Training (full training set) |
|
|
111 |
########################## |
|
|
112 |
net = get_net(config) |
|
|
113 |
config.lr = 0.01 |
|
|
114 |
optimizer = optim.Adam(net.parameters(), lr=config.lr) |
|
|
115 |
train_data_loader_3D = TorchIODataLoader3DTraining(config) |
|
|
116 |
net = train_model(net, optimizer, train_data_loader_3D, |
|
|
117 |
config, device=cuda_dev, logs_folder=logs_folder) |
|
|
118 |
|
|
|
119 |
torch.save(net,os.path.join(logs_folder,"model.pt")) |
|
|
120 |
|
|
|
121 |
|
|
|
122 |
############################ |
|
|
123 |
# MAIN |
|
|
124 |
############################ |
|
|
125 |
if __name__ == "__main__": |
|
|
126 |
config = SemSegMRIConfig() |
|
|
127 |
|
|
|
128 |
parser = argparse.ArgumentParser(description="Run Training on Hippocampus Segmentation") |
|
|
129 |
parser.add_argument( |
|
|
130 |
"-e", |
|
|
131 |
"--epochs", |
|
|
132 |
default=config.epochs, type=int, |
|
|
133 |
help="Specify the number of epochs required for training" |
|
|
134 |
) |
|
|
135 |
parser.add_argument( |
|
|
136 |
"-b", |
|
|
137 |
"--batch", |
|
|
138 |
default=config.batch_size, type=int, |
|
|
139 |
help="Specify the batch size" |
|
|
140 |
) |
|
|
141 |
parser.add_argument( |
|
|
142 |
"-v", |
|
|
143 |
"--val_epochs", |
|
|
144 |
default=config.val_epochs, type=int, |
|
|
145 |
help="Specify the number of validation epochs during training ** FOR FUTURE RELEASES **" |
|
|
146 |
) |
|
|
147 |
parser.add_argument( |
|
|
148 |
"-w", |
|
|
149 |
"--workers", |
|
|
150 |
default=config.num_workers, type=int, |
|
|
151 |
help="Specify the number of workers" |
|
|
152 |
) |
|
|
153 |
parser.add_argument( |
|
|
154 |
"--net", |
|
|
155 |
default='vnet', |
|
|
156 |
help="Specify the network to use [unet | vnet] ** FOR FUTURE RELEASES **" |
|
|
157 |
) |
|
|
158 |
parser.add_argument( |
|
|
159 |
"--lr", |
|
|
160 |
default=config.lr, type=float, |
|
|
161 |
help="Learning Rate" |
|
|
162 |
) |
|
|
163 |
|
|
|
164 |
args = parser.parse_args() |
|
|
165 |
config.net = args.net |
|
|
166 |
config.epochs = args.epochs |
|
|
167 |
config.batch_size = args.batch |
|
|
168 |
config.val_epochs = args.val_epochs |
|
|
169 |
config.num_workers = args.workers |
|
|
170 |
config.lr = args.lr |
|
|
171 |
|
|
|
172 |
run(config) |