Diff of /image.py [000000] .. [38391a]

Switch to unified view

a b/image.py
1
"""Utilities for real-time data augmentation on image data.
2
"""
3
from __future__ import absolute_import
4
from __future__ import division
5
from __future__ import print_function
6
7
import numpy as np
8
import re
9
from scipy import linalg
10
import scipy.ndimage as ndi
11
from six.moves import range
12
import os
13
import threading
14
import warnings
15
import multiprocessing.pool
16
import cv2
17
from functools import partial
18
from skimage import data, img_as_float
19
from skimage import exposure
20
21
from . import get_keras_submodule
22
23
backend = get_keras_submodule('backend')
24
keras_utils = get_keras_submodule('utils')
25
26
try:
27
    from PIL import ImageEnhance
28
    from PIL import Image as pil_image
29
except ImportError:
30
    pil_image = None
31
32
if pil_image is not None:
33
    _PIL_INTERPOLATION_METHODS = {
34
        'nearest': pil_image.NEAREST,
35
        'bilinear': pil_image.BILINEAR,
36
        'bicubic': pil_image.BICUBIC,
37
        'antialias' : pil_image.ANTIALIAS,
38
    }
39
    # These methods were only introduced in version 3.4.0 (2016).
40
    if hasattr(pil_image, 'HAMMING'):
41
        _PIL_INTERPOLATION_METHODS['hamming'] = pil_image.HAMMING
42
    if hasattr(pil_image, 'BOX'):
43
        _PIL_INTERPOLATION_METHODS['box'] = pil_image.BOX
44
    # This method is new in version 1.1.3 (2013).
45
    if hasattr(pil_image, 'LANCZOS'):
46
        _PIL_INTERPOLATION_METHODS['lanczos'] = pil_image.LANCZOS
47
48
49
def random_rotation(x, rg, row_axis=1, col_axis=2, channel_axis=0,
50
                    fill_mode='nearest', cval=0.):
51
    """Performs a random rotation of a Numpy image tensor.
52
53
    # Arguments
54
        x: Input tensor. Must be 3D.
55
        rg: Rotation range, in degrees.
56
        row_axis: Index of axis for rows in the input tensor.
57
        col_axis: Index of axis for columns in the input tensor.
58
        channel_axis: Index of axis for channels in the input tensor.
59
        fill_mode: Points outside the boundaries of the input
60
            are filled according to the given mode
61
            (one of `{'constant', 'nearest', 'reflect', 'wrap'}`).
62
        cval: Value used for points outside the boundaries
63
            of the input if `mode='constant'`.
64
65
    # Returns
66
        Rotated Numpy image tensor.
67
    """
68
    theta = np.random.uniform(-rg, rg)
69
    x = apply_affine_transform(x, theta=theta, channel_axis=channel_axis,
70
                               fill_mode=fill_mode, cval=cval)
71
    return x
72
73
74
def random_shift(x, wrg, hrg, row_axis=1, col_axis=2, channel_axis=0,
75
                 fill_mode='nearest', cval=0.):
76
    """Performs a random spatial shift of a Numpy image tensor.
77
78
    # Arguments
79
        x: Input tensor. Must be 3D.
80
        wrg: Width shift range, as a float fraction of the width.
81
        hrg: Height shift range, as a float fraction of the height.
82
        row_axis: Index of axis for rows in the input tensor.
83
        col_axis: Index of axis for columns in the input tensor.
84
        channel_axis: Index of axis for channels in the input tensor.
85
        fill_mode: Points outside the boundaries of the input
86
            are filled according to the given mode
87
            (one of `{'constant', 'nearest', 'reflect', 'wrap'}`).
88
        cval: Value used for points outside the boundaries
89
            of the input if `mode='constant'`.
90
91
    # Returns
92
        Shifted Numpy image tensor.
93
    """
94
    h, w = x.shape[row_axis], x.shape[col_axis]
95
    tx = np.random.uniform(-hrg, hrg) * h
96
    ty = np.random.uniform(-wrg, wrg) * w
97
    x = apply_affine_transform(x, tx=tx, ty=ty, channel_axis=channel_axis,
98
                               fill_mode=fill_mode, cval=cval)
99
    return x
100
101
102
def random_shear(x, intensity, row_axis=1, col_axis=2, channel_axis=0,
103
                 fill_mode='nearest', cval=0.):
104
    """Performs a random spatial shear of a Numpy image tensor.
105
106
    # Arguments
107
        x: Input tensor. Must be 3D.
108
        intensity: Transformation intensity in degrees.
109
        row_axis: Index of axis for rows in the input tensor.
110
        col_axis: Index of axis for columns in the input tensor.
111
        channel_axis: Index of axis for channels in the input tensor.
112
        fill_mode: Points outside the boundaries of the input
113
            are filled according to the given mode
114
            (one of `{'constant', 'nearest', 'reflect', 'wrap'}`).
115
        cval: Value used for points outside the boundaries
116
            of the input if `mode='constant'`.
117
118
    # Returns
119
        Sheared Numpy image tensor.
120
    """
121
    shear = np.random.uniform(-intensity, intensity)
122
    x = apply_affine_transform(x, shear=shear, channel_axis=channel_axis,
123
                               fill_mode=fill_mode, cval=cval)
124
    return x
125
126
127
def random_zoom(x, zoom_range, row_axis=1, col_axis=2, channel_axis=0,
128
                fill_mode='nearest', cval=0.):
129
    """Performs a random spatial zoom of a Numpy image tensor.
130
131
    # Arguments
132
        x: Input tensor. Must be 3D.
133
        zoom_range: Tuple of floats; zoom range for width and height.
134
        row_axis: Index of axis for rows in the input tensor.
135
        col_axis: Index of axis for columns in the input tensor.
136
        channel_axis: Index of axis for channels in the input tensor.
137
        fill_mode: Points outside the boundaries of the input
138
            are filled according to the given mode
139
            (one of `{'constant', 'nearest', 'reflect', 'wrap'}`).
140
        cval: Value used for points outside the boundaries
141
            of the input if `mode='constant'`.
142
143
    # Returns
144
        Zoomed Numpy image tensor.
145
146
    # Raises
147
        ValueError: if `zoom_range` isn't a tuple.
148
    """
149
    if len(zoom_range) != 2:
150
        raise ValueError('`zoom_range` should be a tuple or list of two'
151
                         ' floats. Received: ', zoom_range)
152
153
    if zoom_range[0] == 1 and zoom_range[1] == 1:
154
        zx, zy = 1, 1
155
    else:
156
        zx, zy = np.random.uniform(zoom_range[0], zoom_range[1], 2)
157
    x = apply_affine_transform(x, zx=zx, zy=zy, channel_axis=channel_axis,
158
                               fill_mode=fill_mode, cval=cval)
159
    return x
160
161
162
def apply_channel_shift(x, intensity, channel_axis=0):
163
    """Performs a channel shift.
164
165
    # Arguments
166
        x: Input tensor. Must be 3D.
167
        intensity: Transformation intensity.
168
        channel_axis: Index of axis for channels in the input tensor.
169
170
    # Returns
171
        Numpy image tensor.
172
173
    """
174
    x = np.rollaxis(x, channel_axis, 0)
175
    min_x, max_x = np.min(x), np.max(x)
176
    channel_images = [
177
        np.clip(x_channel + intensity,
178
                min_x,
179
                max_x)
180
        for x_channel in x]
181
    x = np.stack(channel_images, axis=0)
182
    x = np.rollaxis(x, 0, channel_axis + 1)
183
    return x
184
185
186
def random_channel_shift(x, intensity_range, channel_axis=0):
187
    """Performs a random channel shift.
188
189
    # Arguments
190
        x: Input tensor. Must be 3D.
191
        intensity_range: Transformation intensity.
192
        channel_axis: Index of axis for channels in the input tensor.
193
194
    # Returns
195
        Numpy image tensor.
196
    """
197
    intensity = np.random.uniform(-intensity_range, intensity_range)
198
    return apply_channel_shift(x, intensity, channel_axis=channel_axis)
199
200
201
def apply_brightness_shift(x, brightness):
202
    """Performs a brightness shift.
203
204
    # Arguments
205
        x: Input tensor. Must be 3D.
206
        brightness: Float. The new brightness value.
207
        channel_axis: Index of axis for channels in the input tensor.
208
209
    # Returns
210
        Numpy image tensor.
211
212
    # Raises
213
        ValueError if `brightness_range` isn't a tuple.
214
    """
215
    x = array_to_img(x)
216
    x = imgenhancer_Brightness = ImageEnhance.Brightness(x)
217
    x = imgenhancer_Brightness.enhance(brightness)
218
    x = img_to_array(x)
219
    return x
220
221
222
def random_brightness(x, brightness_range):
223
    """Performs a random brightness shift.
224
225
    # Arguments
226
        x: Input tensor. Must be 3D.
227
        brightness_range: Tuple of floats; brightness range.
228
        channel_axis: Index of axis for channels in the input tensor.
229
230
    # Returns
231
        Numpy image tensor.
232
233
    # Raises
234
        ValueError if `brightness_range` isn't a tuple.
235
    """
236
    if len(brightness_range) != 2:
237
        raise ValueError(
238
            '`brightness_range should be tuple or list of two floats. '
239
            'Received: %s' % brightness_range)
240
241
    u = np.random.uniform(brightness_range[0], brightness_range[1])
242
    return apply_brightness_shift(x, u)
243
244
245
def transform_matrix_offset_center(matrix, x, y):
246
    o_x = float(x) / 2 + 0.5
