|
a |
|
b/flair-segmentation/data.py |
|
|
1 |
from __future__ import print_function |
|
|
2 |
|
|
|
3 |
import os |
|
|
4 |
|
|
|
5 |
import numpy as np |
|
|
6 |
from skimage.io import imread |
|
|
7 |
from skimage.transform import rescale |
|
|
8 |
from skimage.transform import rotate |
|
|
9 |
|
|
|
10 |
image_rows = 256 |
|
|
11 |
image_cols = 256 |
|
|
12 |
|
|
|
13 |
channels = 1 # refers to neighboring slices; if set to 3, takes previous and next slice as additional channels |
|
|
14 |
modalities = 3 # refers to pre, flair and post modalities; if set to 3, uses all and if set to 1, only flair |
|
|
15 |
|
|
|
16 |
|
|
|
17 |
def load_data(path): |
|
|
18 |
""" |
|
|
19 |
Assumes filenames in given path to be in the following format as defined in `preprocessing3D.m`: |
|
|
20 |
for images: <case_id>_<slice_number>.tif |
|
|
21 |
for masks: <case_id>_<slice_number>_mask.tif |
|
|
22 |
|
|
|
23 |
Args: |
|
|
24 |
path: string to the folder with images |
|
|
25 |
|
|
|
26 |
Returns: |
|
|
27 |
np.ndarray: array of images |
|
|
28 |
np.ndarray: array of masks |
|
|
29 |
np.chararray: array of corresponding images' filenames without extensions |
|
|
30 |
""" |
|
|
31 |
images_list = os.listdir(path) |
|
|
32 |
total_count = len(images_list) / 2 |
|
|
33 |
images = np.ndarray( |
|
|
34 |
(total_count, image_rows, image_cols, channels * modalities), dtype=np.uint8 |
|
|
35 |
) |
|
|
36 |
masks = np.ndarray((total_count, image_rows, image_cols), dtype=np.uint8) |
|
|
37 |
names = np.chararray(total_count, itemsize=64) |
|
|
38 |
|
|
|
39 |
i = 0 |
|
|
40 |
for image_name in images_list: |
|
|
41 |
if "mask" in image_name: |
|
|
42 |
continue |
|
|
43 |
|
|
|
44 |
names[i] = image_name.split(".")[0] |
|
|
45 |
slice_number = int(names[i].split("_")[-1]) |
|
|
46 |
patient_id = "_".join(names[i].split("_")[:-1]) |
|
|
47 |
|
|
|
48 |
image_mask_name = image_name.split(".")[0] + "_mask.tif" |
|
|
49 |
img = imread(os.path.join(path, image_name), as_grey=(modalities == 1)) |
|
|
50 |
img_mask = imread(os.path.join(path, image_mask_name), as_grey=True) |
|
|
51 |
|
|
|
52 |
if channels > 1: |
|
|
53 |
img_prev = read_slice(path, patient_id, slice_number - 1) |
|
|
54 |
img_next = read_slice(path, patient_id, slice_number + 1) |
|
|
55 |
|
|
|
56 |
img = np.dstack((img_prev, img[..., np.newaxis], img_next)) |
|
|
57 |
|
|
|
58 |
elif modalities == 1: |
|
|
59 |
img = np.array([img]) |
|
|
60 |
|
|
|
61 |
img_mask = np.array([img_mask]) |
|
|
62 |
|
|
|
63 |
images[i] = img |
|
|
64 |
masks[i] = img_mask |
|
|
65 |
|
|
|
66 |
i += 1 |
|
|
67 |
|
|
|
68 |
images = images.astype("float32") |
|
|
69 |
masks = masks[..., np.newaxis] |
|
|
70 |
masks = masks.astype("float32") |
|
|
71 |
masks /= 255. |
|
|
72 |
|
|
|
73 |
return images, masks, names |
|
|
74 |
|
|
|
75 |
|
|
|
76 |
def oversample(images, masks, augment=False): |
|
|
77 |
""" |
|
|
78 |
Repeats 2 times every slice with nonzero mask. |
|
|
79 |
|
|
|
80 |
Args: |
|
|
81 |
np.ndarray: array of images |
|
|
82 |
np.ndarray: array of masks |
|
|
83 |
|
|
|
84 |
Returns: |
|
|
85 |
np.ndarray: array of oversampled images |
|
|
86 |
np.ndarray: array of oversampled masks |
|
|
87 |
""" |
|
|
88 |
images_o = [] |
|
|
89 |
masks_o = [] |
|
|
90 |
for i in range(len(masks)): |
|
|
91 |
if np.max(masks[i]) < 1: |
|
|
92 |
continue |
|
|
93 |
|
|
|
94 |
if augment: |
|
|
95 |
image_a, mask_a = augmentation_rotate(images[i], masks[i]) |
|
|
96 |
images_o.append(image_a) |
|
|
97 |
masks_o.append(mask_a) |
|
|
98 |
image_a, mask_a = augmentation_scale(images[i], masks[i]) |
|
|
99 |
images_o.append(image_a) |
|
|
100 |
masks_o.append(mask_a) |
|
|
101 |
continue |
|
|
102 |
|
|
|
103 |
for _ in range(2): |
|
|
104 |
images_o.append(images[i]) |
|
|
105 |
masks_o.append(masks[i]) |
|
|
106 |
|
|
|
107 |
images_o = np.array(images_o) |
|
|
108 |
masks_o = np.array(masks_o) |
|
|
109 |
|
|
|
110 |
return np.vstack((images, images_o)), np.vstack((masks, masks_o)) |
|
|
111 |
|
|
|
112 |
|
|
|
113 |
def read_slice(path, patient_id, slice): |
|
|
114 |
img = np.zeros((image_rows, image_cols)) |
|
|
115 |
img_name = patient_id + "_" + str(slice) + ".tif" |
|
|
116 |
img_path = os.path.join(path, img_name) |
|
|
117 |
|
|
|
118 |
try: |
|
|
119 |
img = imread(img_path, as_grey=(modalities == 1)) |
|
|
120 |
except Exception: |
|
|
121 |
pass |
|
|
122 |
|
|
|
123 |
return img[..., np.newaxis] |
|
|
124 |
|
|
|
125 |
|
|
|
126 |
def augmentation_rotate(img, img_mask): |
|
|
127 |
angle = np.random.uniform(5.0, 15.0) * np.random.choice([-1.0, 1.0], 1)[0] |
|
|
128 |
|
|
|
129 |
img = rotate(img, angle, resize=False, order=3, preserve_range=True) |
|
|
130 |
img_mask = rotate(img_mask, angle, resize=False, order=0, preserve_range=True) |
|
|
131 |
|
|
|
132 |
return img, img_mask |
|
|
133 |
|
|
|
134 |
|
|
|
135 |
def augmentation_scale(img, img_mask): |
|
|
136 |
scale = 1.0 + np.random.uniform(0.04, 0.08) * np.random.choice([-1.0, 1.0], 1)[0] |
|
|
137 |
|
|
|
138 |
img = rescale(img, scale, order=3, preserve_range=True) |
|
|
139 |
img_mask = rescale(img_mask, scale, order=0, preserve_range=True) |
|
|
140 |
if scale > 1: |
|
|
141 |
img = center_crop(img, image_rows, image_cols) |
|
|
142 |
img_mask = center_crop(img_mask, image_rows, image_cols) |
|
|
143 |
else: |
|
|
144 |
img = zeros_pad(img, image_rows) |
|
|
145 |
img_mask = zeros_pad(img_mask, image_rows) |
|
|
146 |
|
|
|
147 |
return img, img_mask |
|
|
148 |
|
|
|
149 |
|
|
|
150 |
def center_crop(img, cropx, cropy): |
|
|
151 |
startx = img.shape[1] // 2 - (cropx // 2) |
|
|
152 |
starty = img.shape[0] // 2 - (cropy // 2) |
|
|
153 |
return img[starty : starty + cropy, startx : startx + cropx] |
|
|
154 |
|
|
|
155 |
|
|
|
156 |
def zeros_pad(img, size): |
|
|
157 |
pad_before = int(round(((size - img.shape[0]) / 2.0))) |
|
|
158 |
pad_after = size - img.shape[0] - pad_before |
|
|
159 |
if len(img.shape) > 2: |
|
|
160 |
return np.pad( |
|
|
161 |
img, |
|
|
162 |
((pad_before, pad_after), (pad_before, pad_after), (0, 0)), |
|
|
163 |
mode="constant", |
|
|
164 |
) |
|
|
165 |
return np.pad(img, (pad_before, pad_after), mode="constant") |