Diff of /keras_CNN/image_5ch.py [000000] .. [797475]

Switch to side-by-side view

--- a
+++ b/keras_CNN/image_5ch.py
@@ -0,0 +1,719 @@
+'''Fairly basic set of tools for real-time data augmentation on image data.
+Can easily be extended to include new transformations,
+new preprocessing methods, etc...
+
+This is a direct fork of the Keras built-in "preprocessing/image.py",
+which was modified to allow 5-channel images (which Keras used to support...)
+'''
+from __future__ import absolute_import
+from __future__ import print_function
+
+import numpy as np
+import re
+from scipy import linalg
+import scipy.ndimage as ndi
+from six.moves import range
+import os
+import threading
+import warnings
+
+from keras import backend as K
+
+
+def random_rotation(x, rg, row_index=1, col_index=2, channel_index=0,
+                    fill_mode='nearest', cval=0.):
+    theta = np.pi / 180 * np.random.uniform(-rg, rg)
+    rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0],
+                                [np.sin(theta), np.cos(theta), 0],
+                                [0, 0, 1]])
+
+    h, w = x.shape[row_index], x.shape[col_index]
+    transform_matrix = transform_matrix_offset_center(rotation_matrix, h, w)
+    x = apply_transform(x, transform_matrix, channel_index, fill_mode, cval)
+    return x
+
+
+def random_shift(x, wrg, hrg, row_index=1, col_index=2, channel_index=0,
+                 fill_mode='nearest', cval=0.):
+    h, w = x.shape[row_index], x.shape[col_index]
+    tx = np.random.uniform(-hrg, hrg) * h
+    ty = np.random.uniform(-wrg, wrg) * w
+    translation_matrix = np.array([[1, 0, tx],
+                                   [0, 1, ty],
+                                   [0, 0, 1]])
+
+    transform_matrix = translation_matrix  # no need to do offset
+    x = apply_transform(x, transform_matrix, channel_index, fill_mode, cval)
+    return x
+
+
+def random_shear(x, intensity, row_index=1, col_index=2, channel_index=0,
+                 fill_mode='nearest', cval=0.):
+    shear = np.random.uniform(-intensity, intensity)
+    shear_matrix = np.array([[1, -np.sin(shear), 0],
+                             [0, np.cos(shear), 0],
+                             [0, 0, 1]])
+
+    h, w = x.shape[row_index], x.shape[col_index]
+    transform_matrix = transform_matrix_offset_center(shear_matrix, h, w)
+    x = apply_transform(x, transform_matrix, channel_index, fill_mode, cval)
+    return x
+
+
+def random_zoom(x, zoom_range, row_index=1, col_index=2, channel_index=0,
+                fill_mode='nearest', cval=0.):
+    if len(zoom_range) != 2:
+        raise ValueError('zoom_range should be a tuple or list of two floats. '
+                         'Received arg: ', zoom_range)
+
+    if zoom_range[0] == 1 and zoom_range[1] == 1:
+        zx, zy = 1, 1
+    else:
+        zx, zy = np.random.uniform(zoom_range[0], zoom_range[1], 2)
+    zoom_matrix = np.array([[zx, 0, 0],
+                            [0, zy, 0],
+                            [0, 0, 1]])
+
+    h, w = x.shape[row_index], x.shape[col_index]
+    transform_matrix = transform_matrix_offset_center(zoom_matrix, h, w)
+    x = apply_transform(x, transform_matrix, channel_index, fill_mode, cval)
+    return x
+
+
+def random_barrel_transform(x, intensity):
+    # TODO
+    pass
+
+
+def random_channel_shift(x, intensity, channel_index=0):
+    x = np.rollaxis(x, channel_index, 0)
+    min_x, max_x = np.min(x), np.max(x)
+    channel_images = [np.clip(x_channel + np.random.uniform(-intensity, intensity), min_x, max_x)
+                      for x_channel in x]
+    x = np.stack(channel_images, axis=0)
+    x = np.rollaxis(x, 0, channel_index+1)
+    return x
+
+
+def transform_matrix_offset_center(matrix, x, y):
+    o_x = float(x) / 2 + 0.5
+    o_y = float(y) / 2 + 0.5
+    offset_matrix = np.array([[1, 0, o_x], [0, 1, o_y], [0, 0, 1]])
+    reset_matrix = np.array([[1, 0, -o_x], [0, 1, -o_y], [0, 0, 1]])
+    transform_matrix = np.dot(np.dot(offset_matrix, matrix), reset_matrix)
+    return transform_matrix
+
+
+def apply_transform(x, transform_matrix, channel_index=0, fill_mode='nearest', cval=0.):
+    x = np.rollaxis(x, channel_index, 0)
+    final_affine_matrix = transform_matrix[:2, :2]
+    final_offset = transform_matrix[:2, 2]
+    channel_images = [ndi.interpolation.affine_transform(x_channel, final_affine_matrix,
+                      final_offset, order=0, mode=fill_mode, cval=cval) for x_channel in x]
+    x = np.stack(channel_images, axis=0)
+    x = np.rollaxis(x, 0, channel_index+1)
+    return x
+
+
+def flip_axis(x, axis):
+    x = np.asarray(x).swapaxes(axis, 0)
+    x = x[::-1, ...]
+    x = x.swapaxes(0, axis)
+    return x
+
+
+def array_to_img(x, dim_ordering='default', scale=True):
+    from PIL import Image
+    x = np.asarray(x)
+    if x.ndim != 3:
+        raise ValueError('Expected image array to have rank 3 (single image). '
+                         'Got array with shape:', x.shape)
+
+    if dim_ordering == 'default':
+        dim_ordering = K.image_dim_ordering()
+    if dim_ordering not in {'th', 'tf'}:
+        raise ValueError('Invalid dim_ordering:', dim_ordering)
+
+    # Original Numpy array x has format (height, width, channel)
+    # or (channel, height, width)
+    # but target PIL image has format (width, height, channel)
+    if dim_ordering == 'th':
+        x = x.transpose(1, 2, 0)
+    if scale:
+        x += max(-np.min(x), 0)
+        x_max = np.max(x)
+        if x_max != 0:
+            x /= x_max
+        x *= 255
+    if x.shape[2] == 3:
+        # RGB
+        return Image.fromarray(x.astype('uint8'), 'RGB')
+    elif x.shape[2] == 1:
+        # grayscale
+        return Image.fromarray(x[:, :, 0].astype('uint8'), 'L')
+    else:
+        raise ValueError('Unsupported channel number: ', x.shape[2])
+
+
+def img_to_array(img, dim_ordering='default'):
+    if dim_ordering == 'default':
+        dim_ordering = K.image_dim_ordering()
+    if dim_ordering not in {'th', 'tf'}:
+        raise ValueError('Unknown dim_ordering: ', dim_ordering)
+    # Numpy array x has format (height, width, channel)
+    # or (channel, height, width)
+    # but original PIL image has format (width, height, channel)
+    x = np.asarray(img, dtype='float32')
+    if len(x.shape) == 3:
+        if dim_ordering == 'th':
+            x = x.transpose(2, 0, 1)
+    elif len(x.shape) == 2:
+        if dim_ordering == 'th':
+            x = x.reshape((1, x.shape[0], x.shape[1]))
+        else:
+            x = x.reshape((x.shape[0], x.shape[1], 1))
+    else:
+        raise ValueError('Unsupported image shape: ', x.shape)
+    return x
+
+
+def load_img(path, grayscale=False, target_size=None):
+    '''Load an image into PIL format.
+
+    # Arguments
+        path: path to image file
+        grayscale: boolean
+        target_size: None (default to original size)
+            or (img_height, img_width)
+    '''
+    from PIL import Image
+    img = Image.open(path)
+    if grayscale:
+        img = img.convert('L')
+    else:  # Ensure 3 channel even when loaded image is grayscale
+        img = img.convert('RGB')
+    if target_size:
+        img = img.resize((target_size[1], target_size[0]))
+    return img
+
+
+def list_pictures(directory, ext='jpg|jpeg|bmp|png'):
+    return [os.path.join(root, f)
+            for root, dirs, files in os.walk(directory) for f in files
+            if re.match('([\w]+\.(?:' + ext + '))', f)]
+
+
+class ImageDataGenerator(object):
+    '''Generate minibatches with
+    real-time data augmentation.
+
+    # Arguments
+        featurewise_center: set input mean to 0 over the dataset.
+        samplewise_center: set each sample mean to 0.
+        featurewise_std_normalization: divide inputs by std of the dataset.
+        samplewise_std_normalization: divide each input by its std.
+        zca_whitening: apply ZCA whitening.
+        rotation_range: degrees (0 to 180).
+        width_shift_range: fraction of total width.
+        height_shift_range: fraction of total height.
+        shear_range: shear intensity (shear angle in radians).
+        zoom_range: amount of zoom. if scalar z, zoom will be randomly picked
+            in the range [1-z, 1+z]. A sequence of two can be passed instead
+            to select this range.
+        channel_shift_range: shift range for each channels.
+        fill_mode: points outside the boundaries are filled according to the
+            given mode ('constant', 'nearest', 'reflect' or 'wrap'). Default
+            is 'nearest'.
+        cval: value used for points outside the boundaries when fill_mode is
+            'constant'. Default is 0.
+        horizontal_flip: whether to randomly flip images horizontally.
+        vertical_flip: whether to randomly flip images vertically.
+        rescale: rescaling factor. If None or 0, no rescaling is applied,
+            otherwise we multiply the data by the value provided
+            (before applying any other transformation).
+        preprocessing_function: function that will be implied on each input.
+            The function will run before any other modification on it.
+            The function should take one argument: one image (Numpy tensor with rank 3),
+            and should output a Numpy tensor with the same shape.
+        dim_ordering: 'th' or 'tf'. In 'th' mode, the channels dimension
+            (the depth) is at index 1, in 'tf' mode it is at index 3.
+            It defaults to the `image_dim_ordering` value found in your
+            Keras config file at `~/.keras/keras.json`.
+            If you never set it, then it will be "th".
+    '''
+    def __init__(self,
+                 featurewise_center=False,
+                 samplewise_center=False,
+                 featurewise_std_normalization=False,
+                 samplewise_std_normalization=False,
+                 zca_whitening=False,
+                 rotation_range=0.,
+                 width_shift_range=0.,
+                 height_shift_range=0.,
+                 shear_range=0.,
+                 zoom_range=0.,
+                 channel_shift_range=0.,
+                 fill_mode='nearest',
+                 cval=0.,
+                 horizontal_flip=False,
+                 vertical_flip=False,
+                 rescale=None,
+                 preprocessing_function=None,
+                 dim_ordering='default'):
+        if dim_ordering == 'default':
+            dim_ordering = K.image_dim_ordering()
+        self.__dict__.update(locals())
+        self.mean = None
+        self.std = None
+        self.principal_components = None
+        self.rescale = rescale
+        self.preprocessing_function = preprocessing_function
+
+        if dim_ordering not in {'tf', 'th'}:
+            raise ValueError('dim_ordering should be "tf" (channel after row and '
+                             'column) or "th" (channel before row and column). '
+                             'Received arg: ', dim_ordering)
+        self.dim_ordering = dim_ordering
+        if dim_ordering == 'th':
+            self.channel_index = 1
+            self.row_index = 2
+            self.col_index = 3
+        if dim_ordering == 'tf':
+            self.channel_index = 3
+            self.row_index = 1
+            self.col_index = 2
+
+        if np.isscalar(zoom_range):
+            self.zoom_range = [1 - zoom_range, 1 + zoom_range]
+        elif len(zoom_range) == 2:
+            self.zoom_range = [zoom_range[0], zoom_range[1]]
+        else:
+            raise ValueError('zoom_range should be a float or '
+                             'a tuple or list of two floats. '
+                             'Received arg: ', zoom_range)
+
+    def flow(self, X, y=None, batch_size=32, shuffle=True, seed=None,
+             save_to_dir=None, save_prefix='', save_format='jpeg'):
+        return NumpyArrayIterator(
+            X, y, self,
+            batch_size=batch_size, shuffle=shuffle, seed=seed,
+            dim_ordering=self.dim_ordering,
+            save_to_dir=save_to_dir, save_prefix=save_prefix, save_format=save_format)
+
+    def flow_from_directory(self, directory,
+                            target_size=(256, 256), color_mode='rgb',
+                            classes=None, class_mode='categorical',
+                            batch_size=32, shuffle=True, seed=None,
+                            save_to_dir=None, save_prefix='', save_format='jpeg',
+                            follow_links=False):
+        return DirectoryIterator(
+            directory, self,
+            target_size=target_size, color_mode=color_mode,
+            classes=classes, class_mode=class_mode,
+            dim_ordering=self.dim_ordering,
+            batch_size=batch_size, shuffle=shuffle, seed=seed,
+            save_to_dir=save_to_dir, save_prefix=save_prefix, save_format=save_format,
+            follow_links=follow_links)
+
+    def standardize(self, x):
+        if self.preprocessing_function:
+            x = self.preprocessing_function(x)
+        if self.rescale:
+            x *= self.rescale
+        # x is a single image, so it doesn't have image number at index 0
+        img_channel_index = self.channel_index - 1
+        if self.samplewise_center:
+            x -= np.mean(x, axis=img_channel_index, keepdims=True)
+        if self.samplewise_std_normalization:
+            x /= (np.std(x, axis=img_channel_index, keepdims=True) + 1e-7)
+
+        if self.featurewise_center:
+            if self.mean is not None:
+                x -= self.mean
+            else:
+                warnings.warn('This ImageDataGenerator specifies '
+                              '`featurewise_center`, but it hasn\'t'
+                              'been fit on any training data. Fit it '
+                              'first by calling `.fit(numpy_data)`.')
+        if self.featurewise_std_normalization:
+            if self.std is not None:
+                x /= (self.std + 1e-7)
+            else:
+                warnings.warn('This ImageDataGenerator specifies '
+                              '`featurewise_std_normalization`, but it hasn\'t'
+                              'been fit on any training data. Fit it '
+                              'first by calling `.fit(numpy_data)`.')
+        if self.zca_whitening:
+            if self.principal_components is not None:
+                flatx = np.reshape(x, (x.size))
+                whitex = np.dot(flatx, self.principal_components)
+                x = np.reshape(whitex, (x.shape[0], x.shape[1], x.shape[2]))
+            else:
+                warnings.warn('This ImageDataGenerator specifies '
+                              '`zca_whitening`, but it hasn\'t'
+                              'been fit on any training data. Fit it '
+                              'first by calling `.fit(numpy_data)`.')
+        return x
+
+    def random_transform(self, x):
+        # x is a single image, so it doesn't have image number at index 0
+        img_row_index = self.row_index - 1
+        img_col_index = self.col_index - 1
+        img_channel_index = self.channel_index - 1
+
+        # use composition of homographies to generate final transform that needs to be applied
+        if self.rotation_range:
+            theta = np.pi / 180 * np.random.uniform(-self.rotation_range, self.rotation_range)
+        else:
+            theta = 0
+        rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0],
+                                    [np.sin(theta), np.cos(theta), 0],
+                                    [0, 0, 1]])
+        if self.height_shift_range:
+            tx = np.random.uniform(-self.height_shift_range, self.height_shift_range) * x.shape[img_row_index]
+        else:
+            tx = 0
+
+        if self.width_shift_range:
+            ty = np.random.uniform(-self.width_shift_range, self.width_shift_range) * x.shape[img_col_index]
+        else:
+            ty = 0
+
+        translation_matrix = np.array([[1, 0, tx],
+                                       [0, 1, ty],
+                                       [0, 0, 1]])
+        if self.shear_range:
+            shear = np.random.uniform(-self.shear_range, self.shear_range)
+        else:
+            shear = 0
+        shear_matrix = np.array([[1, -np.sin(shear), 0],
+                                 [0, np.cos(shear), 0],
+                                 [0, 0, 1]])
+
+        if self.zoom_range[0] == 1 and self.zoom_range[1] == 1:
+            zx, zy = 1, 1
+        else:
+            zx, zy = np.random.uniform(self.zoom_range[0], self.zoom_range[1], 2)
+        zoom_matrix = np.array([[zx, 0, 0],
+                                [0, zy, 0],
+                                [0, 0, 1]])
+
+        transform_matrix = np.dot(np.dot(np.dot(rotation_matrix, translation_matrix), shear_matrix), zoom_matrix)
+
+        h, w = x.shape[img_row_index], x.shape[img_col_index]
+        transform_matrix = transform_matrix_offset_center(transform_matrix, h, w)
+        x = apply_transform(x, transform_matrix, img_channel_index,
+                            fill_mode=self.fill_mode, cval=self.cval)
+        if self.channel_shift_range != 0:
+            x = random_channel_shift(x, self.channel_shift_range, img_channel_index)
+
+        if self.horizontal_flip:
+            if np.random.random() < 0.5:
+                x = flip_axis(x, img_col_index)
+
+        if self.vertical_flip:
+            if np.random.random() < 0.5:
+                x = flip_axis(x, img_row_index)
+
+        # TODO:
+        # channel-wise normalization
+        # barrel/fisheye
+        return x
+
+    def fit(self, X,
+            augment=False,
+            rounds=1,
+            seed=None):
+        '''Required for featurewise_center, featurewise_std_normalization
+        and zca_whitening.
+
+        # Arguments
+            X: Numpy array, the data to fit on. Should have rank 4.
+                In case of grayscale data,
+                the channels axis should have value 1, and in case
+                of RGB data, it should have value 3.
+            augment: Whether to fit on randomly augmented samples
+            rounds: If `augment`,
+                how many augmentation passes to do over the data
+            seed: random seed.
+        '''
+        X = np.asarray(X)
+        if X.ndim != 4:
+            raise ValueError('Input to `.fit()` should have rank 4. '
+                             'Got array with shape: ' + str(X.shape))
+        if X.shape[self.channel_index] not in {1, 3, 4, 5}:
+            raise ValueError(
+                'Expected input to be images (as Numpy array) '
+                'following the dimension ordering convention "' + self.dim_ordering + '" '
+                '(channels on axis ' + str(self.channel_index) + '), i.e. expected '
+                'either 1, 3, 4, or 5 channels on axis ' + str(self.channel_index) + '. '
+                'However, it was passed an array with shape ' + str(X.shape) +
+                ' (' + str(X.shape[self.channel_index]) + ' channels).')
+
+        if seed is not None:
+            np.random.seed(seed)
+
+        X = np.copy(X)
+        if augment:
+            aX = np.zeros(tuple([rounds * X.shape[0]] + list(X.shape)[1:]))
+            for r in range(rounds):
+                for i in range(X.shape[0]):
+                    aX[i + r * X.shape[0]] = self.random_transform(X[i])
+            X = aX
+
+        if self.featurewise_center:
+            self.mean = np.mean(X, axis=(0, self.row_index, self.col_index))
+            broadcast_shape = [1, 1, 1]
+            broadcast_shape[self.channel_index - 1] = X.shape[self.channel_index]
+            self.mean = np.reshape(self.mean, broadcast_shape)
+            X -= self.mean
+
+        if self.featurewise_std_normalization:
+            self.std = np.std(X, axis=(0, self.row_index, self.col_index))
+            broadcast_shape = [1, 1, 1]
+            broadcast_shape[self.channel_index - 1] = X.shape[self.channel_index]
+            self.std = np.reshape(self.std, broadcast_shape)
+            X /= (self.std + K.epsilon())
+
+        if self.zca_whitening:
+            flatX = np.reshape(X, (X.shape[0], X.shape[1] * X.shape[2] * X.shape[3]))
+            sigma = np.dot(flatX.T, flatX) / flatX.shape[0]
+            U, S, V = linalg.svd(sigma)
+            self.principal_components = np.dot(np.dot(U, np.diag(1. / np.sqrt(S + 10e-7))), U.T)
+
+
+class Iterator(object):
+
+    def __init__(self, N, batch_size, shuffle, seed):
+        self.N = N
+        self.batch_size = batch_size
+        self.shuffle = shuffle
+        self.batch_index = 0
+        self.total_batches_seen = 0
+        self.lock = threading.Lock()
+        self.index_generator = self._flow_index(N, batch_size, shuffle, seed)
+
+    def reset(self):
+        self.batch_index = 0
+
+    def _flow_index(self, N, batch_size=32, shuffle=False, seed=None):
+        # ensure self.batch_index is 0
+        self.reset()
+        while 1:
+            if seed is not None:
+                np.random.seed(seed + self.total_batches_seen)
+            if self.batch_index == 0:
+                index_array = np.arange(N)
+                if shuffle:
+                    index_array = np.random.permutation(N)
+
+            current_index = (self.batch_index * batch_size) % N
+            if N >= current_index + batch_size:
+                current_batch_size = batch_size
+                self.batch_index += 1
+            else:
+                current_batch_size = N - current_index
+                self.batch_index = 0
+            self.total_batches_seen += 1
+            yield (index_array[current_index: current_index + current_batch_size],
+                   current_index, current_batch_size)
+
+    def __iter__(self):
+        # needed if we want to do something like:
+        # for x, y in data_gen.flow(...):
+        return self
+
+    def __next__(self, *args, **kwargs):
+        return self.next(*args, **kwargs)
+
+
+class NumpyArrayIterator(Iterator):
+
+    def __init__(self, X, y, image_data_generator,
+                 batch_size=32, shuffle=False, seed=None,
+                 dim_ordering='default',
+                 save_to_dir=None, save_prefix='', save_format='jpeg'):
+        if y is not None and len(X) != len(y):
+            raise ValueError('X (images tensor) and y (labels) '
+                             'should have the same length. '
+                             'Found: X.shape = %s, y.shape = %s' % (np.asarray(X).shape, np.asarray(y).shape))
+        if dim_ordering == 'default':
+            dim_ordering = K.image_dim_ordering()
+        self.X = np.asarray(X)
+        if self.X.ndim != 4:
+            raise ValueError('Input data in `NumpyArrayIterator` '
+                             'should have rank 4. You passed an array '
+                             'with shape', self.X.shape)
+        channels_axis = 3 if dim_ordering == 'tf' else 1
+        if self.X.shape[channels_axis] not in {1, 3, 4, 5}:
+            raise ValueError('NumpyArrayIterator is set to use the '
+                             'dimension ordering convention "' + dim_ordering + '" '
+                             '(channels on axis ' + str(channels_axis) + '), i.e. expected '
+                             'either 1, 3 or 4, or 5 channels on axis ' + str(channels_axis) + '. '
+                             'However, it was passed an array with shape ' + str(self.X.shape) +
+                             ' (' + str(self.X.shape[channels_axis]) + ' channels).')
+        if y is not None:
+            self.y = np.asarray(y)
+        else:
+            self.y = None
+        self.image_data_generator = image_data_generator
+        self.dim_ordering = dim_ordering
+        self.save_to_dir = save_to_dir
+        self.save_prefix = save_prefix
+        self.save_format = save_format
+        super(NumpyArrayIterator, self).__init__(X.shape[0], batch_size, shuffle, seed)
+
+    def next(self):
+        # for python 2.x.
+        # Keeps under lock only the mechanism which advances
+        # the indexing of each batch
+        # see http://anandology.com/blog/using-iterators-and-generators/
+        with self.lock:
+            index_array, current_index, current_batch_size = next(self.index_generator)
+        # The transformation of images is not under thread lock so it can be done in parallel
+        batch_x = np.zeros(tuple([current_batch_size] + list(self.X.shape)[1:]))
+        for i, j in enumerate(index_array):
+            x = self.X[j]
+            x = self.image_data_generator.random_transform(x.astype('float32'))
+            x = self.image_data_generator.standardize(x)
+            batch_x[i] = x
+        if self.save_to_dir:
+            for i in range(current_batch_size):
+                img = array_to_img(batch_x[i], self.dim_ordering, scale=True)
+                fname = '{prefix}_{index}_{hash}.{format}'.format(prefix=self.save_prefix,
+                                                                  index=current_index + i,
+                                                                  hash=np.random.randint(1e4),
+                                                                  format=self.save_format)
+                img.save(os.path.join(self.save_to_dir, fname))
+        if self.y is None:
+            return batch_x
+        batch_y = self.y[index_array]
+        return batch_x, batch_y
+
+
+class DirectoryIterator(Iterator):
+
+    def __init__(self, directory, image_data_generator,
+                 target_size=(256, 256), color_mode='rgb',
+                 dim_ordering='default',
+                 classes=None, class_mode='categorical',
+                 batch_size=32, shuffle=True, seed=None,
+                 save_to_dir=None, save_prefix='', save_format='jpeg',
+                 follow_links=False):
+        if dim_ordering == 'default':
+            dim_ordering = K.image_dim_ordering()
+        self.directory = directory
+        self.image_data_generator = image_data_generator
+        self.target_size = tuple(target_size)
+        if color_mode not in {'rgb', 'grayscale'}:
+            raise ValueError('Invalid color mode:', color_mode,
+                             '; expected "rgb" or "grayscale".')
+        self.color_mode = color_mode
+        self.dim_ordering = dim_ordering
+        if self.color_mode == 'rgb':
+            if self.dim_ordering == 'tf':
+                self.image_shape = self.target_size + (3,)
+            else:
+                self.image_shape = (3,) + self.target_size
+        else:
+            if self.dim_ordering == 'tf':
+                self.image_shape = self.target_size + (1,)
+            else:
+                self.image_shape = (1,) + self.target_size
+        self.classes = classes
+        if class_mode not in {'categorical', 'binary', 'sparse', None}:
+            raise ValueError('Invalid class_mode:', class_mode,
+                             '; expected one of "categorical", '
+                             '"binary", "sparse", or None.')
+        self.class_mode = class_mode
+        self.save_to_dir = save_to_dir
+        self.save_prefix = save_prefix
+        self.save_format = save_format
+
+        white_list_formats = {'png', 'jpg', 'jpeg', 'bmp'}
+
+        # first, count the number of samples and classes
+        self.nb_sample = 0
+
+        if not classes:
+            classes = []
+            for subdir in sorted(os.listdir(directory)):
+                if os.path.isdir(os.path.join(directory, subdir)):
+                    classes.append(subdir)
+        self.nb_class = len(classes)
+        self.class_indices = dict(zip(classes, range(len(classes))))
+
+        def _recursive_list(subpath):
+            return sorted(os.walk(subpath, followlinks=follow_links), key=lambda tpl: tpl[0])
+
+        for subdir in classes:
+            subpath = os.path.join(directory, subdir)
+            for root, dirs, files in _recursive_list(subpath):
+                for fname in files:
+                    is_valid = False
+                    for extension in white_list_formats:
+                        if fname.lower().endswith('.' + extension):
+                            is_valid = True
+                            break
+                    if is_valid:
+                        self.nb_sample += 1
+        print('Found %d images belonging to %d classes.' % (self.nb_sample, self.nb_class))
+
+        # second, build an index of the images in the different class subfolders
+        self.filenames = []
+        self.classes = np.zeros((self.nb_sample,), dtype='int32')
+        i = 0
+        for subdir in classes:
+            subpath = os.path.join(directory, subdir)
+            for root, dirs, files in _recursive_list(subpath):
+                for fname in files:
+                    is_valid = False
+                    for extension in white_list_formats:
+                        if fname.lower().endswith('.' + extension):
+                            is_valid = True
+                            break
+                    if is_valid:
+                        self.classes[i] = self.class_indices[subdir]
+                        i += 1
+                        # add filename relative to directory
+                        absolute_path = os.path.join(root, fname)
+                        self.filenames.append(os.path.relpath(absolute_path, directory))
+        super(DirectoryIterator, self).__init__(self.nb_sample, batch_size, shuffle, seed)
+
+    def next(self):
+        with self.lock:
+            index_array, current_index, current_batch_size = next(self.index_generator)
+        # The transformation of images is not under thread lock so it can be done in parallel
+        batch_x = np.zeros((current_batch_size,) + self.image_shape)
+        grayscale = self.color_mode == 'grayscale'
+        # build batch of image data
+        for i, j in enumerate(index_array):
+            fname = self.filenames[j]
+            img = load_img(os.path.join(self.directory, fname),
+                           grayscale=grayscale,
+                           target_size=self.target_size)
+            x = img_to_array(img, dim_ordering=self.dim_ordering)
+            x = self.image_data_generator.random_transform(x)
+            x = self.image_data_generator.standardize(x)
+            batch_x[i] = x
+        # optionally save augmented images to disk for debugging purposes
+        if self.save_to_dir:
+            for i in range(current_batch_size):
+                img = array_to_img(batch_x[i], self.dim_ordering, scale=True)
+                fname = '{prefix}_{index}_{hash}.{format}'.format(prefix=self.save_prefix,
+                                                                  index=current_index + i,
+                                                                  hash=np.random.randint(1e4),
+                                                                  format=self.save_format)
+                img.save(os.path.join(self.save_to_dir, fname))
+        # build batch of labels
+        if self.class_mode == 'sparse':
+            batch_y = self.classes[index_array]
+        elif self.class_mode == 'binary':
+            batch_y = self.classes[index_array].astype('float32')
+        elif self.class_mode == 'categorical':
+            batch_y = np.zeros((len(batch_x), self.nb_class), dtype='float32')
+            for i, label in enumerate(self.classes[index_array]):
+                batch_y[i, label] = 1.
+        else:
+            return batch_x
+        return batch_x, batch_y