247
    o_y = float(y) / 2 + 0.5
248
    offset_matrix = np.array([[1, 0, o_x], [0, 1, o_y], [0, 0, 1]])
249
    reset_matrix = np.array([[1, 0, -o_x], [0, 1, -o_y], [0, 0, 1]])
250
    transform_matrix = np.dot(np.dot(offset_matrix, matrix), reset_matrix)
251
    return transform_matrix
252
253
254
def apply_affine_transform(x, theta=0, tx=0, ty=0, shear=0, zx=1, zy=1,
255
                           row_axis=0, col_axis=1, channel_axis=2,
256
                           fill_mode='nearest', cval=0.):
257
    """Applies an affine transformation specified by the parameters given.
258
259
    # Arguments
260
        x: 2D numpy array, single image.
261
        theta: Rotation angle in degrees.
262
        tx: Width shift.
263
        ty: Heigh shift.
264
        shear: Shear angle in degrees.
265
        zx: Zoom in x direction.
266
        zy: Zoom in y direction
267
        row_axis: Index of axis for rows in the input image.
268
        col_axis: Index of axis for columns in the input image.
269
        channel_axis: Index of axis for channels in the input image.
270
        fill_mode: Points outside the boundaries of the input
271
            are filled according to the given mode
272
            (one of `{'constant', 'nearest', 'reflect', 'wrap'}`).
273
        cval: Value used for points outside the boundaries
274
            of the input if `mode='constant'`.
275
276
    # Returns
277
        The transformed version of the input.
278
    """
279
    transform_matrix = None
280
    if theta != 0:
281
        theta = np.deg2rad(theta)
282
        rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0],
283
                                    [np.sin(theta), np.cos(theta), 0],
284
                                    [0, 0, 1]])
285
        transform_matrix = rotation_matrix
286
287
    if tx != 0 or ty != 0:
288
        shift_matrix = np.array([[1, 0, tx],
289
                                 [0, 1, ty],
290
                                 [0, 0, 1]])
291
        if transform_matrix is None:
292
            transform_matrix = shift_matrix
293
        else:
294
            transform_matrix = np.dot(transform_matrix, shift_matrix)
295
296
    if shear != 0:
297
        shear = np.deg2rad(shear)
298
        shear_matrix = np.array([[1, -np.sin(shear), 0],
299
                                 [0, np.cos(shear), 0],
300
                                 [0, 0, 1]])
301
        if transform_matrix is None:
302
            transform_matrix = shear_matrix
303
        else:
304
            transform_matrix = np.dot(transform_matrix, shear_matrix)
305
306
    if zx != 1 or zy != 1:
307
        zoom_matrix = np.array([[zx, 0, 0],
308
                                [0, zy, 0],
309
                                [0, 0, 1]])
310
        if transform_matrix is None:
311
            transform_matrix = zoom_matrix
312
        else:
313
            transform_matrix = np.dot(transform_matrix, zoom_matrix)
314
315
    if transform_matrix is not None:
316
        h, w = x.shape[row_axis], x.shape[col_axis]
317
        transform_matrix = transform_matrix_offset_center(
318
            transform_matrix, h, w)
319
        x = np.rollaxis(x, channel_axis, 0)
320
        final_affine_matrix = transform_matrix[:2, :2]
321
        final_offset = transform_matrix[:2, 2]
322
323
        channel_images = [ndi.interpolation.affine_transform(
324
            x_channel,
325
            final_affine_matrix,
326
            final_offset,
327
            order=1,
328
            mode=fill_mode,
329
            cval=cval) for x_channel in x]
330
        x = np.stack(channel_images, axis=0)
331
        x = np.rollaxis(x, 0, channel_axis + 1)
332
    return x
333
334
def rgb2gray(rgb):
335
    r,g,b = rgb[:,:,0], rgb[:,:,1], rgb[:,:,2]
336
    gray = 0.2989* r + 0.5870*g + 0.1140*b
337
    return gray
338
339
340
def flip_axis(x, axis):
341
    x = np.asarray(x).swapaxes(axis, 0)
342
    x = x[::-1, ...]
343
    x = x.swapaxes(0, axis)
344
    return x
345
346
347
def array_to_img(x, data_format=None, scale=True):
348
    """Converts a 3D Numpy array to a PIL Image instance.
349
350
    # Arguments
351
        x: Input Numpy array.
352
        data_format: Image data format.
353
            either "channels_first" or "channels_last".
354
        scale: Whether to rescale image values
355
            to be within `[0, 255]`.
356
357
    # Returns
358
        A PIL Image instance.
359
360
    # Raises
361
        ImportError: if PIL is not available.
362
        ValueError: if invalid `x` or `data_format` is passed.
363
    """
364
    if pil_image is None:
365
        raise ImportError('Could not import PIL.Image. '
366
                          'The use of `array_to_img` requires PIL.')
367
    x = np.asarray(x, dtype=backend.floatx())
368
    if x.ndim != 3:
369
        raise ValueError('Expected image array to have rank 3 (single image). '
370
                         'Got array with shape:', x.shape)
371
372
    if data_format is None:
373
        data_format = backend.image_data_format()
374
    if data_format not in {'channels_first', 'channels_last'}:
375
        raise ValueError('Invalid data_format:', data_format)
376
377
    # Original Numpy array x has format (height, width, channel)
378
    # or (channel, height, width)
379
    # but target PIL image has format (width, height, channel)
380
    if data_format == 'channels_first':
381
        x = x.transpose(1, 2, 0)
382
    if scale:
383
        x = x + max(-np.min(x), 0)
384
        x_max = np.max(x)
385
        if x_max != 0:
386
            x /= x_max
387
        x *= 255
388
    if x.shape[2] == 3:
389
        # RGB
390
        return pil_image.fromarray(x.astype('uint8'), 'RGB')
391
    elif x.shape[2] == 1:
392
        # grayscale
393
        return pil_image.fromarray(x[:, :, 0].astype('uint8'), 'L')
394
    else:
395
        raise ValueError('Unsupported channel number: ', x.shape[2])
396
397
398
def img_to_array(img, data_format=None):
399
    """Converts a PIL Image instance to a Numpy array.
400
401
    # Arguments
402
        img: PIL Image instance.
403
        data_format: Image data format,
404
            either "channels_first" or "channels_last".
405
406
    # Returns
407
        A 3D Numpy array.
408
409
    # Raises
410
        ValueError: if invalid `img` or `data_format` is passed.
411
    """
412
    if data_format is None:
413
        data_format = backend.image_data_format()
414
    if data_format not in {'channels_first', 'channels_last'}:
415
        raise ValueError('Unknown data_format: ', data_format)
416
    # Numpy array x has format (height, width, channel)
417
    # or (channel, height, width)
418
    # but original PIL image has format (width, height, channel)
419
    x = np.asarray(img, dtype=backend.floatx())
420
    if len(x.shape) == 3:
421
        if data_format == 'channels_first':
422
            x = x.transpose(2, 0, 1)
423
    elif len(x.shape) == 2:
424
        if data_format == 'channels_first':
425
            x = x.reshape((1, x.shape[0], x.shape[1]))
426
        else:
427
            x = x.reshape((x.shape[0], x.shape[1], 1))
428
    else:
429
        raise ValueError('Unsupported image shape: ', x.shape)
430
    return x
431
432
433
def save_img(path,
434
             x,
435
             data_format=None,
436
             file_format=None,
437
             scale=True, **kwargs):
438
    """Saves an image stored as a Numpy array to a path or file object.
439
440
    # Arguments
441
        path: Path or file object.
442
        x: Numpy array.
443
        data_format: Image data format,
444
            either "channels_first" or "channels_last".
445
        file_format: Optional file format override. If omitted, the
446
            format to use is determined from the filename extension.
447
            If a file object was used instead of a filename, this
448
            parameter should always be used.
449
        scale: Whether to rescale image values to be within `[0, 255]`.
450
        **kwargs: Additional keyword arguments passed to `PIL.Image.save()`.
451
    """
452
    img = array_to_img(x, data_format=data_format, scale=scale)
453
    img.save(path, format=file_format, **kwargs)
454
455
456
def load_img(path, grayscale=False, target_size=None,
457
             interpolation='nearest'):  #nearest
458
    """Loads an image into PIL format.
459
460
    # Arguments
461
        path: Path to image file.
462
        grayscale: Boolean, whether to load the image as grayscale.
463
        target_size: Either `None` (default to original size)
464
            or tuple of ints `(img_height, img_width)`.
465
        interpolation: Interpolation method used to resample the image if the
466
            target size is different from that of the loaded image.
467
            Supported methods are "nearest", "bilinear", and "bicubic".
468
            If PIL version 1.1.3 or newer is installed, "lanczos" is also
469
            supported. If PIL version 3.4.0 or newer is installed, "box" and
470
            "hamming" are also supported. By default, "nearest" is used.
471
472
    # Returns
473
        A PIL Image instance.
474
475
    # Raises
476
        ImportError: if PIL is not available.
477
        ValueError: if interpolation method is not supported.
478
    """
479
    if pil_image is None:
480
        raise ImportError('Could not import PIL.Image. '
481
                          'The use of `array_to_img` requires PIL.')
482
    img = pil_image.open(path)
483
    if grayscale:
484
        if img.mode != 'L':
485
            img = img.convert('L')
486
    else:
487
        if img.mode != 'RGB':
488
            img = img.convert('RGB')
489
    if target_size is not None:
490
        width_height_tuple = (target_size[1], target_size[0])
