Diff of /run/validate.py [000000] .. [cc8b8f]

Switch to side-by-side view

--- a
+++ b/run/validate.py
@@ -0,0 +1,185 @@
+##########################
+# Nicola Altini (2020)
+# V-Net for Hippocampus Segmentation from MRI with PyTorch
+##########################
+# python run/validate.py
+# python run/validate.py --dir=logs/no_augm_torchio
+# python run/validate.py --dir=logs/no_augm_torchio --write=0
+# python run/validate.py --dir=path/to/logs/dir --write=WRITE --verbose=VERBOSE
+
+##########################
+# Imports
+##########################
+import torch
+import numpy as np
+import os
+from sklearn.model_selection import KFold
+import argparse
+import sys
+
+##########################
+# Local Imports
+##########################
+current_path_abs = os.path.abspath('.')
+sys.path.append(current_path_abs)
+print('{} appended to sys!'.format(current_path_abs))
+
+from config.paths import ( train_images_folder, train_labels_folder, train_prediction_folder,
+                           train_images, train_labels,
+                           test_images_folder, test_images, test_prediction_folder)
+from run.utils import (train_val_split, print_folder, nii_load, sitk_load, nii_write, print_config,
+                       sitk_write, print_test, np3d_to_torch5d, torch5d_to_np3d, print_metrics, plot_confusion_matrix)
+from config.config import SemSegMRIConfig
+from semseg.utils import multi_dice_coeff
+from sklearn.metrics import confusion_matrix, f1_score
+
+
+def run(logs_dir="logs", write_out=False, plot_conf=False):
+    ##########################
+    # Config
+    ##########################
+    config = SemSegMRIConfig()
+    print_config(config)
+
+    ###########################
+    # Load Net
+    ###########################
+    cuda_dev = torch.device("cuda")
+
+    # Load From State Dict
+    # path_net = "logs/model_epoch_0080.pht"
+    # net = VNet3D(num_outs=config.num_outs, channels=config.num_channels)
+    # net.load_state_dict(torch.load(path_net))
+
+    path_net = os.path.join(logs_dir,"model.pt")
+    path_nets_crossval = [os.path.join(logs_dir,"model_folder_{:d}.pt".format(idx))
+                          for idx in range(config.num_folders)]
+
+    ###########################
+    # Eval Loop
+    ###########################
+    use_nib = True
+    pad_ref = (48,64,48)
+    multi_dices = list()
+    f1_scores = list()
+
+    os.makedirs(train_prediction_folder, exist_ok=True)
+    os.makedirs(test_prediction_folder, exist_ok=True)
+
+    train_and_test = [True, False]
+    train_and_test_images = [train_images, test_images]
+    train_and_test_images_folder = [train_images_folder, test_images_folder]
+    train_and_test_prediction_folder = [train_prediction_folder, test_prediction_folder]
+    os.makedirs(train_prediction_folder,exist_ok=True)
+    os.makedirs(test_prediction_folder,exist_ok=True)
+
+    train_confusion_matrix = np.zeros((config.num_outs, config.num_outs))
+
+    for train_or_test_images, train_or_test_images_folder, train_or_test_prediction_folder, is_training in \
+            zip(train_and_test_images, train_and_test_images_folder, train_and_test_prediction_folder, train_and_test):
+        print("Images Folder: {}".format(train_or_test_images_folder))
+        print("IsTraining: {}".format(is_training))
+
+        kf = KFold(n_splits=config.num_folders)
+        for idx_crossval, (train_index, val_index) in enumerate(kf.split(train_images)):
+            if is_training:
+                print_folder(idx_crossval, train_index, val_index)
+                model_path = path_nets_crossval[idx_crossval]
+                print("Model: {}".format(model_path))
+                net = torch.load(model_path)
+                _, train_or_test_images, _, train_labels_crossval = \
+                    train_val_split(train_images, train_labels, train_index, val_index)
+            else:
+                print_test()
+                net = torch.load(path_net)
+            net = net.cuda(cuda_dev)
+            net.eval()
+
+            for idx, train_image in enumerate(train_or_test_images):
+                print("Iter {} on {}".format(idx,len(train_or_test_images)))
+                print("Image: {}".format(train_image))
+                train_image_path = os.path.join(train_or_test_images_folder, train_image)
+
+                if use_nib:
+                    train_image_np, affine = nii_load(train_image_path)
+                else:
+                    train_image_np, meta_sitk = sitk_load(train_image_path)
+
+                with torch.no_grad():
+                    inputs = np3d_to_torch5d(train_image_np, pad_ref, cuda_dev)
+                    outputs = net(inputs)
+                    outputs_np = torch5d_to_np3d(outputs, train_image_np.shape)
+
+                if write_out:
+                    filename_out = os.path.join(train_or_test_prediction_folder, train_image)
+                    if use_nib:
+                        nii_write(outputs_np, affine, filename_out)
+                    else:
+                        sitk_write(outputs_np, meta_sitk, filename_out)
+
+                if is_training:
+                    train_label = train_labels_crossval[idx]
+                    train_label_path = os.path.join(train_labels_folder, train_label)
+                    if use_nib:
+                        train_label_np, _ = nii_load(train_label_path)
+                    else:
+                        train_label_np, _ = sitk_load(train_label_path)
+
+                    multi_dice = multi_dice_coeff(np.expand_dims(train_label_np,axis=0),
+                                                  np.expand_dims(outputs_np,axis=0),
+                                                  config.num_outs)
+                    print("Multi Class Dice Coeff = {:.4f}".format(multi_dice))
+                    multi_dices.append(multi_dice)
+
+                    f1_score_idx = f1_score(train_label_np.flatten(), outputs_np.flatten(), average=None)
+                    cm_idx = confusion_matrix(train_label_np.flatten(), outputs_np.flatten())
+                    train_confusion_matrix += cm_idx
+                    f1_scores.append(f1_score_idx)
+
+            if not is_training:
+                break
+
+    print_metrics(multi_dices, f1_scores, train_confusion_matrix)
+
+    if plot_conf:
+        plot_confusion_matrix(train_confusion_matrix,
+                              target_names=None, title='Cross-Validation Confusion matrix',
+                              cmap=None, normalize=False, already_normalized=False,
+                              path_out="images/conf_matrix_no_norm_no_augm_torchio.png")
+        plot_confusion_matrix(train_confusion_matrix,
+                              target_names=None, title='Cross-Validation Confusion matrix (row-normalized)',
+                              cmap=None, normalize=True, already_normalized=False,
+                              path_out="images/conf_matrix_normalized_row_no_augm_torchio.png")
+
+
+############################
+# MAIN
+############################
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(description="Run Validation for Hippocampus Segmentation")
+    parser.add_argument(
+        "-V",
+        "--verbose",
+        default=False, type=bool,
+        help="Boolean flag. Set to true for VERBOSE mode; false otherwise."
+    )
+    parser.add_argument(
+        "-D",
+        "--dir",
+        default="logs", type=str,
+        help="Local path to logs dir"
+    )
+    parser.add_argument(
+        "-W",
+        "--write",
+        default=False, type=bool,
+        help="Boolean flag. Set to true for WRITE mode; false otherwise."
+    )
+    parser.add_argument(
+        "--net",
+        default='vnet',
+        help="Specify the network to use [unet | vnet] ** FOR FUTURE RELEASES **"
+    )
+
+    args = parser.parse_args()
+    run(logs_dir=args.dir, write_out=args.write, plot_conf=args.verbose)