--- a +++ b/run/train.py @@ -0,0 +1,172 @@ +########################## +# Nicola Altini (2020) +# V-Net for Hippocampus Segmentation from MRI with PyTorch +########################## +# python run/train.py +# python run/train.py --epochs=NUM_EPOCHS --batch=BATCH_SIZE --workers=NUM_WORKERS --lr=LR +# python run/train.py --epochs=5 --batch=1 --net=unet + +########################## +# Imports +########################## +import argparse +import os +import sys +import numpy as np +import torch +import torch.optim as optim +from sklearn.model_selection import KFold + +########################## +# Local Imports +########################## +current_path_abs = os.path.abspath('.') +sys.path.append(current_path_abs) +print('{} appended to sys!'.format(current_path_abs)) + +from run.utils import print_config, check_train_set, check_torch_loader, print_folder, train_val_split_config +from config.config import SemSegMRIConfig +from config.paths import logs_folder +from semseg.train import train_model, val_model +from semseg.data_loader import TorchIODataLoader3DTraining, TorchIODataLoader3DValidation +from models.vnet3d import VNet3D +from models.unet3d import UNet3D + + +def get_net(config): + name = config.net + assert name in ['unet', 'vnet'], "Network Name not valid or not supported! Use one of ['unet', 'vnet']" + if name == 'vnet': + return VNet3D(num_outs=config.num_outs, channels=config.num_channels) + elif name == 'unet': + return UNet3D(num_out_classes=config.num_outs, input_channels=1, init_feat_channels=32) + + +def run(config): + ########################## + # Check training set + ########################## + check_train_set(config) + + ########################## + # Config + ########################## + print_config(config) + + ########################## + # Check Torch DataLoader and Net + ########################## + check_torch_loader(config, check_net=False) + + ########################## + # Training loop + ########################## + cuda_dev = torch.device('cuda') + + if config.do_crossval: + ########################## + # Training (cross-validation) + ########################## + multi_dices_crossval = list() + mean_multi_dice_crossval = list() + std_multi_dice_crossval = list() + + kf = KFold(n_splits=config.num_folders) + for idx, (train_index, val_index) in enumerate(kf.split(config.train_images)): + print_folder(idx, train_index, val_index) + config_crossval = train_val_split_config(config, train_index, val_index) + + ########################## + # Training (cross-validation) + ########################## + net = get_net(config_crossval) + config_crossval.lr = 0.01 + optimizer = optim.Adam(net.parameters(), lr=config_crossval.lr) + train_data_loader_3D = TorchIODataLoader3DTraining(config_crossval) + net = train_model(net, optimizer, train_data_loader_3D, + config_crossval, device=cuda_dev, logs_folder=logs_folder) + + ########################## + # Validation (cross-validation) + ########################## + val_data_loader_3D = TorchIODataLoader3DValidation(config_crossval) + multi_dices, mean_multi_dice, std_multi_dice = val_model(net, val_data_loader_3D, + config_crossval, device=cuda_dev) + multi_dices_crossval.append(multi_dices) + mean_multi_dice_crossval.append(mean_multi_dice) + std_multi_dice_crossval.append(std_multi_dice) + torch.save(net, os.path.join(logs_folder, "model_folder_{:d}.pt".format(idx))) + + ########################## + # Saving Validation Results + ########################## + multi_dices_crossval_flatten = [item for sublist in multi_dices_crossval for item in sublist] + mean_multi_dice_crossval_flatten = np.mean(multi_dices_crossval_flatten) + std_multi_dice_crossval_flatten = np.std(multi_dices_crossval_flatten) + print("Multi-Dice: {:.4f} +/- {:.4f}".format(mean_multi_dice_crossval_flatten, std_multi_dice_crossval_flatten)) + # Multi-Dice: 0.8728 +/- 0.0227 + + ########################## + # Training (full training set) + ########################## + net = get_net(config) + config.lr = 0.01 + optimizer = optim.Adam(net.parameters(), lr=config.lr) + train_data_loader_3D = TorchIODataLoader3DTraining(config) + net = train_model(net, optimizer, train_data_loader_3D, + config, device=cuda_dev, logs_folder=logs_folder) + + torch.save(net,os.path.join(logs_folder,"model.pt")) + + +############################ +# MAIN +############################ +if __name__ == "__main__": + config = SemSegMRIConfig() + + parser = argparse.ArgumentParser(description="Run Training on Hippocampus Segmentation") + parser.add_argument( + "-e", + "--epochs", + default=config.epochs, type=int, + help="Specify the number of epochs required for training" + ) + parser.add_argument( + "-b", + "--batch", + default=config.batch_size, type=int, + help="Specify the batch size" + ) + parser.add_argument( + "-v", + "--val_epochs", + default=config.val_epochs, type=int, + help="Specify the number of validation epochs during training ** FOR FUTURE RELEASES **" + ) + parser.add_argument( + "-w", + "--workers", + default=config.num_workers, type=int, + help="Specify the number of workers" + ) + parser.add_argument( + "--net", + default='vnet', + help="Specify the network to use [unet | vnet] ** FOR FUTURE RELEASES **" + ) + parser.add_argument( + "--lr", + default=config.lr, type=float, + help="Learning Rate" + ) + + args = parser.parse_args() + config.net = args.net + config.epochs = args.epochs + config.batch_size = args.batch + config.val_epochs = args.val_epochs + config.num_workers = args.workers + config.lr = args.lr + + run(config)