a b/src/uncertainty/uncertainty.py
1
from typing import Tuple, List
2
import torch
3
import numpy as np
4
from scipy import stats
5
import os
6
from tqdm import tqdm
7
8
from src.dataset.utils import nifi_volume
9
10
11
def get_entropy_uncertainty(prediction_score_vectors: List[torch.tensor], matrix_size: Tuple) -> np.ndarray:
12
    """
13
    Compute uncertainty using the entropy of the predictions in the predictions for the evaluation metrics WT, TC, ET
14
    :param prediction_score_vectors: list of tensors containing the predicted scores x each label computed using TTD
15
    :return: a prediction map with the global uncertainty value
16
    """
17
    prediction_score_vectors = torch.stack(tuple(prediction_score_vectors))
18
    mean = np.mean(prediction_score_vectors.cpu().numpy(), axis=0)
19
    entropy = stats.entropy(mean, axis=1).reshape(matrix_size) * 100
20
    return entropy.astype(np.uint8)
21
22
23
def get_variation_uncertainty(prediction_score_vectors: List[torch.tensor], matrix_size: Tuple) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
24
    """
25
    Compute uncertainty using the variance in the predictions for the evaluation metrics WT, TC, ET
26
    :param prediction_score_vectors: list of tensors containing the predicted scores x each label computed using TTD
27
    :return:
28
    """
29
    prediction_score_vectors = torch.stack(tuple(prediction_score_vectors))
30
31
    wt_var = np.var(np.sum(prediction_score_vectors[:, :, 1:].cpu().numpy(), axis=2), axis=0).reshape(matrix_size) * 100
32
    tc_var = np.var(np.sum(prediction_score_vectors[:, :, [1, 3]].cpu().numpy(), axis=2), axis=0).reshape( matrix_size) *100
33
    et_var = np.var(prediction_score_vectors[:, :, 3].cpu().numpy(), axis=0).reshape(matrix_size) * 100
34
35
    return wt_var.astype(np.uint8), tc_var.astype(np.uint8), et_var.astype(np.uint8)
36
37
38
def brats_normalize(uncertainty_map: np.ndarray, max_unc: int, min_unc: int) -> np.ndarray:
39
    minimum = 0
40
    maximum = 100
41
    step = (maximum - minimum) / (max_unc - min_unc)
42
    vfunc = np.vectorize(lambda x: (x - min_unc) * step if x != 0 else 0)
43
    return vfunc(uncertainty_map).astype(np.uint8)
44
45
46
def compute_normalization(input_dir, output_dir, ground_truth_path):
47
48
    file_list = sorted([file for file in os.listdir(input_dir) if "unc" in file and "nii.gz"])
49
    file_list_all = sorted([file for file in os.listdir(input_dir) if "nii.gz" in file])
50
51
    max_uncertainty = 0
52
    min_uncertainty = 10000
53
54
    for uncertainty_map in tqdm(file_list, total=len(file_list), desc="Getting min and max"):
55
56
        # Load Uncertainty maps
57
        patient_name = uncertainty_map.split(".")[0].split("_unc")[0]
58
        path_gt = os.path.join(ground_truth_path, patient_name, f"{patient_name}_flair.nii.gz")
59
        flair = nifi_volume.load_nifi_volume(path_gt, normalize=False)
60
        brain_mask = np.zeros(flair.shape, np.float)
61
        brain_mask[flair > 0] = 1
62
63
        path = os.path.join(input_dir, uncertainty_map)
64
        unc_map, _ = nifi_volume.load_nifi_volume_return_nib(path, normalize=False)
65
66
        tmp_max = np.max(unc_map[brain_mask == 1])
67
        tmp_min = np.min(unc_map[brain_mask == 1])
68
69
        if tmp_max > max_uncertainty:
70
            max_uncertainty = tmp_max
71
72
        if tmp_min < min_uncertainty:
73
            min_uncertainty = tmp_min
74
75
    for uncertainty_map_path in tqdm(file_list_all, total=len(file_list_all), desc="Normalizing.."):
76
77
        path = os.path.join(input_dir, uncertainty_map_path)
78
        output_path = os.path.join(output_dir, uncertainty_map_path)
79
80
        unc_map, nib_data = nifi_volume.load_nifi_volume_return_nib(path, normalize=False)
81
82
        if "unc" in uncertainty_map_path:
83
            uncertainty_map_normalized = brats_normalize(unc_map, max_unc=max_uncertainty, min_unc=min_uncertainty)
84
            print(f"Saving to: {output_path}")
85
            nifi_volume.save_segmask_as_nifi_volume(uncertainty_map_normalized, nib_data.affine, output_path)
86
        else:
87
            nifi_volume.save_segmask_as_nifi_volume(unc_map, nib_data.affine, output_path)