491
        if img.size != width_height_tuple:
492
            if interpolation not in _PIL_INTERPOLATION_METHODS:
493
                raise ValueError(
494
                    'Invalid interpolation method {} specified. Supported '
495
                    'methods are {}'.format(
496
                        interpolation,
497
                        ", ".join(_PIL_INTERPOLATION_METHODS.keys())))
498
            resample = _PIL_INTERPOLATION_METHODS[interpolation]
499
            img = img.resize(width_height_tuple, resample)
500
    return img
501
502
503
def list_pictures(directory, ext='jpg|jpeg|bmp|png|ppm'):
504
    return [os.path.join(root, f)
505
            for root, _, files in os.walk(directory) for f in files
506
            if re.match(r'([\w]+\.(?:' + ext + '))', f.lower())]
507
508
509
class ImageDataGenerator(object):
510
    """Generate batches of tensor image data with real-time data augmentation.
511
     The data will be looped over (in batches).
512
513
    # Arguments
514
        featurewise_center: Boolean.
515
            Set input mean to 0 over the dataset, feature-wise.
516
        samplewise_center: Boolean. Set each sample mean to 0.
517
        featurewise_std_normalization: Boolean.
518
            Divide inputs by std of the dataset, feature-wise.
519
        samplewise_std_normalization: Boolean. Divide each input by its std.
520
        zca_epsilon: epsilon for ZCA whitening. Default is 1e-6.
521
        zca_whitening: Boolean. Apply ZCA whitening.
522
        rotation_range: Int. Degree range for random rotations.
523
        width_shift_range: Float, 1-D array-like or int
524
            - float: fraction of total width, if < 1, or pixels if >= 1.
525
            - 1-D array-like: random elements from the array.
526
            - int: integer number of pixels from interval
527
                `(-width_shift_range, +width_shift_range)`
528
            - With `width_shift_range=2` possible values
529
                are integers `[-1, 0, +1]`,
530
                same as with `width_shift_range=[-1, 0, +1]`,
531
                while with `width_shift_range=1.0` possible values are floats
532
                in the interval [-1.0, +1.0).
533
        height_shift_range: Float, 1-D array-like or int
534
            - float: fraction of total height, if < 1, or pixels if >= 1.
535
            - 1-D array-like: random elements from the array.
536
            - int: integer number of pixels from interval
537
                `(-height_shift_range, +height_shift_range)`
538
            - With `height_shift_range=2` possible values
539
                are integers `[-1, 0, +1]`,
540
                same as with `height_shift_range=[-1, 0, +1]`,
541
                while with `height_shift_range=1.0` possible values are floats
542
                in the interval [-1.0, +1.0).
543
        shear_range: Float. Shear Intensity
544
            (Shear angle in counter-clockwise direction in degrees)
545
        zoom_range: Float or [lower, upper]. Range for random zoom.
546
            If a float, `[lower, upper] = [1-zoom_range, 1+zoom_range]`.
547
        channel_shift_range: Float. Range for random channel shifts.
548
        fill_mode: One of {"constant", "nearest", "reflect" or "wrap"}.
549
            Default is 'nearest'.
550
            Points outside the boundaries of the input are filled
551
            according to the given mode:
552
            - 'constant': kkkkkkkk|abcd|kkkkkkkk (cval=k)
553
            - 'nearest':  aaaaaaaa|abcd|dddddddd
554
            - 'reflect':  abcddcba|abcd|dcbaabcd
555
            - 'wrap':  abcdabcd|abcd|abcdabcd
556
        cval: Float or Int.
557
            Value used for points outside the boundaries
558
            when `fill_mode = "constant"`.
559
        horizontal_flip: Boolean. Randomly flip inputs horizontally.
560
        vertical_flip: Boolean. Randomly flip inputs vertically.
561
        rescale: rescaling factor. Defaults to None.
562
            If None or 0, no rescaling is applied,
563
            otherwise we multiply the data by the value provided
564
            (before applying any other transformation).
565
        preprocessing_function: function that will be implied on each input.
566
            The function will run after the image is resized and augmented.
567
            The function should take one argument:
568
            one image (Numpy tensor with rank 3),
569
            and should output a Numpy tensor with the same shape.
570
        data_format: Image data format,
571
            either "channels_first" or "channels_last".
572
            "channels_last" mode means that the images should have shape
573
            `(samples, height, width, channels)`,
574
            "channels_first" mode means that the images should have shape
575
            `(samples, channels, height, width)`.
576
            It defaults to the `image_data_format` value found in your
577
            Keras config file at `~/.keras/keras.json`.
578
            If you never set it, then it will be "channels_last".
579
        validation_split: Float. Fraction of images reserved for validation
580
            (strictly between 0 and 1).
581
582
    # Examples
583
    Example of using `.flow(x, y)`:
584
585
    ```python
586
    (x_train, y_train), (x_test, y_test) = cifar10.load_data()
587
    y_train = np_utils.to_categorical(y_train, num_classes)
588
    y_test = np_utils.to_categorical(y_test, num_classes)
589
590
    datagen = ImageDataGenerator(
591
        featurewise_center=True,
592
        featurewise_std_normalization=True,
593
        rotation_range=20,
594
        width_shift_range=0.2,
595
        height_shift_range=0.2,
596
        horizontal_flip=True)
597
598
    # compute quantities required for featurewise normalization
599
    # (std, mean, and principal components if ZCA whitening is applied)
600
    datagen.fit(x_train)
601
602
    # fits the model on batches with real-time data augmentation:
603
    model.fit_generator(datagen.flow(x_train, y_train, batch_size=32),
604
                        steps_per_epoch=len(x_train) / 32, epochs=epochs)
605
606
    # here's a more "manual" example
607
    for e in range(epochs):
608
        print('Epoch', e)
609
        batches = 0
610
        for x_batch, y_batch in datagen.flow(x_train, y_train, batch_size=32):
611
            model.fit(x_batch, y_batch)
612
            batches += 1
613
            if batches >= len(x_train) / 32:
614
                # we need to break the loop by hand because
615
                # the generator loops indefinitely
616
                break
617
    ```
618
    Example of using `.flow_from_directory(directory)`:
619
620
    ```python
621
    train_datagen = ImageDataGenerator(
622
            rescale=1./255,
623
            shear_range=0.2,
624
            zoom_range=0.2,
625
            horizontal_flip=True)
626
627
    test_datagen = ImageDataGenerator(rescale=1./255)
628
629
    train_generator = train_datagen.flow_from_directory(
630
            'data/train',
631
            target_size=(150, 150),
632
            batch_size=32,
633
            class_mode='binary')
634
635
    validation_generator = test_datagen.flow_from_directory(
636
            'data/validation',
637
            target_size=(150, 150),
638
            batch_size=32,
639
            class_mode='binary')
640
641
    model.fit_generator(
642
            train_generator,
643
            steps_per_epoch=2000,
644
            epochs=50,
645
            validation_data=validation_generator,
646
            validation_steps=800)
647
    ```
648
649
    Example of transforming images and masks together.
650
651
    ```python
652
    # we create two instances with the same arguments
653
    data_gen_args = dict(featurewise_center=True,
654
                         featurewise_std_normalization=True,
655
                         rotation_range=90.,
656
                         width_shift_range=0.1,
657
                         height_shift_range=0.1,
658
                         zoom_range=0.2)
659
    image_datagen = ImageDataGenerator(**data_gen_args)
660
    mask_datagen = ImageDataGenerator(**data_gen_args)
661
662
    # Provide the same seed and keyword arguments to the fit and flow methods
663
    seed = 1
664
    image_datagen.fit(images, augment=True, seed=seed)
665
    mask_datagen.fit(masks, augment=True, seed=seed)
666
667
    image_generator = image_datagen.flow_from_directory(
668
        'data/images',
669
        class_mode=None,
670
        seed=seed)
671
672
    mask_generator = mask_datagen.flow_from_directory(
673
        'data/masks',
674
        class_mode=None,
675
        seed=seed)
676
677
    # combine generators into one which yields image and masks
678
    train_generator = zip(image_generator, mask_generator)
679
680
    model.fit_generator(
681
        train_generator,
682
        steps_per_epoch=2000,
683
        epochs=50)
684
    ```
685
    """
686
687
    def __init__(self,
688
                 contrast_stretching=False,
689
                 histogram_equalization=False,
690
                 adaptive_equalization=False,
691
                 featurewise_center=False,
692
                 samplewise_center=False,
693
                 featurewise_std_normalization=False,
694
                 samplewise_std_normalization=False,
695
                 zca_whitening=False,
696
                 zca_epsilon=1e-6,
697
                 rotation_range=0.,
698
                 width_shift_range=0.,
699
                 height_shift_range=0.,
700
                 brightness_range=None,
701
                 shear_range=0.,
702
                 zoom_range=0.,
703
                 channel_shift_range=0.,
704
                 fill_mode='nearest',
705
                 cval=0.,
706
                 horizontal_flip=False,
707
                 vertical_flip=False,
708
                 rescale=None,
709
                 preprocessing_function=None,
710
                 data_format=None,
711
                 validation_split=0.0):
712
        if data_format is None:
713
            data_format = backend.image_data_format()
714
        self.contrast_stretching = contrast_stretching
715
        self.histogram_equalization = histogram_equalization
716
        self.adaptive_equalization = adaptive_equalization
717
        self.featurewise_center = featurewise_center
718
        self.samplewise_center = samplewise_center
719
        self.featurewise_std_normalization = featurewise_std_normalization
