a b/data_generator.py
1
"""
2
Utilities for real-time multi-thread data generator
3
"""
4
5
import scipy
6
import numpy as np
7
from tensorflow.keras.utils import Sequence, to_categorical
8
9
10
11
class CustomDataGenerator(Sequence):
12
    
13
    def __init__(self, hdf5_file, brain_idx, batch_size=16, view="axial", mode='train', horizontal_flip=False,
14
                 vertical_flip=False, rotation_range=0, zoom_range=0., shuffle=True):
15
        """
16
        Custom data generator based on Keras Sequance class.
17
        This implementation enables multiprocessing and on-the-fly data augmentation 
18
        which will speed up training, especially in the task of brain tumor segmentation
19
        that suffers from time-consuming data processing.
20
        
21
        Parameters
22
        ----------
23
        hdf5_file : file.File
24
            An opend hdf5 file that contains all data.
25
        brain_idx : array
26
            The brain indexes corresponing to a specific fold. All of these 
27
            brain indexes will be use for training and the ones which are 
28
            not in 'brain_idx' will be used for validation
29
        batch_size : int
30
            The number of input/output arrays that will be generated each 
31
            time. The default is 16.
32
        view : str
33
            'axial', 'sagittal' or 'coronal'. The generator will extract
34
            2D slices and perform normalization with respect to the chosen view.
35
            The defualt is axial.
36
        mode : str
37
            Prepare the DataGenerator for 'train' or 'validation' phase. 
38
            The default is 'train'.
39
        horizontal_flip : bool
40
            Whether to use horizontal flip for data augmentation. The default is False.
41
        vertical_flip : bool
42
            Whether to use vertical flip for data augmentation. The default is False.
43
        rotation_range : float
44
            Random rotation for data augmentation. The default is 0.
45
        zoom_range : float
46
            Random zoom for data augmentation. The default is 0.
47
        shuffle : bool
48
            Whether to shuffle data. The default is True. Note that if mode='validation' 
49
            it will not shufflw tha data.
50
51
        """
52
        
53
        self.data_storage    = hdf5_file.root.data
54
        self.truth_storage   = hdf5_file.root.truth   
55
        
56
        total_brains         = self.data_storage.shape[0]
57
        self.brain_idx       = self.get_brain_idx(brain_idx, mode, total_brains)
58
        self.batch_size      = batch_size
59
        
60
        if view == 'axial':
61
            self.view_axes = (0, 1, 2, 3)            
62
        elif view == 'sagittal': 
63
            self.view_axes = (2, 1, 0, 3)
64
        elif view == 'coronal':
65
            self.view_axes = (1, 2, 0, 3)            
66
        else:
67
            ValueError('unknown input view => {}'.format(view))
68
            
69
        self.mode            = mode
70
        self.horizontal_flip = horizontal_flip
71
        self.vertical_flip   = vertical_flip
72
        self.rotation_range  = rotation_range       
73
        self.zoom_range      = [1 - zoom_range, 1 + zoom_range]
74
        self.shuffle         = shuffle
75
        self.data_shape      = tuple(np.array(self.data_storage.shape[1:])[np.array(self.view_axes)])
76
        
77
        print('Using {} out of {} brains'.format(len(self.brain_idx), total_brains), end=' ')
78
        print('({} out of {} 2D slices)'.format(len(self.brain_idx) * self.data_shape[0], total_brains * self.data_shape[0]))
79
        print('the generated data shape in "{}" view: {}'.format(view, str(self.data_shape[1:])))
80
        print('-----'*10)
81
82
        self.on_epoch_end()
83
        
84
        
85
86
    @staticmethod
87
    def get_brain_idx(brain_idx, mode, total_brains):
88
        
89
        """
90
        Getting the brain indexes that will be used by the generator.
91
        if mode=='train' => the original indexes will be used (because we built these
92
        npy files based on training indexes in 'prepare_data.py' for k-fold, remember? :)
93
        if mode=='validation' => the indexes which are not in the brain_idx will
94
        be used.
95
        
96
97
        """            
98
        if mode=='validation':
99
            brain_idx       = np.array([i for i in np.arange(total_brains) if i not in brain_idx])
100
            print('DataGenerator is preparing for validation mode ...') 
101
        elif mode=='train':
102
            brain_idx       = brain_idx
103
            print('DataGenerator is preparing for training mode ...')
104
        else:
105
            raise ValueError('unknown "{}" mode'.format(mode))
106
            
107
        return brain_idx
108
109
110
    def __len__(self):
