|
a |
|
b/data_generator.py |
|
|
1 |
""" |
|
|
2 |
Utilities for real-time multi-thread data generator |
|
|
3 |
""" |
|
|
4 |
|
|
|
5 |
import scipy |
|
|
6 |
import numpy as np |
|
|
7 |
from tensorflow.keras.utils import Sequence, to_categorical |
|
|
8 |
|
|
|
9 |
|
|
|
10 |
|
|
|
11 |
class CustomDataGenerator(Sequence): |
|
|
12 |
|
|
|
13 |
def __init__(self, hdf5_file, brain_idx, batch_size=16, view="axial", mode='train', horizontal_flip=False, |
|
|
14 |
vertical_flip=False, rotation_range=0, zoom_range=0., shuffle=True): |
|
|
15 |
""" |
|
|
16 |
Custom data generator based on Keras Sequance class. |
|
|
17 |
This implementation enables multiprocessing and on-the-fly data augmentation |
|
|
18 |
which will speed up training, especially in the task of brain tumor segmentation |
|
|
19 |
that suffers from time-consuming data processing. |
|
|
20 |
|
|
|
21 |
Parameters |
|
|
22 |
---------- |
|
|
23 |
hdf5_file : file.File |
|
|
24 |
An opend hdf5 file that contains all data. |
|
|
25 |
brain_idx : array |
|
|
26 |
The brain indexes corresponing to a specific fold. All of these |
|
|
27 |
brain indexes will be use for training and the ones which are |
|
|
28 |
not in 'brain_idx' will be used for validation |
|
|
29 |
batch_size : int |
|
|
30 |
The number of input/output arrays that will be generated each |
|
|
31 |
time. The default is 16. |
|
|
32 |
view : str |
|
|
33 |
'axial', 'sagittal' or 'coronal'. The generator will extract |
|
|
34 |
2D slices and perform normalization with respect to the chosen view. |
|
|
35 |
The defualt is axial. |
|
|
36 |
mode : str |
|
|
37 |
Prepare the DataGenerator for 'train' or 'validation' phase. |
|
|
38 |
The default is 'train'. |
|
|
39 |
horizontal_flip : bool |
|
|
40 |
Whether to use horizontal flip for data augmentation. The default is False. |
|
|
41 |
vertical_flip : bool |
|
|
42 |
Whether to use vertical flip for data augmentation. The default is False. |
|
|
43 |
rotation_range : float |
|
|
44 |
Random rotation for data augmentation. The default is 0. |
|
|
45 |
zoom_range : float |
|
|
46 |
Random zoom for data augmentation. The default is 0. |
|
|
47 |
shuffle : bool |
|
|
48 |
Whether to shuffle data. The default is True. Note that if mode='validation' |
|
|
49 |
it will not shufflw tha data. |
|
|
50 |
|
|
|
51 |
""" |
|
|
52 |
|
|
|
53 |
self.data_storage = hdf5_file.root.data |
|
|
54 |
self.truth_storage = hdf5_file.root.truth |
|
|
55 |
|
|
|
56 |
total_brains = self.data_storage.shape[0] |
|
|
57 |
self.brain_idx = self.get_brain_idx(brain_idx, mode, total_brains) |
|
|
58 |
self.batch_size = batch_size |
|
|
59 |
|
|
|
60 |
if view == 'axial': |
|
|
61 |
self.view_axes = (0, 1, 2, 3) |
|
|
62 |
elif view == 'sagittal': |
|
|
63 |
self.view_axes = (2, 1, 0, 3) |
|
|
64 |
elif view == 'coronal': |
|
|
65 |
self.view_axes = (1, 2, 0, 3) |
|
|
66 |
else: |
|
|
67 |
ValueError('unknown input view => {}'.format(view)) |
|
|
68 |
|
|
|
69 |
self.mode = mode |
|
|
70 |
self.horizontal_flip = horizontal_flip |
|
|
71 |
self.vertical_flip = vertical_flip |
|
|
72 |
self.rotation_range = rotation_range |
|
|
73 |
self.zoom_range = [1 - zoom_range, 1 + zoom_range] |
|
|
74 |
self.shuffle = shuffle |
|
|
75 |
self.data_shape = tuple(np.array(self.data_storage.shape[1:])[np.array(self.view_axes)]) |
|
|
76 |
|
|
|
77 |
print('Using {} out of {} brains'.format(len(self.brain_idx), total_brains), end=' ') |
|
|
78 |
print('({} out of {} 2D slices)'.format(len(self.brain_idx) * self.data_shape[0], total_brains * self.data_shape[0])) |
|
|
79 |
print('the generated data shape in "{}" view: {}'.format(view, str(self.data_shape[1:]))) |
|
|
80 |
print('-----'*10) |
|
|
81 |
|
|
|
82 |
self.on_epoch_end() |
|
|
83 |
|
|
|
84 |
|
|
|
85 |
|
|
|
86 |
@staticmethod |
|
|
87 |
def get_brain_idx(brain_idx, mode, total_brains): |
|
|
88 |
|
|
|
89 |
""" |
|
|
90 |
Getting the brain indexes that will be used by the generator. |
|
|
91 |
if mode=='train' => the original indexes will be used (because we built these |
|
|
92 |
npy files based on training indexes in 'prepare_data.py' for k-fold, remember? :) |
|
|
93 |
if mode=='validation' => the indexes which are not in the brain_idx will |
|
|
94 |
be used. |
|
|
95 |
|
|
|
96 |
|
|
|
97 |
""" |
|
|
98 |
if mode=='validation': |
|
|
99 |
brain_idx = np.array([i for i in np.arange(total_brains) if i not in brain_idx]) |
|
|
100 |
print('DataGenerator is preparing for validation mode ...') |
|
|
101 |
elif mode=='train': |
|
|
102 |
brain_idx = brain_idx |
|
|
103 |
print('DataGenerator is preparing for training mode ...') |
|
|
104 |
else: |
|
|
105 |
raise ValueError('unknown "{}" mode'.format(mode)) |
|
|
106 |
|
|
|
107 |
return brain_idx |
|
|
108 |
|
|
|
109 |
|
|
|
110 |
def __len__(self): |
|
|
111 |
return int(np.floor( len(self.indexes) / self.batch_size)) |
|
|
112 |
|
|
|
113 |
|
|
|
114 |
def __getitem__(self, index): |
|
|
115 |
|
|
|
116 |
# Generate indexes of the batch |
|
|
117 |
idx = self.indexes[index*self.batch_size:(index+1)*self.batch_size] |
|
|
118 |
# Generate data |
|
|
119 |
X_batch, Y_batch = self.data_load_and_preprocess(idx) |
|
|
120 |
|
|
|
121 |
return X_batch, Y_batch |
|
|
122 |
|
|
|
123 |
def on_epoch_end(self): |
|
|
124 |
""" |
|
|
125 |
Updates indexes after each epoch |
|
|
126 |
""" |
|
|
127 |
tmp=[] |
|
|
128 |
for i in self.brain_idx: |
|
|
129 |
for j in range(self.data_shape[0]): |
|
|
130 |
tmp.append((i,j)) |
|
|
131 |
self.indexes = tmp |
|
|
132 |
|
|
|
133 |
if self.mode=='train' and self.shuffle: |
|
|
134 |
np.random.shuffle(self.indexes) |
|
|
135 |
|
|
|
136 |
|
|
|
137 |
def data_load_and_preprocess(self, idx): |
|
|
138 |
""" |
|
|
139 |
Generates data containing batch_size samples |
|
|
140 |
""" |
|
|
141 |
slice_batch = [] |
|
|
142 |
label_batch = [] |
|
|
143 |
|
|
|
144 |
# Generate data |
|
|
145 |
for i in idx: |
|
|
146 |
brain_number = i[0] |
|
|
147 |
slice_number = i[1] |
|
|
148 |
slice_, label_ = self.read_data(brain_number, slice_number) |
|
|
149 |
slice_ = self.normalize_modalities(slice_) |
|
|
150 |
slice_and_label = np.concatenate((slice_, label_) , axis=-1) |
|
|
151 |
params = self.get_random_transform() |
|
|
152 |
slice_and_label = self.apply_transform(slice_and_label, params) |
|
|
153 |
slice_ = slice_and_label[...,:4] |
|
|
154 |
label_ = slice_and_label[..., 4] |
|
|
155 |
label_ = to_categorical(label_, 4) |
|
|
156 |
|
|
|
157 |
slice_batch.append(slice_) |
|
|
158 |
label_batch.append(label_) |
|
|
159 |
|
|
|
160 |
return np.array(slice_batch), np.array(label_batch) |
|
|
161 |
|
|
|
162 |
|
|
|
163 |
|
|
|
164 |
def read_data(self, brain_number, slice_number): |
|
|
165 |
|
|
|
166 |
""" |
|
|
167 |
Reads data from table with respect to the 'view' |
|
|
168 |
|
|
|
169 |
""" |
|
|
170 |
|
|
|
171 |
slice_ = self.data_storage[brain_number].transpose(self.view_axes)[slice_number] |
|
|
172 |
label_ = self.truth_storage[brain_number].transpose(self.view_axes[:3])[slice_number] |
|
|
173 |
label_ = np.expand_dims(label_, axis=-1) |
|
|
174 |
|
|
|
175 |
return slice_, label_ |
|
|
176 |
|
|
|
177 |
|
|
|
178 |
def normalize_slice(self, slice): |
|
|
179 |
|
|
|
180 |
""" |
|
|
181 |
Removes 1% of the top and bottom intensities and perform |
|
|
182 |
normalization on the input 2D slice. |
|
|
183 |
""" |
|
|
184 |
b = np.percentile(slice, 99) |
|
|
185 |
t = np.percentile(slice, 1) |
|
|
186 |
slice = np.clip(slice, t, b) |
|
|
187 |
if np.std(slice)==0: |
|
|
188 |
return slice |
|
|
189 |
else: |
|
|
190 |
slice = (slice - np.mean(slice)) / np.std(slice) |
|
|
191 |
return slice |
|
|
192 |
|
|
|
193 |
|
|
|
194 |
def normalize_modalities(self, Slice): |
|
|
195 |
|
|
|
196 |
""" |
|
|
197 |
Performs normalization on each modalities of input |
|
|
198 |
""" |
|
|
199 |
|
|
|
200 |
normalized_slices = np.zeros_like(Slice).astype(np.float32) |
|
|
201 |
for slice_ix in range(4): |
|
|
202 |
normalized_slices[..., slice_ix] = self.normalize_slice(Slice[..., slice_ix]) |
|
|
203 |
|
|
|
204 |
return normalized_slices |
|
|
205 |
|
|
|
206 |
|
|
|
207 |
def flip_axis(self, x, axis): |
|
|
208 |
|
|
|
209 |
x = np.asarray(x).swapaxes(axis, 0) |
|
|
210 |
x = x[::-1, ...] |
|
|
211 |
x = x.swapaxes(0, axis) |
|
|
212 |
return x |
|
|
213 |
|
|
|
214 |
|
|
|
215 |
def apply_transform(self, x, transform_parameters): |
|
|
216 |
|
|
|
217 |
x = apply_affine_transform(x, transform_parameters.get('theta', 0), |
|
|
218 |
transform_parameters.get('tx', 0), |
|
|
219 |
transform_parameters.get('ty', 0), |
|
|
220 |
transform_parameters.get('shear', 0), |
|
|
221 |
transform_parameters.get('zx', 1), |
|
|
222 |
transform_parameters.get('zy', 1), |
|
|
223 |
row_axis=0, |
|
|
224 |
col_axis=1, |
|
|
225 |
channel_axis=2) |
|
|
226 |
if transform_parameters.get('flip_horizontal', False): |
|
|
227 |
x = self.flip_axis(x, 1) |
|
|
228 |
if transform_parameters.get('flip_vertical', False): |
|
|
229 |
x = self.flip_axis(x, 0) |
|
|
230 |
return x |
|
|
231 |
|
|
|
232 |
def get_random_transform(self): |
|
|
233 |
|
|
|
234 |
if self.rotation_range: |
|
|
235 |
theta = np.random.uniform(-self.rotation_range,self.rotation_range) |
|
|
236 |
else: |
|
|
237 |
theta = 0 |
|
|
238 |
|
|
|
239 |
if self.zoom_range[0] == 1 and self.zoom_range[1] == 1: |
|
|
240 |
zx, zy = 1, 1 |
|
|
241 |
else: |
|
|
242 |
zx, zy = np.random.uniform(self.zoom_range[0],self.zoom_range[1], 2) |
|
|
243 |
|
|
|
244 |
flip_horizontal = (np.random.random() < 0.5) * self.horizontal_flip |
|
|
245 |
flip_vertical = (np.random.random() < 0.5) * self.vertical_flip |
|
|
246 |
|
|
|
247 |
transform_parameters = {'flip_horizontal': flip_horizontal, |
|
|
248 |
'flip_vertical':flip_vertical, |
|
|
249 |
'theta': theta, |
|
|
250 |
'zx': zx, |
|
|
251 |
'zy': zy} |
|
|
252 |
|
|
|
253 |
return transform_parameters |
|
|
254 |
|
|
|
255 |
""" |
|
|
256 |
The two following functions are from ImageDataGenerator class of keras. |
|
|
257 |
https://github.com/keras-team/keras/blob/master/keras/preprocessing/image.py |
|
|
258 |
""" |
|
|
259 |
|
|
|
260 |
def apply_affine_transform(x, theta=0, tx=0, ty=0, shear=0, zx=1, zy=1, |
|
|
261 |
row_axis=0, col_axis=1, channel_axis=2, |
|
|
262 |
fill_mode='nearest', cval=0.): |
|
|
263 |
"""Applies an affine transformation specified by the parameters given. |
|
|
264 |
|
|
|
265 |
# Arguments |
|
|
266 |
x: 2D numpy array, single image. |
|
|
267 |
theta: Rotation angle in degrees. |
|
|
268 |
tx: Width shift. |
|
|
269 |
ty: Heigh shift. |
|
|
270 |
shear: Shear angle in degrees. |
|
|
271 |
zx: Zoom in x direction. |
|
|
272 |
zy: Zoom in y direction |
|
|
273 |
row_axis: Index of axis for rows in the input image. |
|
|
274 |
col_axis: Index of axis for columns in the input image. |
|
|
275 |
channel_axis: Index of axis for channels in the input image. |
|
|
276 |
fill_mode: Points outside the boundaries of the input |
|
|
277 |
are filled according to the given mode |
|
|
278 |
(one of `{'constant', 'nearest', 'reflect', 'wrap'}`). |
|
|
279 |
cval: Value used for points outside the boundaries |
|
|
280 |
of the input if `mode='constant'`. |
|
|
281 |
|
|
|
282 |
# Returns |
|
|
283 |
The transformed version of the input. |
|
|
284 |
""" |
|
|
285 |
transform_matrix = None |
|
|
286 |
if theta != 0: |
|
|
287 |
theta = np.deg2rad(theta) |
|
|
288 |
rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0], |
|
|
289 |
[np.sin(theta), np.cos(theta), 0], |
|
|
290 |
[0, 0, 1]]) |
|
|
291 |
transform_matrix = rotation_matrix |
|
|
292 |
|
|
|
293 |
if tx != 0 or ty != 0: |
|
|
294 |
shift_matrix = np.array([[1, 0, tx], |
|
|
295 |
[0, 1, ty], |
|
|
296 |
[0, 0, 1]]) |
|
|
297 |
if transform_matrix is None: |
|
|
298 |
transform_matrix = shift_matrix |
|
|
299 |
else: |
|
|
300 |
transform_matrix = np.dot(transform_matrix, shift_matrix) |
|
|
301 |
|
|
|
302 |
if shear != 0: |
|
|
303 |
shear = np.deg2rad(shear) |
|
|
304 |
shear_matrix = np.array([[1, -np.sin(shear), 0], |
|
|
305 |
[0, np.cos(shear), 0], |
|
|
306 |
[0, 0, 1]]) |
|
|
307 |
if transform_matrix is None: |
|
|
308 |
transform_matrix = shear_matrix |
|
|
309 |
else: |
|
|
310 |
transform_matrix = np.dot(transform_matrix, shear_matrix) |
|
|
311 |
|
|
|
312 |
if zx != 1 or zy != 1: |
|
|
313 |
zoom_matrix = np.array([[zx, 0, 0], |
|
|
314 |
[0, zy, 0], |
|
|
315 |
[0, 0, 1]]) |
|
|
316 |
if transform_matrix is None: |
|
|
317 |
transform_matrix = zoom_matrix |
|
|
318 |
else: |
|
|
319 |
transform_matrix = np.dot(transform_matrix, zoom_matrix) |
|
|
320 |
|
|
|
321 |
if transform_matrix is not None: |
|
|
322 |
h, w = x.shape[row_axis], x.shape[col_axis] |
|
|
323 |
transform_matrix = transform_matrix_offset_center( |
|
|
324 |
transform_matrix, h, w) |
|
|
325 |
x = np.rollaxis(x, channel_axis, 0) |
|
|
326 |
final_affine_matrix = transform_matrix[:2, :2] |
|
|
327 |
final_offset = transform_matrix[:2, 2] |
|
|
328 |
|
|
|
329 |
channel_images = [scipy.ndimage.interpolation.affine_transform( |
|
|
330 |
x_channel, |
|
|
331 |
final_affine_matrix, |
|
|
332 |
final_offset, |
|
|
333 |
order=1, |
|
|
334 |
mode=fill_mode, |
|
|
335 |
cval=cval) for x_channel in x] |
|
|
336 |
x = np.stack(channel_images, axis=0) |
|
|
337 |
x = np.rollaxis(x, 0, channel_axis + 1) |
|
|
338 |
return x |
|
|
339 |
|
|
|
340 |
|
|
|
341 |
|
|
|
342 |
def transform_matrix_offset_center(matrix, x, y): |
|
|
343 |
o_x = float(x) / 2 + 0.5 |
|
|
344 |
o_y = float(y) / 2 + 0.5 |
|
|
345 |
offset_matrix = np.array([[1, 0, o_x], [0, 1, o_y], [0, 0, 1]]) |
|
|
346 |
reset_matrix = np.array([[1, 0, -o_x], [0, 1, -o_y], [0, 0, 1]]) |
|
|
347 |
transform_matrix = np.dot(np.dot(offset_matrix, matrix), reset_matrix) |
|
|
348 |
return transform_matrix |
|
|
349 |
|
|
|
350 |
|
|
|
351 |
|
|
|
352 |
|