720
        self.samplewise_std_normalization = samplewise_std_normalization
721
        self.zca_whitening = zca_whitening
722
        self.zca_epsilon = zca_epsilon
723
        self.rotation_range = rotation_range
724
        self.width_shift_range = width_shift_range
725
        self.height_shift_range = height_shift_range
726
        self.brightness_range = brightness_range
727
        self.shear_range = shear_range
728
        self.zoom_range = zoom_range
729
        self.channel_shift_range = channel_shift_range
730
        self.fill_mode = fill_mode
731
        self.cval = cval
732
        self.horizontal_flip = horizontal_flip
733
        self.vertical_flip = vertical_flip
734
        self.rescale = rescale
735
        self.preprocessing_function = preprocessing_function
736
737
        if data_format not in {'channels_last', 'channels_first'}:
738
            raise ValueError(
739
                '`data_format` should be `"channels_last"` '
740
                '(channel after row and column) or '
741
                '`"channels_first"` (channel before row and column). '
742
                'Received: %s' % data_format)
743
        self.data_format = data_format
744
        if data_format == 'channels_first':
745
            self.channel_axis = 1
746
            self.row_axis = 2
747
            self.col_axis = 3
748
        if data_format == 'channels_last':
749
            self.channel_axis = 3
750
            self.row_axis = 1
751
            self.col_axis = 2
752
        if validation_split and not 0 < validation_split < 1:
753
            raise ValueError(
754
                '`validation_split` must be strictly between 0 and 1. '
755
                ' Received: %s' % validation_split)
756
        self._validation_split = validation_split
757
758
        self.mean = None
759
        self.std = None
760
        self.principal_components = None
761
762
        if np.isscalar(zoom_range):
763
            self.zoom_range = [1 - zoom_range, 1 + zoom_range]
764
        elif len(zoom_range) == 2:
765
            self.zoom_range = [zoom_range[0], zoom_range[1]]
766
        else:
767
            raise ValueError('`zoom_range` should be a float or '
768
                             'a tuple or list of two floats. '
769
                             'Received: %s' % zoom_range)
770
        if zca_whitening:
771
            if not featurewise_center:
772
                self.featurewise_center = True
773
                warnings.warn('This ImageDataGenerator specifies '
774
                              '`zca_whitening`, which overrides '
775
                              'setting of `featurewise_center`.')
776
            if featurewise_std_normalization:
777
                self.featurewise_std_normalization = False
778
                warnings.warn('This ImageDataGenerator specifies '
779
                              '`zca_whitening` '
780
                              'which overrides setting of'
781
                              '`featurewise_std_normalization`.')
782
        if featurewise_std_normalization:
783
            if not featurewise_center:
784
                self.featurewise_center = True
785
                warnings.warn('This ImageDataGenerator specifies '
786
                              '`featurewise_std_normalization`, '
787
                              'which overrides setting of '
788
                              '`featurewise_center`.')
789
        if samplewise_std_normalization:
790
            if not samplewise_center:
791
                self.samplewise_center = True
792
                warnings.warn('This ImageDataGenerator specifies '
793
                              '`samplewise_std_normalization`, '
794
                              'which overrides setting of '
795
                              '`samplewise_center`.')
796
797
    def flow(self, x,
798
             y=None, batch_size=32, shuffle=True,
799
             sample_weight=None, seed=None,
800
             save_to_dir=None, save_prefix='', save_format='png', subset=None):
801
        """Takes data & label arrays, generates batches of augmented data.
802
803
        # Arguments
804
            x: Input data. Numpy array of rank 4 or a tuple.
805
                If tuple, the first element
806
                should contain the images and the second element
807
                another numpy array or a list of numpy arrays
808
                that gets passed to the output
809
                without any modifications.
810
                Can be used to feed the model miscellaneous data
811
                along with the images.
812
                In case of grayscale data, the channels axis of the image array
813
                should have value 1, and in case
814
                of RGB data, it should have value 3.
815
            y: Labels.
816
            batch_size: Int (default: 32).
817
            shuffle: Boolean (default: True).
818
            sample_weight: Sample weights.
819
            seed: Int (default: None).
820
            save_to_dir: None or str (default: None).
821
                This allows you to optionally specify a directory
822
                to which to save the augmented pictures being generated
823
                (useful for visualizing what you are doing).
824
            save_prefix: Str (default: `''`).
825
                Prefix to use for filenames of saved pictures
826
                (only relevant if `save_to_dir` is set).
827
                save_format: one of "png", "jpeg"
828
                (only relevant if `save_to_dir` is set). Default: "png".
829
            subset: Subset of data (`"training"` or `"validation"`) if
830
                `validation_split` is set in `ImageDataGenerator`.
831
832
        # Returns
833
            An `Iterator` yielding tuples of `(x, y)`
834
                where `x` is a numpy array of image data
835
                (in the case of a single image input) or a list
836
                of numpy arrays (in the case with
837
                additional inputs) and `y` is a numpy array
838
                of corresponding labels. If 'sample_weight' is not None,
839
                the yielded tuples are of the form `(x, y, sample_weight)`.
840
                If `y` is None, only the numpy array `x` is returned.
841
        """
842
        return NumpyArrayIterator(
843
            x, y, self,
844
            batch_size=batch_size,
845
            shuffle=shuffle,
846
            sample_weight=sample_weight,
847
            seed=seed,
848
            data_format=self.data_format,
849
            save_to_dir=save_to_dir,
850
            save_prefix=save_prefix,
851
            save_format=save_format,
852
            subset=subset)
853
854
    def flow_from_directory(self, directory,
855
                            target_size=(256, 256), color_mode='rgb',
856
                            classes=None, class_mode='categorical',
857
                            batch_size=32, shuffle=True, seed=None,
858
                            save_to_dir=None,
859
                            save_prefix='',
860
                            save_format='png',
861
                            follow_links=False,
862
                            subset=None,
863
                            interpolation='nearest'):
864
        """Takes the path to a directory & generates batches of augmented data.
865
866
        # Arguments
867
            directory: Path to the target directory.
868
                It should contain one subdirectory per class.
869
                Any PNG, JPG, BMP, PPM or TIF images
870
                inside each of the subdirectories directory tree
871
                will be included in the generator.
872
                See [this script](
873
                    https://gist.github.com/fchollet/
874
                    0830affa1f7f19fd47b06d4cf89ed44d)
875
                for more details.
876
            target_size: Tuple of integers `(height, width)`,
877
                default: `(256, 256)`.
878
                The dimensions to which all images found will be resized.
879
            color_mode: One of "grayscale", "rbg". Default: "rgb".
880
                Whether the images will be converted to
881
                have 1 or 3 color channels.
882
            classes: Optional list of class subdirectories
883
                (e.g. `['dogs', 'cats']`). Default: None.
884
                If not provided, the list of classes will be automatically
885
                inferred from the subdirectory names/structure
886
                under `directory`, where each subdirectory will
887
                be treated as a different class
888
                (and the order of the classes, which will map to the label
889
                indices, will be alphanumeric).
890
                The dictionary containing the mapping from class names to class
891
                indices can be obtained via the attribute `class_indices`.
892
            class_mode: One of "categorical", "binary", "sparse",
893
                "input", or None. Default: "categorical".
894
                Determines the type of label arrays that are returned:
895
                - "categorical" will be 2D one-hot encoded labels,
896
                - "binary" will be 1D binary labels,
897
                    "sparse" will be 1D integer labels,
898
                - "input" will be images identical
899
                    to input images (mainly used to work with autoencoders).
900
                - If None, no labels are returned
901
                  (the generator will only yield batches of image data,
902
                  which is useful to use with `model.predict_generator()`,
903
                  `model.evaluate_generator()`, etc.).
904
                  Please note that in case of class_mode None,
905
                  the data still needs to reside in a subdirectory
906
                  of `directory` for it to work correctly.
907
            batch_size: Size of the batches of data (default: 32).
908
            shuffle: Whether to shuffle the data (default: True)
909
            seed: Optional random seed for shuffling and transformations.
910
            save_to_dir: None or str (default: None).
911
                This allows you to optionally specify
912
                a directory to which to save
913
                the augmented pictures being generated
914
                (useful for visualizing what you are doing).
915
            save_prefix: Str. Prefix to use for filenames of saved pictures
916
                (only relevant if `save_to_dir` is set).
917
            save_format: One of "png", "jpeg"
918
                (only relevant if `save_to_dir` is set). Default: "png".
919
            follow_links: Whether to follow symlinks inside
920
                class subdirectories (default: False).
921
            subset: Subset of data (`"training"` or `"validation"`) if
922
                `validation_split` is set in `ImageDataGenerator`.
923
            interpolation: Interpolation method used to
924
                resample the image if the
925
                target size is different from that of the loaded image.
926
                Supported methods are `"nearest"`, `"bilinear"`,
927
                and `"bicubic"`.
928
                If PIL version 1.1.3 or newer is installed, `"lanczos"` is also
929
                supported. If PIL version 3.4.0 or newer is installed,
930
                `"box"` and `"hamming"` are also supported.
931
                By default, `"nearest"` is used.
932
933
        # Returns
934
            A `DirectoryIterator` yielding tuples of `(x, y)`
935
                where `x` is a numpy array containing a batch
936
                of images with shape `(batch_size, *target_size, channels)`
937
                and `y` is a numpy array of corresponding labels.
938
        """