111
        return int(np.floor( len(self.indexes) / self.batch_size))
112
    
113
    
114
    def __getitem__(self, index):
115
116
        # Generate indexes of the batch
117
        idx = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
118
        # Generate data
119
        X_batch, Y_batch = self.data_load_and_preprocess(idx)
120
121
        return X_batch, Y_batch
122
123
    def on_epoch_end(self):
124
        """
125
        Updates indexes after each epoch
126
        """
127
        tmp=[]
128
        for i in self.brain_idx:
129
            for j in range(self.data_shape[0]):
130
                tmp.append((i,j))
131
        self.indexes = tmp
132
            
133
        if self.mode=='train' and self.shuffle:
134
            np.random.shuffle(self.indexes)
135
            
136
            
137
    def data_load_and_preprocess(self, idx):
138
        """
139
        Generates data containing batch_size samples
140
        """
141
        slice_batch = []
142
        label_batch = []
143
144
        # Generate data
145
        for i in idx:
146
            brain_number     = i[0]
147
            slice_number     = i[1]
148
            slice_, label_   = self.read_data(brain_number, slice_number)
149
            slice_           = self.normalize_modalities(slice_)
150
            slice_and_label  = np.concatenate((slice_, label_) , axis=-1)
151
            params           = self.get_random_transform()
152
            slice_and_label  = self.apply_transform(slice_and_label, params)
153
            slice_           = slice_and_label[...,:4]
154
            label_           = slice_and_label[..., 4]
155
            label_           = to_categorical(label_, 4) 
156
            
157
            slice_batch.append(slice_)
158
            label_batch.append(label_)
159
            
160
        return np.array(slice_batch), np.array(label_batch)
161
    
162
    
163
    
164
    def read_data(self, brain_number, slice_number):
165
        
166
        """
167
        Reads data from table with respect to the 'view'
168
        
169
        """
170
        
171
        slice_    = self.data_storage[brain_number].transpose(self.view_axes)[slice_number]
172
        label_    = self.truth_storage[brain_number].transpose(self.view_axes[:3])[slice_number]
173
        label_    = np.expand_dims(label_, axis=-1)
174
        
175
        return slice_, label_ 
176
        
177
    
178
    def normalize_slice(self, slice):
179
        
180
        """
181
        Removes 1% of the top and bottom intensities and perform
182
        normalization on the input 2D slice.
183
        """
184
        b = np.percentile(slice, 99)
185
        t = np.percentile(slice, 1)
186
        slice = np.clip(slice, t, b)
187
        if np.std(slice)==0:
188
            return slice
189
        else:
190
            slice = (slice - np.mean(slice)) / np.std(slice)
191
            return slice
192
        
193
        
194
    def normalize_modalities(self, Slice): 
195
        
196
        """
197
        Performs normalization on each modalities of input
198
        """
199
200
        normalized_slices = np.zeros_like(Slice).astype(np.float32)
201
        for slice_ix in range(4):
202
            normalized_slices[..., slice_ix] = self.normalize_slice(Slice[..., slice_ix])
203
    
204
        return normalized_slices  
205
    
206
207
    def flip_axis(self, x, axis):
208
        
209
        x = np.asarray(x).swapaxes(axis, 0)
210
        x = x[::-1, ...]
211
        x = x.swapaxes(0, axis)
212
        return x
213
    
214
    
215
    def apply_transform(self, x, transform_parameters):
216
        
217
        x = apply_affine_transform(x, transform_parameters.get('theta', 0),
218
                           transform_parameters.get('tx', 0),
219
                           transform_parameters.get('ty', 0),
220
                           transform_parameters.get('shear', 0),
221
                           transform_parameters.get('zx', 1),
222
                           transform_parameters.get('zy', 1),
223
                           row_axis=0,
224
                           col_axis=1,
225
                           channel_axis=2)
226
        if transform_parameters.get('flip_horizontal', False):
227
            x = self.flip_axis(x, 1)
228
        if transform_parameters.get('flip_vertical', False):
229
            x = self.flip_axis(x, 0)            
230
        return x
231
        
232
    def get_random_transform(self):
233
    
234
        if self.rotation_range:
235
            theta = np.random.uniform(-self.rotation_range,self.rotation_range)    
236
        else:
237
            theta = 0            
238
 
239
        if self.zoom_range[0] == 1 and self.zoom_range[1] == 1:
240
            zx, zy = 1, 1
241
        else:
242
            zx, zy = np.random.uniform(self.zoom_range[0],self.zoom_range[1], 2)
