--- a
+++ b/run/train.py
@@ -0,0 +1,172 @@
+##########################
+# Nicola Altini (2020)
+# V-Net for Hippocampus Segmentation from MRI with PyTorch
+##########################
+# python run/train.py
+# python run/train.py --epochs=NUM_EPOCHS --batch=BATCH_SIZE --workers=NUM_WORKERS --lr=LR
+# python run/train.py --epochs=5 --batch=1 --net=unet
+
+##########################
+# Imports
+##########################
+import argparse
+import os
+import sys
+import numpy as np
+import torch
+import torch.optim as optim
+from sklearn.model_selection import KFold
+
+##########################
+# Local Imports
+##########################
+current_path_abs = os.path.abspath('.')
+sys.path.append(current_path_abs)
+print('{} appended to sys!'.format(current_path_abs))
+
+from run.utils import print_config, check_train_set, check_torch_loader, print_folder, train_val_split_config
+from config.config import SemSegMRIConfig
+from config.paths import logs_folder
+from semseg.train import train_model, val_model
+from semseg.data_loader import TorchIODataLoader3DTraining, TorchIODataLoader3DValidation
+from models.vnet3d import VNet3D
+from models.unet3d import UNet3D
+
+
+def get_net(config):
+    name = config.net
+    assert name in ['unet', 'vnet'], "Network Name not valid or not supported! Use one of ['unet', 'vnet']"
+    if name == 'vnet':
+        return VNet3D(num_outs=config.num_outs, channels=config.num_channels)
+    elif name == 'unet':
+        return UNet3D(num_out_classes=config.num_outs, input_channels=1, init_feat_channels=32)
+
+
+def run(config):
+    ##########################
+    # Check training set
+    ##########################
+    check_train_set(config)
+
+    ##########################
+    # Config
+    ##########################
+    print_config(config)
+
+    ##########################
+    # Check Torch DataLoader and Net
+    ##########################
+    check_torch_loader(config, check_net=False)
+
+    ##########################
+    # Training loop
+    ##########################
+    cuda_dev = torch.device('cuda')
+
+    if config.do_crossval:
+        ##########################
+        # Training (cross-validation)
+        ##########################
+        multi_dices_crossval = list()
+        mean_multi_dice_crossval = list()
+        std_multi_dice_crossval = list()
+
+        kf = KFold(n_splits=config.num_folders)
+        for idx, (train_index, val_index) in enumerate(kf.split(config.train_images)):
+            print_folder(idx, train_index, val_index)
+            config_crossval = train_val_split_config(config, train_index, val_index)
+
+            ##########################
+            # Training (cross-validation)
+            ##########################
+            net = get_net(config_crossval)
+            config_crossval.lr = 0.01
+            optimizer = optim.Adam(net.parameters(), lr=config_crossval.lr)
+            train_data_loader_3D = TorchIODataLoader3DTraining(config_crossval)
+            net = train_model(net, optimizer, train_data_loader_3D,
+                              config_crossval, device=cuda_dev, logs_folder=logs_folder)
+
+            ##########################
+            # Validation (cross-validation)
+            ##########################
+            val_data_loader_3D = TorchIODataLoader3DValidation(config_crossval)
+            multi_dices, mean_multi_dice, std_multi_dice = val_model(net, val_data_loader_3D,
+                                                                     config_crossval, device=cuda_dev)
+            multi_dices_crossval.append(multi_dices)
+            mean_multi_dice_crossval.append(mean_multi_dice)
+            std_multi_dice_crossval.append(std_multi_dice)
+            torch.save(net, os.path.join(logs_folder, "model_folder_{:d}.pt".format(idx)))
+
+        ##########################
+        # Saving Validation Results
+        ##########################
+        multi_dices_crossval_flatten = [item for sublist in multi_dices_crossval for item in sublist]
+        mean_multi_dice_crossval_flatten = np.mean(multi_dices_crossval_flatten)
+        std_multi_dice_crossval_flatten = np.std(multi_dices_crossval_flatten)
+        print("Multi-Dice: {:.4f} +/- {:.4f}".format(mean_multi_dice_crossval_flatten, std_multi_dice_crossval_flatten))
+        # Multi-Dice: 0.8728 +/- 0.0227
+
+    ##########################
+    # Training (full training set)
+    ##########################
+    net = get_net(config)
+    config.lr = 0.01
+    optimizer = optim.Adam(net.parameters(), lr=config.lr)
+    train_data_loader_3D = TorchIODataLoader3DTraining(config)
+    net = train_model(net, optimizer, train_data_loader_3D,
+                      config, device=cuda_dev, logs_folder=logs_folder)
+
+    torch.save(net,os.path.join(logs_folder,"model.pt"))
+
+
+############################
+# MAIN
+############################
+if __name__ == "__main__":
+    config = SemSegMRIConfig()
+
+    parser = argparse.ArgumentParser(description="Run Training on Hippocampus Segmentation")
+    parser.add_argument(
+        "-e",
+        "--epochs",
+        default=config.epochs, type=int,
+        help="Specify the number of epochs required for training"
+    )
+    parser.add_argument(
+        "-b",
+        "--batch",
+        default=config.batch_size, type=int,
+        help="Specify the batch size"
+    )
+    parser.add_argument(
+        "-v",
+        "--val_epochs",
+        default=config.val_epochs, type=int,
+        help="Specify the number of validation epochs during training ** FOR FUTURE RELEASES **"
+    )
+    parser.add_argument(
+        "-w",
+        "--workers",
+        default=config.num_workers, type=int,
+        help="Specify the number of workers"
+    )
+    parser.add_argument(
+        "--net",
+        default='vnet',
+        help="Specify the network to use [unet | vnet] ** FOR FUTURE RELEASES **"
+    )
+    parser.add_argument(
+        "--lr",
+        default=config.lr, type=float,
+        help="Learning Rate"
+    )
+
+    args = parser.parse_args()
+    config.net = args.net
+    config.epochs = args.epochs
+    config.batch_size = args.batch
+    config.val_epochs = args.val_epochs
+    config.num_workers = args.workers
+    config.lr = args.lr
+
+    run(config)