939
        return DirectoryIterator(
940
            directory, self,
941
            target_size=target_size, color_mode=color_mode,
942
            classes=classes, class_mode=class_mode,
943
            data_format=self.data_format,
944
            batch_size=batch_size, shuffle=shuffle, seed=seed,
945
            save_to_dir=save_to_dir,
946
            save_prefix=save_prefix,
947
            save_format=save_format,
948
            follow_links=follow_links,
949
            subset=subset,
950
            interpolation=interpolation)
951
952
    def standardize(self, x):
953
        """Applies the normalization configuration to a batch of inputs.
954
955
        # Arguments
956
            x: Batch of inputs to be normalized.
957
958
        # Returns
959
            The inputs, normalized.
960
        """
961
        imagenet_mean = np.array([0.485, 0.456, 0.406])
962
        imagenet_std  = np.array([0.229, 0.224, 0.225])
963
964
        if self.rescale:
965
            x *= self.rescale
966
        if self.preprocessing_function:
967
            x = self.preprocessing_function(x)
968
#        if self.rescale:
969
#            x *= self.rescale
970
        if self.samplewise_center:
971
            x -= np.mean(x, keepdims=True)
972
        if self.samplewise_std_normalization:
973
            x /= (np.std(x, keepdims=True) + backend.epsilon())
974
975
        #x = (x - imagenet_mean) / imagenet_std
976
977
        if self.featurewise_center:
978
            if self.mean is not None:
979
                x -= self.mean
980
            else:
981
                warnings.warn('This ImageDataGenerator specifies '
982
                              '`featurewise_center`, but it hasn\'t '
983
                              'been fit on any training data. Fit it '
984
                              'first by calling `.fit(numpy_data)`.')
985
        if self.featurewise_std_normalization:
986
            if self.std is not None:
987
                x /= (self.std + backend.epsilon())
988
            else:
989
                warnings.warn('This ImageDataGenerator specifies '
990
                              '`featurewise_std_normalization`, '
991
                              'but it hasn\'t '
992
                              'been fit on any training data. Fit it '
993
                              'first by calling `.fit(numpy_data)`.')
994
        if self.zca_whitening:
995
            if self.principal_components is not None:
996
                flatx = np.reshape(x, (-1, np.prod(x.shape[-3:])))
997
                whitex = np.dot(flatx, self.principal_components)
998
                x = np.reshape(whitex, x.shape)
999
            else:
1000
                warnings.warn('This ImageDataGenerator specifies '
1001
                              '`zca_whitening`, but it hasn\'t '
1002
                              'been fit on any training data. Fit it '
1003
                              'first by calling ')
1004
1005
1006
#        if self.contrast_stretching:
1007
#            if np.random.random() < 0.5:
1008
#                p2, p98 = np.percentile((x),(2,98))
1009
#                x = (exposure.rescale_intensity((x), in_range=(p2, p98)))
1010
1011
     #   if self.adaptive_equalization:
1012
     #       if np.random.random() < 0.5:
1013
     #               x = (exposure.equalize_adapthist((x), clip_limit = 0.03))
1014
1015
     #   if self.histogram_equalization:
1016
     #       if np.random.random() < 0.5:
1017
     #               x = (exposure.equalize_hist((x)))
1018
1019
1020
        return x
1021
1022
1023
    def get_random_transform(self, img_shape, seed=None):
1024
        """Generates random parameters for a transformation.
1025
1026
        # Arguments
1027
            seed: Random seed.
1028
            img_shape: Tuple of integers.
1029
                Shape of the image that is transformed.
1030
1031
        # Returns
1032
            A dictionary containing randomly chosen parameters describing the
1033
            transformation.
1034
        """
1035
        img_row_axis = self.row_axis - 1
1036
        img_col_axis = self.col_axis - 1
1037
1038
        if seed is not None:
1039
            np.random.seed(seed)
1040
1041
        if self.rotation_range:
1042
            theta = np.random.uniform(
1043
                -self.rotation_range,
1044
                self.rotation_range)
1045
        else:
1046
            theta = 0
1047
1048
        if self.height_shift_range:
1049
            try:  # 1-D array-like or int
1050
                tx = np.random.choice(self.height_shift_range)
1051
                tx *= np.random.choice([-1, 1])
1052
            except ValueError:  # floating point
1053
                tx = np.random.uniform(-self.height_shift_range,
1054
                                       self.height_shift_range)
1055
            if np.max(self.height_shift_range) < 1:
1056
                tx *= img_shape[img_row_axis]
1057
        else:
1058
            tx = 0
1059
1060
        if self.width_shift_range:
1061
            try:  # 1-D array-like or int
1062
                ty = np.random.choice(self.width_shift_range)
1063
                ty *= np.random.choice([-1, 1])
1064
            except ValueError:  # floating point
1065
                ty = np.random.uniform(-self.width_shift_range,
1066
                                       self.width_shift_range)
1067
            if np.max(self.width_shift_range) < 1:
1068
                ty *= img_shape[img_col_axis]
1069
        else:
1070
            ty = 0
1071
1072
        if self.shear_range:
1073
            shear = np.random.uniform(
1074
                -self.shear_range,
1075
                self.shear_range)
1076
        else:
1077
            shear = 0
1078
1079
        if self.zoom_range[0] == 1 and self.zoom_range[1] == 1:
1080
            zx, zy = 1, 1
1081
        else:
1082
            zx, zy = np.random.uniform(
1083
                self.zoom_range[0],
1084
                self.zoom_range[1],
1085
                2)
1086
1087
        flip_horizontal = (np.random.random() < 0.5) * self.horizontal_flip
1088
        flip_vertical = (np.random.random() < 0.5) * self.vertical_flip
1089
1090
        channel_shift_intensity = None
1091
        if self.channel_shift_range != 0:
1092
            channel_shift_intensity = np.random.uniform(-self.channel_shift_range,
1093
                                                        self.channel_shift_range)
1094
1095
        brightness = None
1096
        if self.brightness_range is not None:
1097
            if len(self.brightness_range) != 2:
1098
                raise ValueError(
1099
                    '`brightness_range should be tuple or list of two floats. '
1100
                    'Received: %s' % brightness_range)
1101
            brightness = np.random.uniform(self.brightness_range[0],
1102
                                           self.brightness_range[1])
1103
1104
        transform_parameters = {'theta': theta,
1105
                                'tx': tx,
1106
                                'ty': ty,
1107
                                'shear': shear,
1108
                                'zx': zx,
1109
                                'zy': zy,
1110
                                'flip_horizontal': flip_horizontal,
1111
                                'flip_vertical': flip_vertical,
1112
                                'channel_shift_intensity': channel_shift_intensity,
1113
                                'brightness': brightness,
1114
                                'contrast_stretching' : self.contrast_stretching,
1115
                                'adaptive_equalization' : self.adaptive_equalization,
1116
                                'histogram_equalization' : self.histogram_equalization
1117
                                }
1118
1119
        return transform_parameters
1120
1121
    def apply_transform(self, x, transform_parameters):
1122
        """Applies a transformation to an image according to given parameters.
1123
1124
        # Arguments
1125
            x: 3D tensor, single image.
1126
            transform_parameters: Dictionary with string - parameter pairs
1127
                describing the transformation.
1128
                Currently, the following parameters
1129
                from the dictionary are used:
1130
                - `'theta'`: Float. Rotation angle in degrees.
1131
                - `'tx'`: Float. Shift in the x direction.
1132
                - `'ty'`: Float. Shift in the y direction.
1133
                - `'shear'`: Float. Shear angle in degrees.
1134
                - `'zx'`: Float. Zoom in the x direction.
1135
                - `'zy'`: Float. Zoom in the y direction.
1136
                - `'flip_horizontal'`: Boolean. Horizontal flip.
1137
                - `'flip_vertical'`: Boolean. Vertical flip.
1138
                - `'channel_shift_intencity'`: Float. Channel shift intensity.
1139
                - `'brightness'`: Float. Brightness shift intensity.
1140
1141
        # Returns
1142
            A ransformed version of the input (same shape).
1143
        """
1144
        # x is a single image, so it doesn't have image number at index 0
1145
        img_row_axis = self.row_axis - 1
1146
        img_col_axis = self.col_axis - 1
1147
        img_channel_axis = self.channel_axis - 1
1148
1149
        x = apply_affine_transform(x, transform_parameters.get('theta', 0),
1150
                                   transform_parameters.get('tx', 0),
1151
                                   transform_parameters.get('ty', 0),
1152
                                   transform_parameters.get('shear', 0),
1153
                                   transform_parameters.get('zx', 1),
1154
                                   transform_parameters.get('zy', 1),
1155
                                   row_axis=img_row_axis, col_axis=img_col_axis,
1156
                                   channel_axis=img_channel_axis,
1157
                                   fill_mode=self.fill_mode, cval=self.cval)
1158
1159
        if transform_parameters.get('channel_shift_intensity') is not None:
1160
            x = apply_channel_shift(x,
1161
                                    transform_parameters['channel_shift_intensity'],
1162
                                    img_channel_axis)
1163
1164
        if transform_parameters.get('flip_horizontal', False):
1165
            x = flip_axis(x, img_col_axis)
1166
1167
        if transform_parameters.get('flip_vertical', False):
1168
            x = flip_axis(x, img_row_axis)
1169
1170
        if transform_parameters.get('brightness') is not None:
1171
            x = apply_brightness_shift(x, transform_parameters['brightness'])
1172
1173
1174
1175
        if transform_parameters.get('contrast_stretching') is not None:
1176
           if np.random.random() < 1.0:
1177
               x = img_to_array(x)
1178
               p2, p98 = np.percentile((x),(2,98))