243
            
244
        flip_horizontal = (np.random.random() < 0.5) * self.horizontal_flip    
245
        flip_vertical   = (np.random.random() < 0.5) * self.vertical_flip
246
        
247
        transform_parameters = {'flip_horizontal': flip_horizontal,
248
                                'flip_vertical':flip_vertical,
249
                                'theta': theta, 
250
                                'zx': zx, 
251
                                'zy': zy}
252
    
253
        return transform_parameters        
254
        
255
"""
256
The two following functions are from ImageDataGenerator class of keras.
257
https://github.com/keras-team/keras/blob/master/keras/preprocessing/image.py
258
"""
259
    
260
def apply_affine_transform(x, theta=0, tx=0, ty=0, shear=0, zx=1, zy=1,
261
                           row_axis=0, col_axis=1, channel_axis=2,
262
                           fill_mode='nearest', cval=0.):
263
    """Applies an affine transformation specified by the parameters given.
264
265
    # Arguments
266
        x: 2D numpy array, single image.
267
        theta: Rotation angle in degrees.
268
        tx: Width shift.
269
        ty: Heigh shift.
270
        shear: Shear angle in degrees.
271
        zx: Zoom in x direction.
272
        zy: Zoom in y direction
273
        row_axis: Index of axis for rows in the input image.
274
        col_axis: Index of axis for columns in the input image.
275
        channel_axis: Index of axis for channels in the input image.
276
        fill_mode: Points outside the boundaries of the input
277
            are filled according to the given mode
278
            (one of `{'constant', 'nearest', 'reflect', 'wrap'}`).
279
        cval: Value used for points outside the boundaries
280
            of the input if `mode='constant'`.
281
282
    # Returns
283
        The transformed version of the input.
284
    """
285
    transform_matrix = None
286
    if theta != 0:
287
        theta = np.deg2rad(theta)
288
        rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0],
289
                                    [np.sin(theta), np.cos(theta), 0],
290
                                    [0, 0, 1]])
291
        transform_matrix = rotation_matrix
292
293
    if tx != 0 or ty != 0:
294
        shift_matrix = np.array([[1, 0, tx],
295
                                 [0, 1, ty],
296
                                 [0, 0, 1]])
297
        if transform_matrix is None:
298
            transform_matrix = shift_matrix
299
        else:
300
            transform_matrix = np.dot(transform_matrix, shift_matrix)
301
302
    if shear != 0:
303
        shear = np.deg2rad(shear)
304
        shear_matrix = np.array([[1, -np.sin(shear), 0],
305
                                 [0, np.cos(shear), 0],
306
                                 [0, 0, 1]])
307
        if transform_matrix is None:
308
            transform_matrix = shear_matrix
309
        else:
310
            transform_matrix = np.dot(transform_matrix, shear_matrix)
311
312
    if zx != 1 or zy != 1:
313
        zoom_matrix = np.array([[zx, 0, 0],
314
                                [0, zy, 0],
315
                                [0, 0, 1]])
316
        if transform_matrix is None:
317
            transform_matrix = zoom_matrix
318
        else:
319
            transform_matrix = np.dot(transform_matrix, zoom_matrix)
320
321
    if transform_matrix is not None:
322
        h, w = x.shape[row_axis], x.shape[col_axis]
323
        transform_matrix = transform_matrix_offset_center(
324
            transform_matrix, h, w)
325
        x = np.rollaxis(x, channel_axis, 0)
326
        final_affine_matrix = transform_matrix[:2, :2]
327
        final_offset = transform_matrix[:2, 2]
328
329
        channel_images = [scipy.ndimage.interpolation.affine_transform(
330
            x_channel,
331
            final_affine_matrix,
332
            final_offset,
333
            order=1,
334
            mode=fill_mode,
335
            cval=cval) for x_channel in x]
336
        x = np.stack(channel_images, axis=0)
337
        x = np.rollaxis(x, 0, channel_axis + 1)
338
    return x
339
340
        
341
342
def transform_matrix_offset_center(matrix, x, y):
343
    o_x = float(x) / 2 + 0.5
344
    o_y = float(y) / 2 + 0.5
345
    offset_matrix = np.array([[1, 0, o_x], [0, 1, o_y], [0, 0, 1]])
346
    reset_matrix = np.array([[1, 0, -o_x], [0, 1, -o_y], [0, 0, 1]])
347
    transform_matrix = np.dot(np.dot(offset_matrix, matrix), reset_matrix)
348
    return transform_matrix
349
350
351
352