Switch to unified view

a b/src/dataset/brats_labels.py
1
import numpy as np
2
import torch
3
4
def brats_segmentation_regions() -> dict:
5
    return {"ET_brats": 4, "ET": 3,  "NCR-NET": 1, "ED": 2}
6
7
8
def get_ncr_net(segmentation_map: np.ndarray) -> np.ndarray:
9
    regions = brats_segmentation_regions()
10
    copied_segmentation = _copy_input(segmentation_map)
11
    copied_segmentation[copied_segmentation != regions["NCR-NET"]] = 0
12
    return copied_segmentation.astype(np.uint8)
13
14
15
def get_ed(segmentation_map: np.ndarray) -> np.ndarray:
16
    regions = brats_segmentation_regions()
17
    copied_segmentation = _copy_input(segmentation_map)
18
    copied_segmentation[copied_segmentation != regions["ED"]] = 0
19
    return copied_segmentation.astype(np.uint8)
20
21
22
def get_et(segmentation_map: np.ndarray) -> np.ndarray:
23
    """
24
    ET : enhancing tumors is label 3 in the code, 4 as brats
25
    :param segmentation_map:
26
    :return: only label for ET
27
    """
28
    regions = brats_segmentation_regions()
29
    copied_segmentation = _copy_input(segmentation_map)
30
    unique_values = np.unique(copied_segmentation)
31
    if max(unique_values) == 3:
32
        copied_segmentation[copied_segmentation != regions["ET"]] = 0
33
    else:
34
        copied_segmentation[copied_segmentation != regions["ET_brats"]] = 0
35
36
    copied_segmentation[copied_segmentation != 0] = 1
37
    return copied_segmentation.astype(np.uint8)
38
39
40
def get_wt(segmentation_map: np.ndarray) -> np.ndarray:
41
    """ WT : entails all regions 4 (ET, NCR, ED ) """
42
    copied_segmentation = _copy_input(segmentation_map)
43
    copied_segmentation[copied_segmentation != 0] = 1
44
    return copied_segmentation.astype(np.uint8)
45
46
47
def get_tc(segmentation_map: np.ndarray) -> np.ndarray:
48
    """ TC : tumor core entails the ET and NCR/NET labels 4 and 1 """
49
    regions = brats_segmentation_regions()
50
51
    copied_segmentation = _copy_input(segmentation_map)
52
    copied_segmentation[copied_segmentation == regions["ED"]] = 0  # remove edema
53
    copied_segmentation[copied_segmentation > 0] = 1
54
55
    return copied_segmentation.astype(np.uint8)
56
57
def convert_from_brats_labels(segmentation_map: np.ndarray) -> np.ndarray:
58
    """Method to convert brats labels as models need consecutive values"""
59
    regions = brats_segmentation_regions()
60
    segmentation_map[segmentation_map == regions["ET_brats"]] = regions["ET"]
61
    return segmentation_map
62
63
def convert_to_brats_labels(segmentation_map: np.ndarray) -> np.ndarray:
64
    """Method to convert recover brats labels encoding"""
65
    regions = brats_segmentation_regions()
66
    segmentation_map[segmentation_map == regions["ET"]] = regions["ET_brats"]
67
    return segmentation_map
68
69
70
def _copy_input(input):
71
    if torch.is_tensor(input):
72
        return input.detach().clone()
73
    else:
74
        return input.copy()