1179
               x = (exposure.rescale_intensity((x), in_range=(p2, p98)))
1180
              # x = x.reshape((x.shape[0], x.shape[1],3))
1181
1182
#        if transform_parameters.get('adaptive_equalization') is not None:
1183
#           if np.random.random() < 1.0:
1184
#               x = (exposure.equalize_adapthist(x/255, clip_limit = 0.03))
1185
#               x = x.reshape((x.shape[0], x.shape[1],1))
1186
1187
        if transform_parameters.get('histogram_equalization') is not None:
1188
            if np.random.random() < 1.0:
1189
               x[:,:,0] = exposure.equalize_hist(x[:,:,0])
1190
               x[:,:,1] = exposure.equalize_hist(x[:,:,1])
1191
               x[:,:,2] = exposure.equalize_hist(x[:,:,2])
1192
1193
#                x = x.reshape((x.shape[0], x.shape[1],3))
1194
#                x = x.reshape((x.shape[0], x.shape[1], 1))
1195
1196
1197
        return x
1198
1199
    def random_transform(self, x, seed=None):
1200
        """Applies a random transformation to an image.
1201
1202
        # Arguments
1203
            x: 3D tensor, single image.
1204
            seed: Random seed.
1205
1206
        # Returns
1207
            A randomly transformed version of the input (same shape).
1208
        """
1209
        params = self.get_random_transform(x.shape, seed)
1210
        return self.apply_transform(x, params)
1211
1212
    def fit(self, x,
1213
            augment=False,
1214
            rounds=1,
1215
            seed=None):
1216
        """Fits the data generator to some sample data.
1217
1218
        This computes the internal data stats related to the
1219
        data-dependent transformations, based on an array of sample data.
1220
1221
        Only required if `featurewise_center` or
1222
        `featurewise_std_normalization` or `zca_whitening` are set to True.
1223
1224
        # Arguments
1225
            x: Sample data. Should have rank 4.
1226
             In case of grayscale data,
1227
             the channels axis should have value 1, and in case
1228
             of RGB data, it should have value 3.
1229
            augment: Boolean (default: False).
1230
                Whether to fit on randomly augmented samples.
1231
            rounds: Int (default: 1).
1232
                If using data augmentation (`augment=True`),
1233
                this is how many augmentation passes over the data to use.
1234
            seed: Int (default: None). Random seed.
1235
       """
1236
        x = np.asarray(x, dtype=backend.floatx())
1237
        if x.ndim != 4:
1238
            raise ValueError('Input to `.fit()` should have rank 4. '
1239
                             'Got array with shape: ' + str(x.shape))
1240
        if x.shape[self.channel_axis] not in {1, 3, 4}:
1241
            warnings.warn(
1242
                'Expected input to be images (as Numpy array) '
1243
                'following the data format convention "' +
1244
                self.data_format + '" (channels on axis ' +
1245
                str(self.channel_axis) + '), i.e. expected '
1246
                'either 1, 3 or 4 channels on axis ' +
1247
                str(self.channel_axis) + '. '
1248
                'However, it was passed an array with shape ' +
1249
                str(x.shape) + ' (' + str(x.shape[self.channel_axis]) +
1250
                ' channels).')
1251
1252
        if seed is not None:
1253
            np.random.seed(seed)
1254
1255
        x = np.copy(x)
1256
        if augment:
1257
            ax = np.zeros(
1258
                tuple([rounds * x.shape[0]] + list(x.shape)[1:]),
1259
                dtype=backend.floatx())
1260
            for r in range(rounds):
1261
                for i in range(x.shape[0]):
1262
                    ax[i + r * x.shape[0]] = self.random_transform(x[i])
1263
            x = ax
1264
1265
        if self.featurewise_center:
1266
            self.mean = np.mean(x, axis=(0, self.row_axis, self.col_axis))
1267
            broadcast_shape = [1, 1, 1]
1268
            broadcast_shape[self.channel_axis - 1] = x.shape[self.channel_axis]
1269
            self.mean = np.reshape(self.mean, broadcast_shape)
1270
            x -= self.mean
1271
1272
        if self.featurewise_std_normalization:
1273
            self.std = np.std(x, axis=(0, self.row_axis, self.col_axis))
1274
            broadcast_shape = [1, 1, 1]
1275
            broadcast_shape[self.channel_axis - 1] = x.shape[self.channel_axis]
1276
            self.std = np.reshape(self.std, broadcast_shape)
1277
            x /= (self.std + backend.epsilon())
1278
1279
        if self.zca_whitening:
1280
            flat_x = np.reshape(
1281
                x, (x.shape[0], x.shape[1] * x.shape[2] * x.shape[3]))
1282
            sigma = np.dot(flat_x.T, flat_x) / flat_x.shape[0]
1283
            u, s, _ = linalg.svd(sigma)
1284
            s_inv = 1. / np.sqrt(s[np.newaxis] + self.zca_epsilon)
1285
            self.principal_components = (u * s_inv).dot(u.T)
1286
1287
1288
class Iterator(keras_utils.Sequence):
1289
    """Base class for image data iterators.
1290
1291
    Every `Iterator` must implement the `_get_batches_of_transformed_samples`
1292
    method.
1293
1294
    # Arguments
1295
        n: Integer, total number of samples in the dataset to loop over.
1296
        batch_size: Integer, size of a batch.
1297
        shuffle: Boolean, whether to shuffle the data between epochs.
1298
        seed: Random seeding for data shuffling.
1299
    """
1300
1301
    def __init__(self, n, batch_size, shuffle, seed):
1302
        self.n = n
1303
        self.batch_size = batch_size
1304
        self.seed = seed
1305
        self.shuffle = shuffle
1306
        self.batch_index = 0
1307
        self.total_batches_seen = 0
1308
        self.lock = threading.Lock()
1309
        self.index_array = None
1310
        self.index_generator = self._flow_index()
1311
1312
    def _set_index_array(self):
1313
        self.index_array = np.arange(self.n)
1314
        if self.shuffle:
1315
            self.index_array = np.random.permutation(self.n)
1316
1317
    def __getitem__(self, idx):
1318
        if idx >= len(self):
1319
            raise ValueError('Asked to retrieve element {idx}, '
1320
                             'but the Sequence '
1321
                             'has length {length}'.format(idx=idx,
1322
                                                          length=len(self)))
1323
        if self.seed is not None:
1324
            np.random.seed(self.seed + self.total_batches_seen)
1325
        self.total_batches_seen += 1
1326
        if self.index_array is None:
1327
            self._set_index_array()
1328
        index_array = self.index_array[self.batch_size * idx:
1329
                                       self.batch_size * (idx + 1)]
1330
        return self._get_batches_of_transformed_samples(index_array)
1331
1332
    def __len__(self):
1333
        return (self.n + self.batch_size - 1) // self.batch_size  # round up
1334
1335
    def on_epoch_end(self):
1336
        self._set_index_array()
1337
1338
    def reset(self):
1339
        self.batch_index = 0
1340
1341
    def _flow_index(self):
1342
        # Ensure self.batch_index is 0.
1343
        self.reset()
1344
        while 1:
1345
            if self.seed is not None:
1346
                np.random.seed(self.seed + self.total_batches_seen)
1347
            if self.batch_index == 0:
1348
                self._set_index_array()
1349
1350
            current_index = (self.batch_index * self.batch_size) % self.n
1351
            if self.n > current_index + self.batch_size:
1352
                self.batch_index += 1
1353
            else:
1354
                self.batch_index = 0
1355
            self.total_batches_seen += 1
1356
            yield self.index_array[current_index:
1357
                                   current_index + self.batch_size]
1358
1359
    def __iter__(self):
1360
        # Needed if we want to do something like:
1361
        # for x, y in data_gen.flow(...):
1362
        return self
1363
1364
    def __next__(self, *args, **kwargs):
1365
        return self.next(*args, **kwargs)
1366
1367
    def _get_batches_of_transformed_samples(self, index_array):
1368
        """Gets a batch of transformed samples.
1369
1370
        # Arguments
1371
            index_array: Array of sample indices to include in batch.
1372
1373
        # Returns
1374
            A batch of transformed samples.
1375
        """
1376
        raise NotImplementedError
1377
1378
1379
class NumpyArrayIterator(Iterator):
1380
    """Iterator yielding data from a Numpy array.
1381
1382
    # Arguments
1383
        x: Numpy array of input data or tuple.
1384
            If tuple, the second elements is either
1385
            another numpy array or a list of numpy arrays,
1386
            each of which gets passed
1387
            through as an output without any modifications.
1388
        y: Numpy array of targets data.
1389
        image_data_generator: Instance of `ImageDataGenerator`
1390
            to use for random transformations and normalization.
1391
        batch_size: Integer, size of a batch.
1392
        shuffle: Boolean, whether to shuffle the data between epochs.
1393
        sample_weight: Numpy array of sample weights.
1394
        seed: Random seed for data shuffling.
1395
        data_format: String, one of `channels_first`, `channels_last`.
1396
        save_to_dir: Optional directory where to save the pictures
1397
            being yielded, in a viewable format. This is useful
1398
            for visualizing the random transformations being
1399
            applied, for debugging purposes.
1400
        save_prefix: String prefix to use for saving sample
1401
            images (if `save_to_dir` is set).
1402
        save_format: Format to use for saving sample images
1403
            (if `save_to_dir` is set).
1404
        subset: Subset of data (`"training"` or `"validation"`) if
1405
            validation_split is set in ImageDataGenerator.
1406
    """
