--- a +++ b/src/uncertainty/uncertainty.py @@ -0,0 +1,87 @@ +from typing import Tuple, List +import torch +import numpy as np +from scipy import stats +import os +from tqdm import tqdm + +from src.dataset.utils import nifi_volume + + +def get_entropy_uncertainty(prediction_score_vectors: List[torch.tensor], matrix_size: Tuple) -> np.ndarray: + """ + Compute uncertainty using the entropy of the predictions in the predictions for the evaluation metrics WT, TC, ET + :param prediction_score_vectors: list of tensors containing the predicted scores x each label computed using TTD + :return: a prediction map with the global uncertainty value + """ + prediction_score_vectors = torch.stack(tuple(prediction_score_vectors)) + mean = np.mean(prediction_score_vectors.cpu().numpy(), axis=0) + entropy = stats.entropy(mean, axis=1).reshape(matrix_size) * 100 + return entropy.astype(np.uint8) + + +def get_variation_uncertainty(prediction_score_vectors: List[torch.tensor], matrix_size: Tuple) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Compute uncertainty using the variance in the predictions for the evaluation metrics WT, TC, ET + :param prediction_score_vectors: list of tensors containing the predicted scores x each label computed using TTD + :return: + """ + prediction_score_vectors = torch.stack(tuple(prediction_score_vectors)) + + wt_var = np.var(np.sum(prediction_score_vectors[:, :, 1:].cpu().numpy(), axis=2), axis=0).reshape(matrix_size) * 100 + tc_var = np.var(np.sum(prediction_score_vectors[:, :, [1, 3]].cpu().numpy(), axis=2), axis=0).reshape( matrix_size) *100 + et_var = np.var(prediction_score_vectors[:, :, 3].cpu().numpy(), axis=0).reshape(matrix_size) * 100 + + return wt_var.astype(np.uint8), tc_var.astype(np.uint8), et_var.astype(np.uint8) + + +def brats_normalize(uncertainty_map: np.ndarray, max_unc: int, min_unc: int) -> np.ndarray: + minimum = 0 + maximum = 100 + step = (maximum - minimum) / (max_unc - min_unc) + vfunc = np.vectorize(lambda x: (x - min_unc) * step if x != 0 else 0) + return vfunc(uncertainty_map).astype(np.uint8) + + +def compute_normalization(input_dir, output_dir, ground_truth_path): + + file_list = sorted([file for file in os.listdir(input_dir) if "unc" in file and "nii.gz"]) + file_list_all = sorted([file for file in os.listdir(input_dir) if "nii.gz" in file]) + + max_uncertainty = 0 + min_uncertainty = 10000 + + for uncertainty_map in tqdm(file_list, total=len(file_list), desc="Getting min and max"): + + # Load Uncertainty maps + patient_name = uncertainty_map.split(".")[0].split("_unc")[0] + path_gt = os.path.join(ground_truth_path, patient_name, f"{patient_name}_flair.nii.gz") + flair = nifi_volume.load_nifi_volume(path_gt, normalize=False) + brain_mask = np.zeros(flair.shape, np.float) + brain_mask[flair > 0] = 1 + + path = os.path.join(input_dir, uncertainty_map) + unc_map, _ = nifi_volume.load_nifi_volume_return_nib(path, normalize=False) + + tmp_max = np.max(unc_map[brain_mask == 1]) + tmp_min = np.min(unc_map[brain_mask == 1]) + + if tmp_max > max_uncertainty: + max_uncertainty = tmp_max + + if tmp_min < min_uncertainty: + min_uncertainty = tmp_min + + for uncertainty_map_path in tqdm(file_list_all, total=len(file_list_all), desc="Normalizing.."): + + path = os.path.join(input_dir, uncertainty_map_path) + output_path = os.path.join(output_dir, uncertainty_map_path) + + unc_map, nib_data = nifi_volume.load_nifi_volume_return_nib(path, normalize=False) + + if "unc" in uncertainty_map_path: + uncertainty_map_normalized = brats_normalize(unc_map, max_unc=max_uncertainty, min_unc=min_uncertainty) + print(f"Saving to: {output_path}") + nifi_volume.save_segmask_as_nifi_volume(uncertainty_map_normalized, nib_data.affine, output_path) + else: + nifi_volume.save_segmask_as_nifi_volume(unc_map, nib_data.affine, output_path) \ No newline at end of file