--- a +++ b/run/validate.py @@ -0,0 +1,185 @@ +########################## +# Nicola Altini (2020) +# V-Net for Hippocampus Segmentation from MRI with PyTorch +########################## +# python run/validate.py +# python run/validate.py --dir=logs/no_augm_torchio +# python run/validate.py --dir=logs/no_augm_torchio --write=0 +# python run/validate.py --dir=path/to/logs/dir --write=WRITE --verbose=VERBOSE + +########################## +# Imports +########################## +import torch +import numpy as np +import os +from sklearn.model_selection import KFold +import argparse +import sys + +########################## +# Local Imports +########################## +current_path_abs = os.path.abspath('.') +sys.path.append(current_path_abs) +print('{} appended to sys!'.format(current_path_abs)) + +from config.paths import ( train_images_folder, train_labels_folder, train_prediction_folder, + train_images, train_labels, + test_images_folder, test_images, test_prediction_folder) +from run.utils import (train_val_split, print_folder, nii_load, sitk_load, nii_write, print_config, + sitk_write, print_test, np3d_to_torch5d, torch5d_to_np3d, print_metrics, plot_confusion_matrix) +from config.config import SemSegMRIConfig +from semseg.utils import multi_dice_coeff +from sklearn.metrics import confusion_matrix, f1_score + + +def run(logs_dir="logs", write_out=False, plot_conf=False): + ########################## + # Config + ########################## + config = SemSegMRIConfig() + print_config(config) + + ########################### + # Load Net + ########################### + cuda_dev = torch.device("cuda") + + # Load From State Dict + # path_net = "logs/model_epoch_0080.pht" + # net = VNet3D(num_outs=config.num_outs, channels=config.num_channels) + # net.load_state_dict(torch.load(path_net)) + + path_net = os.path.join(logs_dir,"model.pt") + path_nets_crossval = [os.path.join(logs_dir,"model_folder_{:d}.pt".format(idx)) + for idx in range(config.num_folders)] + + ########################### + # Eval Loop + ########################### + use_nib = True + pad_ref = (48,64,48) + multi_dices = list() + f1_scores = list() + + os.makedirs(train_prediction_folder, exist_ok=True) + os.makedirs(test_prediction_folder, exist_ok=True) + + train_and_test = [True, False] + train_and_test_images = [train_images, test_images] + train_and_test_images_folder = [train_images_folder, test_images_folder] + train_and_test_prediction_folder = [train_prediction_folder, test_prediction_folder] + os.makedirs(train_prediction_folder,exist_ok=True) + os.makedirs(test_prediction_folder,exist_ok=True) + + train_confusion_matrix = np.zeros((config.num_outs, config.num_outs)) + + for train_or_test_images, train_or_test_images_folder, train_or_test_prediction_folder, is_training in \ + zip(train_and_test_images, train_and_test_images_folder, train_and_test_prediction_folder, train_and_test): + print("Images Folder: {}".format(train_or_test_images_folder)) + print("IsTraining: {}".format(is_training)) + + kf = KFold(n_splits=config.num_folders) + for idx_crossval, (train_index, val_index) in enumerate(kf.split(train_images)): + if is_training: + print_folder(idx_crossval, train_index, val_index) + model_path = path_nets_crossval[idx_crossval] + print("Model: {}".format(model_path)) + net = torch.load(model_path) + _, train_or_test_images, _, train_labels_crossval = \ + train_val_split(train_images, train_labels, train_index, val_index) + else: + print_test() + net = torch.load(path_net) + net = net.cuda(cuda_dev) + net.eval() + + for idx, train_image in enumerate(train_or_test_images): + print("Iter {} on {}".format(idx,len(train_or_test_images))) + print("Image: {}".format(train_image)) + train_image_path = os.path.join(train_or_test_images_folder, train_image) + + if use_nib: + train_image_np, affine = nii_load(train_image_path) + else: + train_image_np, meta_sitk = sitk_load(train_image_path) + + with torch.no_grad(): + inputs = np3d_to_torch5d(train_image_np, pad_ref, cuda_dev) + outputs = net(inputs) + outputs_np = torch5d_to_np3d(outputs, train_image_np.shape) + + if write_out: + filename_out = os.path.join(train_or_test_prediction_folder, train_image) + if use_nib: + nii_write(outputs_np, affine, filename_out) + else: + sitk_write(outputs_np, meta_sitk, filename_out) + + if is_training: + train_label = train_labels_crossval[idx] + train_label_path = os.path.join(train_labels_folder, train_label) + if use_nib: + train_label_np, _ = nii_load(train_label_path) + else: + train_label_np, _ = sitk_load(train_label_path) + + multi_dice = multi_dice_coeff(np.expand_dims(train_label_np,axis=0), + np.expand_dims(outputs_np,axis=0), + config.num_outs) + print("Multi Class Dice Coeff = {:.4f}".format(multi_dice)) + multi_dices.append(multi_dice) + + f1_score_idx = f1_score(train_label_np.flatten(), outputs_np.flatten(), average=None) + cm_idx = confusion_matrix(train_label_np.flatten(), outputs_np.flatten()) + train_confusion_matrix += cm_idx + f1_scores.append(f1_score_idx) + + if not is_training: + break + + print_metrics(multi_dices, f1_scores, train_confusion_matrix) + + if plot_conf: + plot_confusion_matrix(train_confusion_matrix, + target_names=None, title='Cross-Validation Confusion matrix', + cmap=None, normalize=False, already_normalized=False, + path_out="images/conf_matrix_no_norm_no_augm_torchio.png") + plot_confusion_matrix(train_confusion_matrix, + target_names=None, title='Cross-Validation Confusion matrix (row-normalized)', + cmap=None, normalize=True, already_normalized=False, + path_out="images/conf_matrix_normalized_row_no_augm_torchio.png") + + +############################ +# MAIN +############################ +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run Validation for Hippocampus Segmentation") + parser.add_argument( + "-V", + "--verbose", + default=False, type=bool, + help="Boolean flag. Set to true for VERBOSE mode; false otherwise." + ) + parser.add_argument( + "-D", + "--dir", + default="logs", type=str, + help="Local path to logs dir" + ) + parser.add_argument( + "-W", + "--write", + default=False, type=bool, + help="Boolean flag. Set to true for WRITE mode; false otherwise." + ) + parser.add_argument( + "--net", + default='vnet', + help="Specify the network to use [unet | vnet] ** FOR FUTURE RELEASES **" + ) + + args = parser.parse_args() + run(logs_dir=args.dir, write_out=args.write, plot_conf=args.verbose)