1407
1408
    def __init__(self, x, y, image_data_generator,
1409
                 batch_size=32, shuffle=False, sample_weight=None,
1410
                 seed=None, data_format=None,
1411
                 save_to_dir=None, save_prefix='', save_format='png',
1412
                 subset=None):
1413
        if (type(x) is tuple) or (type(x) is list):
1414
            if type(x[1]) is not list:
1415
                x_misc = [np.asarray(x[1])]
1416
            else:
1417
                x_misc = [np.asarray(xx) for xx in x[1]]
1418
            x = x[0]
1419
            for xx in x_misc:
1420
                if len(x) != len(xx):
1421
                    raise ValueError(
1422
                        'All of the arrays in `x` '
1423
                        'should have the same length. '
1424
                        'Found a pair with: len(x[0]) = %s, len(x[?]) = %s' %
1425
                        (len(x), len(xx)))
1426
        else:
1427
            x_misc = []
1428
1429
        if y is not None and len(x) != len(y):
1430
            raise ValueError('`x` (images tensor) and `y` (labels) '
1431
                             'should have the same length. '
1432
                             'Found: x.shape = %s, y.shape = %s' %
1433
                             (np.asarray(x).shape, np.asarray(y).shape))
1434
        if sample_weight is not None and len(x) != len(sample_weight):
1435
            raise ValueError('`x` (images tensor) and `sample_weight` '
1436
                             'should have the same length. '
1437
                             'Found: x.shape = %s, sample_weight.shape = %s' %
1438
                             (np.asarray(x).shape, np.asarray(sample_weight).shape))
1439
        if subset is not None:
1440
            if subset not in {'training', 'validation'}:
1441
                raise ValueError('Invalid subset name:', subset,
1442
                                 '; expected "training" or "validation".')
1443
            split_idx = int(len(x) * image_data_generator._validation_split)
1444
            if subset == 'validation':
1445
                x = x[:split_idx]
1446
                x_misc = [np.asarray(xx[:split_idx]) for xx in x_misc]
1447
                if y is not None:
1448
                    y = y[:split_idx]
1449
            else:
1450
                x = x[split_idx:]
1451
                x_misc = [np.asarray(xx[split_idx:]) for xx in x_misc]
1452
                if y is not None:
1453
                    y = y[split_idx:]
1454
        if data_format is None:
1455
            data_format = backend.image_data_format()
1456
        self.x = np.asarray(x, dtype=backend.floatx())
1457
        self.x_misc = x_misc
1458
        if self.x.ndim != 4:
1459
            raise ValueError('Input data in `NumpyArrayIterator` '
1460
                             'should have rank 4. You passed an array '
1461
                             'with shape', self.x.shape)
1462
        channels_axis = 3 if data_format == 'channels_last' else 1
1463
        if self.x.shape[channels_axis] not in {1, 3, 4}:
1464
            warnings.warn('NumpyArrayIterator is set to use the '
1465
                          'data format convention "' + data_format + '" '
1466
                          '(channels on axis ' + str(channels_axis) +
1467
                          '), i.e. expected either 1, 3 or 4 '
1468
                          'channels on axis ' + str(channels_axis) + '. '
1469
                          'However, it was passed an array with shape ' +
1470
                          str(self.x.shape) + ' (' +
1471
                          str(self.x.shape[channels_axis]) + ' channels).')
1472
        if y is not None:
1473
            self.y = np.asarray(y)
1474
        else:
1475
            self.y = None
1476
        if sample_weight is not None:
1477
            self.sample_weight = np.asarray(sample_weight)
1478
        else:
1479
            self.sample_weight = None
1480
        self.image_data_generator = image_data_generator
1481
        self.data_format = data_format
1482
        self.save_to_dir = save_to_dir
1483
        self.save_prefix = save_prefix
1484
        self.save_format = save_format
1485
        super(NumpyArrayIterator, self).__init__(x.shape[0],
1486
                                                 batch_size,
1487
                                                 shuffle,
1488
                                                 seed)
1489
1490
    def _get_batches_of_transformed_samples(self, index_array):
1491
        batch_x = np.zeros(tuple([len(index_array)] + list(self.x.shape)[1:]),
1492
                           dtype=backend.floatx())
1493
        for i, j in enumerate(index_array):
1494
            x = self.x[j]
1495
            params = self.image_data_generator.get_random_transform(x.shape)
1496
            x = self.image_data_generator.apply_transform(
1497
                x.astype(backend.floatx()), params)
1498
            x = self.image_data_generator.standardize(x)
1499
            batch_x[i] = x
1500
1501
        if self.save_to_dir:
1502
            for i, j in enumerate(index_array):
1503
                img = array_to_img(batch_x[i], self.data_format, scale=True)
1504
                fname = '{prefix}_{index}_{hash}.{format}'.format(
1505
                    prefix=self.save_prefix,
1506
                    index=j,
1507
                    hash=np.random.randint(1e4),
1508
                    format=self.save_format)
1509
                img.save(os.path.join(self.save_to_dir, fname))
1510
        batch_x_miscs = [xx[index_array] for xx in self.x_misc]
1511
        output = (batch_x if batch_x_miscs == []
1512
                  else [batch_x] + batch_x_miscs,)
1513
        if self.y is None:
1514
            return output[0]
1515
        output += (self.y[index_array],)
1516
        if self.sample_weight is not None:
1517
            output += (self.sample_weight[index_array],)
1518
        return output
1519
1520
    def next(self):
1521
        """For python 2.x.
1522
1523
        # Returns
1524
            The next batch.
1525
        """
1526
        # Keeps under lock only the mechanism which advances
1527
        # the indexing of each batch.
1528
        with self.lock:
1529
            index_array = next(self.index_generator)
1530
        # The transformation of images is not under thread lock
1531
        # so it can be done in parallel
1532
        return self._get_batches_of_transformed_samples(index_array)
1533
1534
1535
def _iter_valid_files(directory, white_list_formats, follow_links):
1536
    """Iterates on files with extension in `white_list_formats` contained in `directory`.
1537
1538
    # Arguments
1539
        directory: Absolute path to the directory
1540
            containing files to be counted
1541
        white_list_formats: Set of strings containing allowed extensions for
1542
            the files to be counted.
1543
        follow_links: Boolean.
1544
1545
    # Yields
1546
        Tuple of (root, filename) with extension in `white_list_formats`.
1547
    """
1548
    def _recursive_list(subpath):
1549
        return sorted(os.walk(subpath, followlinks=follow_links),
1550
                      key=lambda x: x[0])
1551
1552
    for root, _, files in _recursive_list(directory):
1553
        for fname in sorted(files):
1554
            for extension in white_list_formats:
1555
                if fname.lower().endswith('.tiff'):
1556
                    warnings.warn('Using \'.tiff\' files with multiple bands '
1557
                                  'will cause distortion. '
1558
                                  'Please verify your output.')
1559
                if fname.lower().endswith('.' + extension):
1560
                    yield root, fname
1561
1562
1563
def _count_valid_files_in_directory(directory,
1564
                                    white_list_formats,
1565
                                    split,
1566
                                    follow_links):
1567
    """Counts files with extension in `white_list_formats` contained in `directory`.
1568
1569
    # Arguments
1570
        directory: absolute path to the directory
1571
            containing files to be counted
1572
        white_list_formats: set of strings containing allowed extensions for
1573
            the files to be counted.
1574
        split: tuple of floats (e.g. `(0.2, 0.6)`) to only take into
1575
            account a certain fraction of files in each directory.
1576
            E.g.: `segment=(0.6, 1.0)` would only account for last 40 percent
1577
            of images in each directory.
1578
        follow_links: boolean.
1579
1580
    # Returns
1581
        the count of files with extension in `white_list_formats` contained in
1582
        the directory.
1583
    """
1584
    num_files = len(list(
1585
        _iter_valid_files(directory, white_list_formats, follow_links)))
1586
    if split:
1587
        start, stop = int(split[0] * num_files), int(split[1] * num_files)
1588
    else:
1589
        start, stop = 0, num_files
1590
    return stop - start
1591
1592
1593
def _list_valid_filenames_in_directory(directory, white_list_formats, split,
1594
                                       class_indices, follow_links):
1595
    """Lists paths of files in `subdir` with extensions in `white_list_formats`.
1596
1597
    # Arguments
1598
        directory: absolute path to a directory containing the files to list.
1599
            The directory name is used as class label
1600
            and must be a key of `class_indices`.
1601
        white_list_formats: set of strings containing allowed extensions for
1602
            the files to be counted.
1603
        split: tuple of floats (e.g. `(0.2, 0.6)`) to only take into
1604
            account a certain fraction of files in each directory.
1605
            E.g.: `segment=(0.6, 1.0)` would only account for last 40 percent
1606
            of images in each directory.
1607
        class_indices: dictionary mapping a class name to its index.
1608
        follow_links: boolean.
1609
1610
    # Returns
1611
        classes: a list of class indices
1612
        filenames: the path of valid files in `directory`, relative from
1613
            `directory`'s parent (e.g., if `directory` is "dataset/class1",
1614
            the filenames will be
1615
            `["class1/file1.jpg", "class1/file2.jpg", ...]`).
1616
    """
1617
    dirname = os.path.basename(directory)
1618
    if split:
1619
        num_files = len(list(
1620
            _iter_valid_files(directory, white_list_formats, follow_links)))
1621
        start, stop = int(split[0] * num_files), int(split[1] * num_files)
1622
        valid_files = list(
1623
            _iter_valid_files(
1624
                directory, white_list_formats, follow_links))[start: stop]
