a b/adpkd_segmentation/datasets/masks.py
1
import numpy as np
2
3
BACKGROUND = 0.0
4
L_KIDNEY = 0.5019608
5
R_KIDNEY = 0.7490196
6
7
BACKGROUND_INT = 0
8
L_KIDNEY_INT = 128
9
R_KIDNEY_INT = 191
10
11
12
class SingleChannelMaskNumpy:
13
    """Sets 1 for kidneys, 0 otherwise."""
14
15
    def __call__(self, label):
16
        """
17
        Args:
18
            label, (1, H, W) uint8 numpy array
19
        Returns:
20
            numpy array, (1, H, W) uint8 one-hot encoded mask
21
        """
22
        kidney = np.bitwise_or(label == R_KIDNEY_INT, label == L_KIDNEY_INT)
23
        return kidney.astype(np.uint8)
24
25
26
class TwoChannelsMaskNumpy:
27
    """
28
    The first channel for right kidney vs background,
29
    and the second one for left kidney vs background.
30
31
    Kidneys are marked as 1, background as 0.
32
    """
33
34
    def __call__(self, label):
35
        """
36
        Args:
37
            label, (1, H, W) uint8 numpy array
38
        Returns:
39
            numpy array, (2, H, W) uint8 one-hot encoded mask
40
        """
41
        r_kidney = label == R_KIDNEY_INT
42
        l_kidney = label == L_KIDNEY_INT
43
        mask = np.concatenate([r_kidney, l_kidney], axis=0).astype(np.uint8)
44
        return mask
45
46
47
class ThreeChannelMaskNumpy:
48
    """One channel for each of the 3 classes. Background last."""
49
50
    def __call__(self, label):
51
        """
52
        Args:
53
            label, (1, H, W) float32 tensor
54
55
        Returns:
56
            numpy array, (3, H, W) uint8 one-hot encoded mask
57
        """
58
        background = label == BACKGROUND_INT
59
        r_kidney = label == R_KIDNEY_INT
60
        l_kidney = label == L_KIDNEY_INT
61
        mask = np.concatenate([r_kidney, l_kidney, background], axis=0).astype(
62
            np.uint8
63
        )
64
        return mask