a b/flair-segmentation/data.py
1
from __future__ import print_function
2
3
import os
4
5
import numpy as np
6
from skimage.io import imread
7
from skimage.transform import rescale
8
from skimage.transform import rotate
9
10
image_rows = 256
11
image_cols = 256
12
13
channels = 1    # refers to neighboring slices; if set to 3, takes previous and next slice as additional channels
14
modalities = 3  # refers to pre, flair and post modalities; if set to 3, uses all and if set to 1, only flair
15
16
17
def load_data(path):
18
    """
19
    Assumes filenames in given path to be in the following format as defined in `preprocessing3D.m`:
20
    for images: <case_id>_<slice_number>.tif
21
    for masks: <case_id>_<slice_number>_mask.tif
22
23
        Args:
24
            path: string to the folder with images
25
26
        Returns:
27
            np.ndarray: array of images
28
            np.ndarray: array of masks
29
            np.chararray: array of corresponding images' filenames without extensions
30
    """
31
    images_list = os.listdir(path)
32
    total_count = len(images_list) / 2
33
    images = np.ndarray(
34
        (total_count, image_rows, image_cols, channels * modalities), dtype=np.uint8
35
    )
36
    masks = np.ndarray((total_count, image_rows, image_cols), dtype=np.uint8)
37
    names = np.chararray(total_count, itemsize=64)
38
39
    i = 0
40
    for image_name in images_list:
41
        if "mask" in image_name:
42
            continue
43
44
        names[i] = image_name.split(".")[0]
45
        slice_number = int(names[i].split("_")[-1])
46
        patient_id = "_".join(names[i].split("_")[:-1])
47
48
        image_mask_name = image_name.split(".")[0] + "_mask.tif"
49
        img = imread(os.path.join(path, image_name), as_grey=(modalities == 1))
50
        img_mask = imread(os.path.join(path, image_mask_name), as_grey=True)
51
52
        if channels > 1:
53
            img_prev = read_slice(path, patient_id, slice_number - 1)
54
            img_next = read_slice(path, patient_id, slice_number + 1)
55
56
            img = np.dstack((img_prev, img[..., np.newaxis], img_next))
57
58
        elif modalities == 1:
59
            img = np.array([img])
60
61
        img_mask = np.array([img_mask])
62
63
        images[i] = img
64
        masks[i] = img_mask
65
66
        i += 1
67
68
    images = images.astype("float32")
69
    masks = masks[..., np.newaxis]
70
    masks = masks.astype("float32")
71
    masks /= 255.
72
73
    return images, masks, names
74
75
76
def oversample(images, masks, augment=False):
77
    """
78
    Repeats 2 times every slice with nonzero mask.
79
80
        Args:
81
            np.ndarray: array of images
82
            np.ndarray: array of masks
83
84
        Returns:
85
            np.ndarray: array of oversampled images
86
            np.ndarray: array of oversampled masks
87
    """
88
    images_o = []
89
    masks_o = []
90
    for i in range(len(masks)):
91
        if np.max(masks[i]) < 1:
92
            continue
93
94
        if augment:
95
            image_a, mask_a = augmentation_rotate(images[i], masks[i])
96
            images_o.append(image_a)
97
            masks_o.append(mask_a)
98
            image_a, mask_a = augmentation_scale(images[i], masks[i])
99
            images_o.append(image_a)
100
            masks_o.append(mask_a)
101
            continue
102
103
        for _ in range(2):
104
            images_o.append(images[i])
105
            masks_o.append(masks[i])
106
107
    images_o = np.array(images_o)
108
    masks_o = np.array(masks_o)
109
110
    return np.vstack((images, images_o)), np.vstack((masks, masks_o))
111
112
113
def read_slice(path, patient_id, slice):
114
    img = np.zeros((image_rows, image_cols))
115
    img_name = patient_id + "_" + str(slice) + ".tif"
116
    img_path = os.path.join(path, img_name)
117
118
    try:
119
        img = imread(img_path, as_grey=(modalities == 1))
120
    except Exception:
121
        pass
122
123
    return img[..., np.newaxis]
124
125
126
def augmentation_rotate(img, img_mask):
127
    angle = np.random.uniform(5.0, 15.0) * np.random.choice([-1.0, 1.0], 1)[0]
128
129
    img = rotate(img, angle, resize=False, order=3, preserve_range=True)
130
    img_mask = rotate(img_mask, angle, resize=False, order=0, preserve_range=True)
131
132
    return img, img_mask
133
134
135
def augmentation_scale(img, img_mask):
136
    scale = 1.0 + np.random.uniform(0.04, 0.08) * np.random.choice([-1.0, 1.0], 1)[0]
137
138
    img = rescale(img, scale, order=3, preserve_range=True)
139
    img_mask = rescale(img_mask, scale, order=0, preserve_range=True)
140
    if scale > 1:
141
        img = center_crop(img, image_rows, image_cols)
142
        img_mask = center_crop(img_mask, image_rows, image_cols)
143
    else:
144
        img = zeros_pad(img, image_rows)
145
        img_mask = zeros_pad(img_mask, image_rows)
146
147
    return img, img_mask
148
149
150
def center_crop(img, cropx, cropy):
151
    startx = img.shape[1] // 2 - (cropx // 2)
152
    starty = img.shape[0] // 2 - (cropy // 2)
153
    return img[starty : starty + cropy, startx : startx + cropx]
154
155
156
def zeros_pad(img, size):
157
    pad_before = int(round(((size - img.shape[0]) / 2.0)))
158
    pad_after = size - img.shape[0] - pad_before
159
    if len(img.shape) > 2:
160
        return np.pad(
161
            img,
162
            ((pad_before, pad_after), (pad_before, pad_after), (0, 0)),
163
            mode="constant",
164
        )
165
    return np.pad(img, (pad_before, pad_after), mode="constant")