1625
    else:
1626
        valid_files = _iter_valid_files(
1627
            directory, white_list_formats, follow_links)
1628
1629
    classes = []
1630
    filenames = []
1631
    for root, fname in valid_files:
1632
        classes.append(class_indices[dirname])
1633
        absolute_path = os.path.join(root, fname)
1634
        relative_path = os.path.join(
1635
            dirname, os.path.relpath(absolute_path, directory))
1636
        filenames.append(relative_path)
1637
1638
    return classes, filenames
1639
1640
1641
class DirectoryIterator(Iterator):
1642
    """Iterator capable of reading images from a directory on disk.
1643
1644
    # Arguments
1645
        directory: Path to the directory to read images from.
1646
            Each subdirectory in this directory will be
1647
            considered to contain images from one class,
1648
            or alternatively you could specify class subdirectories
1649
            via the `classes` argument.
1650
        image_data_generator: Instance of `ImageDataGenerator`
1651
            to use for random transformations and normalization.
1652
        target_size: tuple of integers, dimensions to resize input images to.
1653
        color_mode: One of `"rgb"`, `"grayscale"`. Color mode to read images.
1654
        classes: Optional list of strings, names of subdirectories
1655
            containing images from each class (e.g. `["dogs", "cats"]`).
1656
            It will be computed automatically if not set.
1657
        class_mode: Mode for yielding the targets:
1658
            `"binary"`: binary targets (if there are only two classes),
1659
            `"categorical"`: categorical targets,
1660
            `"sparse"`: integer targets,
1661
            `"input"`: targets are images identical to input images (mainly
1662
                used to work with autoencoders),
1663
            `None`: no targets get yielded (only input images are yielded).
1664
        batch_size: Integer, size of a batch.
1665
        shuffle: Boolean, whether to shuffle the data between epochs.
1666
        seed: Random seed for data shuffling.
1667
        data_format: String, one of `channels_first`, `channels_last`.
1668
        save_to_dir: Optional directory where to save the pictures
1669
            being yielded, in a viewable format. This is useful
1670
            for visualizing the random transformations being
1671
            applied, for debugging purposes.
1672
        save_prefix: String prefix to use for saving sample
1673
            images (if `save_to_dir` is set).
1674
        save_format: Format to use for saving sample images
1675
            (if `save_to_dir` is set).
1676
        subset: Subset of data (`"training"` or `"validation"`) if
1677
            validation_split is set in ImageDataGenerator.
1678
        interpolation: Interpolation method used to resample the image if the
1679
            target size is different from that of the loaded image.
1680
            Supported methods are "nearest", "bilinear", and "bicubic".
1681
            If PIL version 1.1.3 or newer is installed, "lanczos" is also
1682
            supported. If PIL version 3.4.0 or newer is installed, "box" and
1683
            "hamming" are also supported. By default, "nearest" is used.
1684
    """
1685
1686
    def __init__(self, directory, image_data_generator,
1687
                 target_size=(256, 256), color_mode='rgb',
1688
                 classes=None, class_mode='categorical',
1689
                 batch_size=32, shuffle=True, seed=None,
1690
                 data_format=None,
1691
                 save_to_dir=None, save_prefix='', save_format='png',
1692
                 follow_links=False,
1693
                 subset=None,
1694
                 interpolation='nearest'):
1695
        if data_format is None:
1696
            data_format = backend.image_data_format()
1697
        self.directory = directory
1698
        self.image_data_generator = image_data_generator
1699
        self.target_size = tuple(target_size)
1700
        if color_mode not in {'rgb', 'grayscale'}:
1701
            raise ValueError('Invalid color mode:', color_mode,
1702
                             '; expected "rgb" or "grayscale".')
1703
        self.color_mode = color_mode
1704
        self.data_format = data_format
1705
        if self.color_mode == 'rgb':
1706
            if self.data_format == 'channels_last':
1707
                self.image_shape = self.target_size + (3,)
1708
            else:
1709
                self.image_shape = (3,) + self.target_size
1710
        else:
1711
            if self.data_format == 'channels_last':
1712
                self.image_shape = self.target_size + (1,)
1713
            else:
1714
                self.image_shape = (1,) + self.target_size
1715
        self.classes = classes
1716
        if class_mode not in {'categorical', 'binary', 'sparse',
1717
                              'input', None}:
1718
            raise ValueError('Invalid class_mode:', class_mode,
1719
                             '; expected one of "categorical", '
1720
                             '"binary", "sparse", "input"'
1721
                             ' or None.')
1722
        self.class_mode = class_mode
1723
        self.save_to_dir = save_to_dir
1724
        self.save_prefix = save_prefix
1725
        self.save_format = save_format
1726
        self.interpolation = interpolation
1727
1728
        if subset is not None:
1729
            validation_split = self.image_data_generator._validation_split
1730
            if subset == 'validation':
1731
                split = (0, validation_split)
1732
            elif subset == 'training':
1733
                split = (validation_split, 1)
1734
            else:
1735
                raise ValueError('Invalid subset name: ', subset,
1736
                                 '; expected "training" or "validation"')
1737
        else:
1738
            split = None
1739
        self.subset = subset
1740
1741
        white_list_formats = {'png', 'jpg', 'jpeg', 'bmp',
1742
                              'ppm', 'tif', 'tiff'}
1743
        # First, count the number of samples and classes.
1744
        self.samples = 0
1745
1746
        if not classes:
1747
            classes = []
1748
            for subdir in sorted(os.listdir(directory)):
1749
                if os.path.isdir(os.path.join(directory, subdir)):
1750
                    classes.append(subdir)
1751
        self.num_classes = len(classes)
1752
        self.class_indices = dict(zip(classes, range(len(classes))))
1753
1754
        pool = multiprocessing.pool.ThreadPool()
1755
        function_partial = partial(_count_valid_files_in_directory,
1756
                                   white_list_formats=white_list_formats,
1757
                                   follow_links=follow_links,
1758
                                   split=split)
1759
        self.samples = sum(pool.map(function_partial,
1760
                                    (os.path.join(directory, subdir)
1761
                                     for subdir in classes)))
1762
1763
        print('Found %d images belonging to %d classes.' %
1764
              (self.samples, self.num_classes))
1765
1766
        # Second, build an index of the images
1767
        # in the different class subfolders.
1768
        results = []
1769
        self.filenames = []
1770
        self.classes = np.zeros((self.samples,), dtype='int32')
1771
        i = 0
1772
        for dirpath in (os.path.join(directory, subdir) for subdir in classes):
1773
            results.append(
1774
                pool.apply_async(_list_valid_filenames_in_directory,
1775
                                 (dirpath, white_list_formats, split,
1776
                                  self.class_indices, follow_links)))
1777
        for res in results:
1778
            classes, filenames = res.get()
1779
            self.classes[i:i + len(classes)] = classes
1780
            self.filenames += filenames
1781
            i += len(classes)
1782
1783
        pool.close()
1784
        pool.join()
1785
        super(DirectoryIterator, self).__init__(self.samples,
1786
                                                batch_size,
1787
                                                shuffle,
1788
                                                seed)
1789
1790
    def _get_batches_of_transformed_samples(self, index_array):
1791
        batch_x = np.zeros(
1792
            (len(index_array),) + self.image_shape,
1793
            dtype=backend.floatx())
1794
        grayscale = self.color_mode == 'grayscale'
1795
        # build batch of image data
1796
        for i, j in enumerate(index_array):
1797
            fname = self.filenames[j]
1798
            img = load_img(os.path.join(self.directory, fname),
1799
                           grayscale=grayscale,
1800
                           target_size=self.target_size,
1801
                           interpolation=self.interpolation)
1802
            x = img_to_array(img, data_format=self.data_format)
1803
            params = self.image_data_generator.get_random_transform(x.shape)
1804
            x = self.image_data_generator.apply_transform(x, params)
1805
1806
            x = self.image_data_generator.standardize(x)
1807
1808
            batch_x[i] = x
1809
        # optionally save augmented images to disk for debugging purposes
1810
        if self.save_to_dir:
1811
            for i, j in enumerate(index_array):
1812
                img = array_to_img(batch_x[i], self.data_format, scale=True)
1813
                fname = '{prefix}_{index}_{hash}.{format}'.format(
1814
                    prefix=self.save_prefix,
1815
                    index=j,
1816
                    hash=np.random.randint(1e7),
1817
                    format=self.save_format)
1818
                img.save(os.path.join(self.save_to_dir, fname))
1819
        # build batch of labels
1820
        if self.class_mode == 'input':
1821
            batch_y = batch_x.copy()
1822
        elif self.class_mode == 'sparse':
1823
            batch_y = self.classes[index_array]
1824
        elif self.class_mode == 'binary':
1825
            batch_y = self.classes[index_array].astype(backend.floatx())
1826
        elif self.class_mode == 'categorical':
1827
            batch_y = np.zeros(
1828
                (len(batch_x), self.num_classes),
1829
                dtype=backend.floatx())
1830
            for i, label in enumerate(self.classes[index_array]):
1831
                batch_y[i, label] = 1.
1832
        else:
1833
            return batch_x
1834
        return batch_x, batch_y
1835
1836
    def next(self):
1837
        """For python 2.x.
1838
1839
        # Returns
1840
            The next batch.
1841
        """
1842
        with self.lock:
1843
            index_array = next(self.index_generator)
1844
        # The transformation of images is not under thread lock
1845
        # so it can be done in parallel
1846
        return self._get_batches_of_transformed_samples(index_array)