--- a +++ b/image.py @@ -0,0 +1,1846 @@ +"""Utilities for real-time data augmentation on image data. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import re +from scipy import linalg +import scipy.ndimage as ndi +from six.moves import range +import os +import threading +import warnings +import multiprocessing.pool +import cv2 +from functools import partial +from skimage import data, img_as_float +from skimage import exposure + +from . import get_keras_submodule + +backend = get_keras_submodule('backend') +keras_utils = get_keras_submodule('utils') + +try: + from PIL import ImageEnhance + from PIL import Image as pil_image +except ImportError: + pil_image = None + +if pil_image is not None: + _PIL_INTERPOLATION_METHODS = { + 'nearest': pil_image.NEAREST, + 'bilinear': pil_image.BILINEAR, + 'bicubic': pil_image.BICUBIC, + 'antialias' : pil_image.ANTIALIAS, + } + # These methods were only introduced in version 3.4.0 (2016). + if hasattr(pil_image, 'HAMMING'): + _PIL_INTERPOLATION_METHODS['hamming'] = pil_image.HAMMING + if hasattr(pil_image, 'BOX'): + _PIL_INTERPOLATION_METHODS['box'] = pil_image.BOX + # This method is new in version 1.1.3 (2013). + if hasattr(pil_image, 'LANCZOS'): + _PIL_INTERPOLATION_METHODS['lanczos'] = pil_image.LANCZOS + + +def random_rotation(x, rg, row_axis=1, col_axis=2, channel_axis=0, + fill_mode='nearest', cval=0.): + """Performs a random rotation of a Numpy image tensor. + + # Arguments + x: Input tensor. Must be 3D. + rg: Rotation range, in degrees. + row_axis: Index of axis for rows in the input tensor. + col_axis: Index of axis for columns in the input tensor. + channel_axis: Index of axis for channels in the input tensor. + fill_mode: Points outside the boundaries of the input + are filled according to the given mode + (one of `{'constant', 'nearest', 'reflect', 'wrap'}`). + cval: Value used for points outside the boundaries + of the input if `mode='constant'`. + + # Returns + Rotated Numpy image tensor. + """ + theta = np.random.uniform(-rg, rg) + x = apply_affine_transform(x, theta=theta, channel_axis=channel_axis, + fill_mode=fill_mode, cval=cval) + return x + + +def random_shift(x, wrg, hrg, row_axis=1, col_axis=2, channel_axis=0, + fill_mode='nearest', cval=0.): + """Performs a random spatial shift of a Numpy image tensor. + + # Arguments + x: Input tensor. Must be 3D. + wrg: Width shift range, as a float fraction of the width. + hrg: Height shift range, as a float fraction of the height. + row_axis: Index of axis for rows in the input tensor. + col_axis: Index of axis for columns in the input tensor. + channel_axis: Index of axis for channels in the input tensor. + fill_mode: Points outside the boundaries of the input + are filled according to the given mode + (one of `{'constant', 'nearest', 'reflect', 'wrap'}`). + cval: Value used for points outside the boundaries + of the input if `mode='constant'`. + + # Returns + Shifted Numpy image tensor. + """ + h, w = x.shape[row_axis], x.shape[col_axis] + tx = np.random.uniform(-hrg, hrg) * h + ty = np.random.uniform(-wrg, wrg) * w + x = apply_affine_transform(x, tx=tx, ty=ty, channel_axis=channel_axis, + fill_mode=fill_mode, cval=cval) + return x + + +def random_shear(x, intensity, row_axis=1, col_axis=2, channel_axis=0, + fill_mode='nearest', cval=0.): + """Performs a random spatial shear of a Numpy image tensor. + + # Arguments + x: Input tensor. Must be 3D. + intensity: Transformation intensity in degrees. + row_axis: Index of axis for rows in the input tensor. + col_axis: Index of axis for columns in the input tensor. + channel_axis: Index of axis for channels in the input tensor. + fill_mode: Points outside the boundaries of the input + are filled according to the given mode + (one of `{'constant', 'nearest', 'reflect', 'wrap'}`). + cval: Value used for points outside the boundaries + of the input if `mode='constant'`. + + # Returns + Sheared Numpy image tensor. + """ + shear = np.random.uniform(-intensity, intensity) + x = apply_affine_transform(x, shear=shear, channel_axis=channel_axis, + fill_mode=fill_mode, cval=cval) + return x + + +def random_zoom(x, zoom_range, row_axis=1, col_axis=2, channel_axis=0, + fill_mode='nearest', cval=0.): + """Performs a random spatial zoom of a Numpy image tensor. + + # Arguments + x: Input tensor. Must be 3D. + zoom_range: Tuple of floats; zoom range for width and height. + row_axis: Index of axis for rows in the input tensor. + col_axis: Index of axis for columns in the input tensor. + channel_axis: Index of axis for channels in the input tensor. + fill_mode: Points outside the boundaries of the input + are filled according to the given mode + (one of `{'constant', 'nearest', 'reflect', 'wrap'}`). + cval: Value used for points outside the boundaries + of the input if `mode='constant'`. + + # Returns + Zoomed Numpy image tensor. + + # Raises + ValueError: if `zoom_range` isn't a tuple. + """ + if len(zoom_range) != 2: + raise ValueError('`zoom_range` should be a tuple or list of two' + ' floats. Received: ', zoom_range) + + if zoom_range[0] == 1 and zoom_range[1] == 1: + zx, zy = 1, 1 + else: + zx, zy = np.random.uniform(zoom_range[0], zoom_range[1], 2) + x = apply_affine_transform(x, zx=zx, zy=zy, channel_axis=channel_axis, + fill_mode=fill_mode, cval=cval) + return x + + +def apply_channel_shift(x, intensity, channel_axis=0): + """Performs a channel shift. + + # Arguments + x: Input tensor. Must be 3D. + intensity: Transformation intensity. + channel_axis: Index of axis for channels in the input tensor. + + # Returns + Numpy image tensor. + + """ + x = np.rollaxis(x, channel_axis, 0) + min_x, max_x = np.min(x), np.max(x) + channel_images = [ + np.clip(x_channel + intensity, + min_x, + max_x) + for x_channel in x] + x = np.stack(channel_images, axis=0) + x = np.rollaxis(x, 0, channel_axis + 1) + return x + + +def random_channel_shift(x, intensity_range, channel_axis=0): + """Performs a random channel shift. + + # Arguments + x: Input tensor. Must be 3D. + intensity_range: Transformation intensity. + channel_axis: Index of axis for channels in the input tensor. + + # Returns + Numpy image tensor. + """ + intensity = np.random.uniform(-intensity_range, intensity_range) + return apply_channel_shift(x, intensity, channel_axis=channel_axis) + + +def apply_brightness_shift(x, brightness): + """Performs a brightness shift. + + # Arguments + x: Input tensor. Must be 3D. + brightness: Float. The new brightness value. + channel_axis: Index of axis for channels in the input tensor. + + # Returns + Numpy image tensor. + + # Raises + ValueError if `brightness_range` isn't a tuple. + """ + x = array_to_img(x) + x = imgenhancer_Brightness = ImageEnhance.Brightness(x) + x = imgenhancer_Brightness.enhance(brightness) + x = img_to_array(x) + return x + + +def random_brightness(x, brightness_range): + """Performs a random brightness shift. + + # Arguments + x: Input tensor. Must be 3D. + brightness_range: Tuple of floats; brightness range. + channel_axis: Index of axis for channels in the input tensor. + + # Returns + Numpy image tensor. + + # Raises + ValueError if `brightness_range` isn't a tuple. + """ + if len(brightness_range) != 2: + raise ValueError( + '`brightness_range should be tuple or list of two floats. ' + 'Received: %s' % brightness_range) + + u = np.random.uniform(brightness_range[0], brightness_range[1]) + return apply_brightness_shift(x, u) + + +def transform_matrix_offset_center(matrix, x, y): + o_x = float(x) / 2 + 0.5 + o_y = float(y) / 2 + 0.5 + offset_matrix = np.array([[1, 0, o_x], [0, 1, o_y], [0, 0, 1]]) + reset_matrix = np.array([[1, 0, -o_x], [0, 1, -o_y], [0, 0, 1]]) + transform_matrix = np.dot(np.dot(offset_matrix, matrix), reset_matrix) + return transform_matrix + + +def apply_affine_transform(x, theta=0, tx=0, ty=0, shear=0, zx=1, zy=1, + row_axis=0, col_axis=1, channel_axis=2, + fill_mode='nearest', cval=0.): + """Applies an affine transformation specified by the parameters given. + + # Arguments + x: 2D numpy array, single image. + theta: Rotation angle in degrees. + tx: Width shift. + ty: Heigh shift. + shear: Shear angle in degrees. + zx: Zoom in x direction. + zy: Zoom in y direction + row_axis: Index of axis for rows in the input image. + col_axis: Index of axis for columns in the input image. + channel_axis: Index of axis for channels in the input image. + fill_mode: Points outside the boundaries of the input + are filled according to the given mode + (one of `{'constant', 'nearest', 'reflect', 'wrap'}`). + cval: Value used for points outside the boundaries + of the input if `mode='constant'`. + + # Returns + The transformed version of the input. + """ + transform_matrix = None + if theta != 0: + theta = np.deg2rad(theta) + rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0], + [np.sin(theta), np.cos(theta), 0], + [0, 0, 1]]) + transform_matrix = rotation_matrix + + if tx != 0 or ty != 0: + shift_matrix = np.array([[1, 0, tx], + [0, 1, ty], + [0, 0, 1]]) + if transform_matrix is None: + transform_matrix = shift_matrix + else: + transform_matrix = np.dot(transform_matrix, shift_matrix) + + if shear != 0: + shear = np.deg2rad(shear) + shear_matrix = np.array([[1, -np.sin(shear), 0], + [0, np.cos(shear), 0], + [0, 0, 1]]) + if transform_matrix is None: + transform_matrix = shear_matrix + else: + transform_matrix = np.dot(transform_matrix, shear_matrix) + + if zx != 1 or zy != 1: + zoom_matrix = np.array([[zx, 0, 0], + [0, zy, 0], + [0, 0, 1]]) + if transform_matrix is None: + transform_matrix = zoom_matrix + else: + transform_matrix = np.dot(transform_matrix, zoom_matrix) + + if transform_matrix is not None: + h, w = x.shape[row_axis], x.shape[col_axis] + transform_matrix = transform_matrix_offset_center( + transform_matrix, h, w) + x = np.rollaxis(x, channel_axis, 0) + final_affine_matrix = transform_matrix[:2, :2] + final_offset = transform_matrix[:2, 2] + + channel_images = [ndi.interpolation.affine_transform( + x_channel, + final_affine_matrix, + final_offset, + order=1, + mode=fill_mode, + cval=cval) for x_channel in x] + x = np.stack(channel_images, axis=0) + x = np.rollaxis(x, 0, channel_axis + 1) + return x + +def rgb2gray(rgb): + r,g,b = rgb[:,:,0], rgb[:,:,1], rgb[:,:,2] + gray = 0.2989* r + 0.5870*g + 0.1140*b + return gray + + +def flip_axis(x, axis): + x = np.asarray(x).swapaxes(axis, 0) + x = x[::-1, ...] + x = x.swapaxes(0, axis) + return x + + +def array_to_img(x, data_format=None, scale=True): + """Converts a 3D Numpy array to a PIL Image instance. + + # Arguments + x: Input Numpy array. + data_format: Image data format. + either "channels_first" or "channels_last". + scale: Whether to rescale image values + to be within `[0, 255]`. + + # Returns + A PIL Image instance. + + # Raises + ImportError: if PIL is not available. + ValueError: if invalid `x` or `data_format` is passed. + """ + if pil_image is None: + raise ImportError('Could not import PIL.Image. ' + 'The use of `array_to_img` requires PIL.') + x = np.asarray(x, dtype=backend.floatx()) + if x.ndim != 3: + raise ValueError('Expected image array to have rank 3 (single image). ' + 'Got array with shape:', x.shape) + + if data_format is None: + data_format = backend.image_data_format() + if data_format not in {'channels_first', 'channels_last'}: + raise ValueError('Invalid data_format:', data_format) + + # Original Numpy array x has format (height, width, channel) + # or (channel, height, width) + # but target PIL image has format (width, height, channel) + if data_format == 'channels_first': + x = x.transpose(1, 2, 0) + if scale: + x = x + max(-np.min(x), 0) + x_max = np.max(x) + if x_max != 0: + x /= x_max + x *= 255 + if x.shape[2] == 3: + # RGB + return pil_image.fromarray(x.astype('uint8'), 'RGB') + elif x.shape[2] == 1: + # grayscale + return pil_image.fromarray(x[:, :, 0].astype('uint8'), 'L') + else: + raise ValueError('Unsupported channel number: ', x.shape[2]) + + +def img_to_array(img, data_format=None): + """Converts a PIL Image instance to a Numpy array. + + # Arguments + img: PIL Image instance. + data_format: Image data format, + either "channels_first" or "channels_last". + + # Returns + A 3D Numpy array. + + # Raises + ValueError: if invalid `img` or `data_format` is passed. + """ + if data_format is None: + data_format = backend.image_data_format() + if data_format not in {'channels_first', 'channels_last'}: + raise ValueError('Unknown data_format: ', data_format) + # Numpy array x has format (height, width, channel) + # or (channel, height, width) + # but original PIL image has format (width, height, channel) + x = np.asarray(img, dtype=backend.floatx()) + if len(x.shape) == 3: + if data_format == 'channels_first': + x = x.transpose(2, 0, 1) + elif len(x.shape) == 2: + if data_format == 'channels_first': + x = x.reshape((1, x.shape[0], x.shape[1])) + else: + x = x.reshape((x.shape[0], x.shape[1], 1)) + else: + raise ValueError('Unsupported image shape: ', x.shape) + return x + + +def save_img(path, + x, + data_format=None, + file_format=None, + scale=True, **kwargs): + """Saves an image stored as a Numpy array to a path or file object. + + # Arguments + path: Path or file object. + x: Numpy array. + data_format: Image data format, + either "channels_first" or "channels_last". + file_format: Optional file format override. If omitted, the + format to use is determined from the filename extension. + If a file object was used instead of a filename, this + parameter should always be used. + scale: Whether to rescale image values to be within `[0, 255]`. + **kwargs: Additional keyword arguments passed to `PIL.Image.save()`. + """ + img = array_to_img(x, data_format=data_format, scale=scale) + img.save(path, format=file_format, **kwargs) + + +def load_img(path, grayscale=False, target_size=None, + interpolation='nearest'): #nearest + """Loads an image into PIL format. + + # Arguments + path: Path to image file. + grayscale: Boolean, whether to load the image as grayscale. + target_size: Either `None` (default to original size) + or tuple of ints `(img_height, img_width)`. + interpolation: Interpolation method used to resample the image if the + target size is different from that of the loaded image. + Supported methods are "nearest", "bilinear", and "bicubic". + If PIL version 1.1.3 or newer is installed, "lanczos" is also + supported. If PIL version 3.4.0 or newer is installed, "box" and + "hamming" are also supported. By default, "nearest" is used. + + # Returns + A PIL Image instance. + + # Raises + ImportError: if PIL is not available. + ValueError: if interpolation method is not supported. + """ + if pil_image is None: + raise ImportError('Could not import PIL.Image. ' + 'The use of `array_to_img` requires PIL.') + img = pil_image.open(path) + if grayscale: + if img.mode != 'L': + img = img.convert('L') + else: + if img.mode != 'RGB': + img = img.convert('RGB') + if target_size is not None: + width_height_tuple = (target_size[1], target_size[0]) + if img.size != width_height_tuple: + if interpolation not in _PIL_INTERPOLATION_METHODS: + raise ValueError( + 'Invalid interpolation method {} specified. Supported ' + 'methods are {}'.format( + interpolation, + ", ".join(_PIL_INTERPOLATION_METHODS.keys()))) + resample = _PIL_INTERPOLATION_METHODS[interpolation] + img = img.resize(width_height_tuple, resample) + return img + + +def list_pictures(directory, ext='jpg|jpeg|bmp|png|ppm'): + return [os.path.join(root, f) + for root, _, files in os.walk(directory) for f in files + if re.match(r'([\w]+\.(?:' + ext + '))', f.lower())] + + +class ImageDataGenerator(object): + """Generate batches of tensor image data with real-time data augmentation. + The data will be looped over (in batches). + + # Arguments + featurewise_center: Boolean. + Set input mean to 0 over the dataset, feature-wise. + samplewise_center: Boolean. Set each sample mean to 0. + featurewise_std_normalization: Boolean. + Divide inputs by std of the dataset, feature-wise. + samplewise_std_normalization: Boolean. Divide each input by its std. + zca_epsilon: epsilon for ZCA whitening. Default is 1e-6. + zca_whitening: Boolean. Apply ZCA whitening. + rotation_range: Int. Degree range for random rotations. + width_shift_range: Float, 1-D array-like or int + - float: fraction of total width, if < 1, or pixels if >= 1. + - 1-D array-like: random elements from the array. + - int: integer number of pixels from interval + `(-width_shift_range, +width_shift_range)` + - With `width_shift_range=2` possible values + are integers `[-1, 0, +1]`, + same as with `width_shift_range=[-1, 0, +1]`, + while with `width_shift_range=1.0` possible values are floats + in the interval [-1.0, +1.0). + height_shift_range: Float, 1-D array-like or int + - float: fraction of total height, if < 1, or pixels if >= 1. + - 1-D array-like: random elements from the array. + - int: integer number of pixels from interval + `(-height_shift_range, +height_shift_range)` + - With `height_shift_range=2` possible values + are integers `[-1, 0, +1]`, + same as with `height_shift_range=[-1, 0, +1]`, + while with `height_shift_range=1.0` possible values are floats + in the interval [-1.0, +1.0). + shear_range: Float. Shear Intensity + (Shear angle in counter-clockwise direction in degrees) + zoom_range: Float or [lower, upper]. Range for random zoom. + If a float, `[lower, upper] = [1-zoom_range, 1+zoom_range]`. + channel_shift_range: Float. Range for random channel shifts. + fill_mode: One of {"constant", "nearest", "reflect" or "wrap"}. + Default is 'nearest'. + Points outside the boundaries of the input are filled + according to the given mode: + - 'constant': kkkkkkkk|abcd|kkkkkkkk (cval=k) + - 'nearest': aaaaaaaa|abcd|dddddddd + - 'reflect': abcddcba|abcd|dcbaabcd + - 'wrap': abcdabcd|abcd|abcdabcd + cval: Float or Int. + Value used for points outside the boundaries + when `fill_mode = "constant"`. + horizontal_flip: Boolean. Randomly flip inputs horizontally. + vertical_flip: Boolean. Randomly flip inputs vertically. + rescale: rescaling factor. Defaults to None. + If None or 0, no rescaling is applied, + otherwise we multiply the data by the value provided + (before applying any other transformation). + preprocessing_function: function that will be implied on each input. + The function will run after the image is resized and augmented. + The function should take one argument: + one image (Numpy tensor with rank 3), + and should output a Numpy tensor with the same shape. + data_format: Image data format, + either "channels_first" or "channels_last". + "channels_last" mode means that the images should have shape + `(samples, height, width, channels)`, + "channels_first" mode means that the images should have shape + `(samples, channels, height, width)`. + It defaults to the `image_data_format` value found in your + Keras config file at `~/.keras/keras.json`. + If you never set it, then it will be "channels_last". + validation_split: Float. Fraction of images reserved for validation + (strictly between 0 and 1). + + # Examples + Example of using `.flow(x, y)`: + + ```python + (x_train, y_train), (x_test, y_test) = cifar10.load_data() + y_train = np_utils.to_categorical(y_train, num_classes) + y_test = np_utils.to_categorical(y_test, num_classes) + + datagen = ImageDataGenerator( + featurewise_center=True, + featurewise_std_normalization=True, + rotation_range=20, + width_shift_range=0.2, + height_shift_range=0.2, + horizontal_flip=True) + + # compute quantities required for featurewise normalization + # (std, mean, and principal components if ZCA whitening is applied) + datagen.fit(x_train) + + # fits the model on batches with real-time data augmentation: + model.fit_generator(datagen.flow(x_train, y_train, batch_size=32), + steps_per_epoch=len(x_train) / 32, epochs=epochs) + + # here's a more "manual" example + for e in range(epochs): + print('Epoch', e) + batches = 0 + for x_batch, y_batch in datagen.flow(x_train, y_train, batch_size=32): + model.fit(x_batch, y_batch) + batches += 1 + if batches >= len(x_train) / 32: + # we need to break the loop by hand because + # the generator loops indefinitely + break + ``` + Example of using `.flow_from_directory(directory)`: + + ```python + train_datagen = ImageDataGenerator( + rescale=1./255, + shear_range=0.2, + zoom_range=0.2, + horizontal_flip=True) + + test_datagen = ImageDataGenerator(rescale=1./255) + + train_generator = train_datagen.flow_from_directory( + 'data/train', + target_size=(150, 150), + batch_size=32, + class_mode='binary') + + validation_generator = test_datagen.flow_from_directory( + 'data/validation', + target_size=(150, 150), + batch_size=32, + class_mode='binary') + + model.fit_generator( + train_generator, + steps_per_epoch=2000, + epochs=50, + validation_data=validation_generator, + validation_steps=800) + ``` + + Example of transforming images and masks together. + + ```python + # we create two instances with the same arguments + data_gen_args = dict(featurewise_center=True, + featurewise_std_normalization=True, + rotation_range=90., + width_shift_range=0.1, + height_shift_range=0.1, + zoom_range=0.2) + image_datagen = ImageDataGenerator(**data_gen_args) + mask_datagen = ImageDataGenerator(**data_gen_args) + + # Provide the same seed and keyword arguments to the fit and flow methods + seed = 1 + image_datagen.fit(images, augment=True, seed=seed) + mask_datagen.fit(masks, augment=True, seed=seed) + + image_generator = image_datagen.flow_from_directory( + 'data/images', + class_mode=None, + seed=seed) + + mask_generator = mask_datagen.flow_from_directory( + 'data/masks', + class_mode=None, + seed=seed) + + # combine generators into one which yields image and masks + train_generator = zip(image_generator, mask_generator) + + model.fit_generator( + train_generator, + steps_per_epoch=2000, + epochs=50) + ``` + """ + + def __init__(self, + contrast_stretching=False, + histogram_equalization=False, + adaptive_equalization=False, + featurewise_center=False, + samplewise_center=False, + featurewise_std_normalization=False, + samplewise_std_normalization=False, + zca_whitening=False, + zca_epsilon=1e-6, + rotation_range=0., + width_shift_range=0., + height_shift_range=0., + brightness_range=None, + shear_range=0., + zoom_range=0., + channel_shift_range=0., + fill_mode='nearest', + cval=0., + horizontal_flip=False, + vertical_flip=False, + rescale=None, + preprocessing_function=None, + data_format=None, + validation_split=0.0): + if data_format is None: + data_format = backend.image_data_format() + self.contrast_stretching = contrast_stretching + self.histogram_equalization = histogram_equalization + self.adaptive_equalization = adaptive_equalization + self.featurewise_center = featurewise_center + self.samplewise_center = samplewise_center + self.featurewise_std_normalization = featurewise_std_normalization + self.samplewise_std_normalization = samplewise_std_normalization + self.zca_whitening = zca_whitening + self.zca_epsilon = zca_epsilon + self.rotation_range = rotation_range + self.width_shift_range = width_shift_range + self.height_shift_range = height_shift_range + self.brightness_range = brightness_range + self.shear_range = shear_range + self.zoom_range = zoom_range + self.channel_shift_range = channel_shift_range + self.fill_mode = fill_mode + self.cval = cval + self.horizontal_flip = horizontal_flip + self.vertical_flip = vertical_flip + self.rescale = rescale + self.preprocessing_function = preprocessing_function + + if data_format not in {'channels_last', 'channels_first'}: + raise ValueError( + '`data_format` should be `"channels_last"` ' + '(channel after row and column) or ' + '`"channels_first"` (channel before row and column). ' + 'Received: %s' % data_format) + self.data_format = data_format + if data_format == 'channels_first': + self.channel_axis = 1 + self.row_axis = 2 + self.col_axis = 3 + if data_format == 'channels_last': + self.channel_axis = 3 + self.row_axis = 1 + self.col_axis = 2 + if validation_split and not 0 < validation_split < 1: + raise ValueError( + '`validation_split` must be strictly between 0 and 1. ' + ' Received: %s' % validation_split) + self._validation_split = validation_split + + self.mean = None + self.std = None + self.principal_components = None + + if np.isscalar(zoom_range): + self.zoom_range = [1 - zoom_range, 1 + zoom_range] + elif len(zoom_range) == 2: + self.zoom_range = [zoom_range[0], zoom_range[1]] + else: + raise ValueError('`zoom_range` should be a float or ' + 'a tuple or list of two floats. ' + 'Received: %s' % zoom_range) + if zca_whitening: + if not featurewise_center: + self.featurewise_center = True + warnings.warn('This ImageDataGenerator specifies ' + '`zca_whitening`, which overrides ' + 'setting of `featurewise_center`.') + if featurewise_std_normalization: + self.featurewise_std_normalization = False + warnings.warn('This ImageDataGenerator specifies ' + '`zca_whitening` ' + 'which overrides setting of' + '`featurewise_std_normalization`.') + if featurewise_std_normalization: + if not featurewise_center: + self.featurewise_center = True + warnings.warn('This ImageDataGenerator specifies ' + '`featurewise_std_normalization`, ' + 'which overrides setting of ' + '`featurewise_center`.') + if samplewise_std_normalization: + if not samplewise_center: + self.samplewise_center = True + warnings.warn('This ImageDataGenerator specifies ' + '`samplewise_std_normalization`, ' + 'which overrides setting of ' + '`samplewise_center`.') + + def flow(self, x, + y=None, batch_size=32, shuffle=True, + sample_weight=None, seed=None, + save_to_dir=None, save_prefix='', save_format='png', subset=None): + """Takes data & label arrays, generates batches of augmented data. + + # Arguments + x: Input data. Numpy array of rank 4 or a tuple. + If tuple, the first element + should contain the images and the second element + another numpy array or a list of numpy arrays + that gets passed to the output + without any modifications. + Can be used to feed the model miscellaneous data + along with the images. + In case of grayscale data, the channels axis of the image array + should have value 1, and in case + of RGB data, it should have value 3. + y: Labels. + batch_size: Int (default: 32). + shuffle: Boolean (default: True). + sample_weight: Sample weights. + seed: Int (default: None). + save_to_dir: None or str (default: None). + This allows you to optionally specify a directory + to which to save the augmented pictures being generated + (useful for visualizing what you are doing). + save_prefix: Str (default: `''`). + Prefix to use for filenames of saved pictures + (only relevant if `save_to_dir` is set). + save_format: one of "png", "jpeg" + (only relevant if `save_to_dir` is set). Default: "png". + subset: Subset of data (`"training"` or `"validation"`) if + `validation_split` is set in `ImageDataGenerator`. + + # Returns + An `Iterator` yielding tuples of `(x, y)` + where `x` is a numpy array of image data + (in the case of a single image input) or a list + of numpy arrays (in the case with + additional inputs) and `y` is a numpy array + of corresponding labels. If 'sample_weight' is not None, + the yielded tuples are of the form `(x, y, sample_weight)`. + If `y` is None, only the numpy array `x` is returned. + """ + return NumpyArrayIterator( + x, y, self, + batch_size=batch_size, + shuffle=shuffle, + sample_weight=sample_weight, + seed=seed, + data_format=self.data_format, + save_to_dir=save_to_dir, + save_prefix=save_prefix, + save_format=save_format, + subset=subset) + + def flow_from_directory(self, directory, + target_size=(256, 256), color_mode='rgb', + classes=None, class_mode='categorical', + batch_size=32, shuffle=True, seed=None, + save_to_dir=None, + save_prefix='', + save_format='png', + follow_links=False, + subset=None, + interpolation='nearest'): + """Takes the path to a directory & generates batches of augmented data. + + # Arguments + directory: Path to the target directory. + It should contain one subdirectory per class. + Any PNG, JPG, BMP, PPM or TIF images + inside each of the subdirectories directory tree + will be included in the generator. + See [this script]( + https://gist.github.com/fchollet/ + 0830affa1f7f19fd47b06d4cf89ed44d) + for more details. + target_size: Tuple of integers `(height, width)`, + default: `(256, 256)`. + The dimensions to which all images found will be resized. + color_mode: One of "grayscale", "rbg". Default: "rgb". + Whether the images will be converted to + have 1 or 3 color channels. + classes: Optional list of class subdirectories + (e.g. `['dogs', 'cats']`). Default: None. + If not provided, the list of classes will be automatically + inferred from the subdirectory names/structure + under `directory`, where each subdirectory will + be treated as a different class + (and the order of the classes, which will map to the label + indices, will be alphanumeric). + The dictionary containing the mapping from class names to class + indices can be obtained via the attribute `class_indices`. + class_mode: One of "categorical", "binary", "sparse", + "input", or None. Default: "categorical". + Determines the type of label arrays that are returned: + - "categorical" will be 2D one-hot encoded labels, + - "binary" will be 1D binary labels, + "sparse" will be 1D integer labels, + - "input" will be images identical + to input images (mainly used to work with autoencoders). + - If None, no labels are returned + (the generator will only yield batches of image data, + which is useful to use with `model.predict_generator()`, + `model.evaluate_generator()`, etc.). + Please note that in case of class_mode None, + the data still needs to reside in a subdirectory + of `directory` for it to work correctly. + batch_size: Size of the batches of data (default: 32). + shuffle: Whether to shuffle the data (default: True) + seed: Optional random seed for shuffling and transformations. + save_to_dir: None or str (default: None). + This allows you to optionally specify + a directory to which to save + the augmented pictures being generated + (useful for visualizing what you are doing). + save_prefix: Str. Prefix to use for filenames of saved pictures + (only relevant if `save_to_dir` is set). + save_format: One of "png", "jpeg" + (only relevant if `save_to_dir` is set). Default: "png". + follow_links: Whether to follow symlinks inside + class subdirectories (default: False). + subset: Subset of data (`"training"` or `"validation"`) if + `validation_split` is set in `ImageDataGenerator`. + interpolation: Interpolation method used to + resample the image if the + target size is different from that of the loaded image. + Supported methods are `"nearest"`, `"bilinear"`, + and `"bicubic"`. + If PIL version 1.1.3 or newer is installed, `"lanczos"` is also + supported. If PIL version 3.4.0 or newer is installed, + `"box"` and `"hamming"` are also supported. + By default, `"nearest"` is used. + + # Returns + A `DirectoryIterator` yielding tuples of `(x, y)` + where `x` is a numpy array containing a batch + of images with shape `(batch_size, *target_size, channels)` + and `y` is a numpy array of corresponding labels. + """ + return DirectoryIterator( + directory, self, + target_size=target_size, color_mode=color_mode, + classes=classes, class_mode=class_mode, + data_format=self.data_format, + batch_size=batch_size, shuffle=shuffle, seed=seed, + save_to_dir=save_to_dir, + save_prefix=save_prefix, + save_format=save_format, + follow_links=follow_links, + subset=subset, + interpolation=interpolation) + + def standardize(self, x): + """Applies the normalization configuration to a batch of inputs. + + # Arguments + x: Batch of inputs to be normalized. + + # Returns + The inputs, normalized. + """ + imagenet_mean = np.array([0.485, 0.456, 0.406]) + imagenet_std = np.array([0.229, 0.224, 0.225]) + + if self.rescale: + x *= self.rescale + if self.preprocessing_function: + x = self.preprocessing_function(x) +# if self.rescale: +# x *= self.rescale + if self.samplewise_center: + x -= np.mean(x, keepdims=True) + if self.samplewise_std_normalization: + x /= (np.std(x, keepdims=True) + backend.epsilon()) + + #x = (x - imagenet_mean) / imagenet_std + + if self.featurewise_center: + if self.mean is not None: + x -= self.mean + else: + warnings.warn('This ImageDataGenerator specifies ' + '`featurewise_center`, but it hasn\'t ' + 'been fit on any training data. Fit it ' + 'first by calling `.fit(numpy_data)`.') + if self.featurewise_std_normalization: + if self.std is not None: + x /= (self.std + backend.epsilon()) + else: + warnings.warn('This ImageDataGenerator specifies ' + '`featurewise_std_normalization`, ' + 'but it hasn\'t ' + 'been fit on any training data. Fit it ' + 'first by calling `.fit(numpy_data)`.') + if self.zca_whitening: + if self.principal_components is not None: + flatx = np.reshape(x, (-1, np.prod(x.shape[-3:]))) + whitex = np.dot(flatx, self.principal_components) + x = np.reshape(whitex, x.shape) + else: + warnings.warn('This ImageDataGenerator specifies ' + '`zca_whitening`, but it hasn\'t ' + 'been fit on any training data. Fit it ' + 'first by calling ') + + +# if self.contrast_stretching: +# if np.random.random() < 0.5: +# p2, p98 = np.percentile((x),(2,98)) +# x = (exposure.rescale_intensity((x), in_range=(p2, p98))) + + # if self.adaptive_equalization: + # if np.random.random() < 0.5: + # x = (exposure.equalize_adapthist((x), clip_limit = 0.03)) + + # if self.histogram_equalization: + # if np.random.random() < 0.5: + # x = (exposure.equalize_hist((x))) + + + return x + + + def get_random_transform(self, img_shape, seed=None): + """Generates random parameters for a transformation. + + # Arguments + seed: Random seed. + img_shape: Tuple of integers. + Shape of the image that is transformed. + + # Returns + A dictionary containing randomly chosen parameters describing the + transformation. + """ + img_row_axis = self.row_axis - 1 + img_col_axis = self.col_axis - 1 + + if seed is not None: + np.random.seed(seed) + + if self.rotation_range: + theta = np.random.uniform( + -self.rotation_range, + self.rotation_range) + else: + theta = 0 + + if self.height_shift_range: + try: # 1-D array-like or int + tx = np.random.choice(self.height_shift_range) + tx *= np.random.choice([-1, 1]) + except ValueError: # floating point + tx = np.random.uniform(-self.height_shift_range, + self.height_shift_range) + if np.max(self.height_shift_range) < 1: + tx *= img_shape[img_row_axis] + else: + tx = 0 + + if self.width_shift_range: + try: # 1-D array-like or int + ty = np.random.choice(self.width_shift_range) + ty *= np.random.choice([-1, 1]) + except ValueError: # floating point + ty = np.random.uniform(-self.width_shift_range, + self.width_shift_range) + if np.max(self.width_shift_range) < 1: + ty *= img_shape[img_col_axis] + else: + ty = 0 + + if self.shear_range: + shear = np.random.uniform( + -self.shear_range, + self.shear_range) + else: + shear = 0 + + if self.zoom_range[0] == 1 and self.zoom_range[1] == 1: + zx, zy = 1, 1 + else: + zx, zy = np.random.uniform( + self.zoom_range[0], + self.zoom_range[1], + 2) + + flip_horizontal = (np.random.random() < 0.5) * self.horizontal_flip + flip_vertical = (np.random.random() < 0.5) * self.vertical_flip + + channel_shift_intensity = None + if self.channel_shift_range != 0: + channel_shift_intensity = np.random.uniform(-self.channel_shift_range, + self.channel_shift_range) + + brightness = None + if self.brightness_range is not None: + if len(self.brightness_range) != 2: + raise ValueError( + '`brightness_range should be tuple or list of two floats. ' + 'Received: %s' % brightness_range) + brightness = np.random.uniform(self.brightness_range[0], + self.brightness_range[1]) + + transform_parameters = {'theta': theta, + 'tx': tx, + 'ty': ty, + 'shear': shear, + 'zx': zx, + 'zy': zy, + 'flip_horizontal': flip_horizontal, + 'flip_vertical': flip_vertical, + 'channel_shift_intensity': channel_shift_intensity, + 'brightness': brightness, + 'contrast_stretching' : self.contrast_stretching, + 'adaptive_equalization' : self.adaptive_equalization, + 'histogram_equalization' : self.histogram_equalization + } + + return transform_parameters + + def apply_transform(self, x, transform_parameters): + """Applies a transformation to an image according to given parameters. + + # Arguments + x: 3D tensor, single image. + transform_parameters: Dictionary with string - parameter pairs + describing the transformation. + Currently, the following parameters + from the dictionary are used: + - `'theta'`: Float. Rotation angle in degrees. + - `'tx'`: Float. Shift in the x direction. + - `'ty'`: Float. Shift in the y direction. + - `'shear'`: Float. Shear angle in degrees. + - `'zx'`: Float. Zoom in the x direction. + - `'zy'`: Float. Zoom in the y direction. + - `'flip_horizontal'`: Boolean. Horizontal flip. + - `'flip_vertical'`: Boolean. Vertical flip. + - `'channel_shift_intencity'`: Float. Channel shift intensity. + - `'brightness'`: Float. Brightness shift intensity. + + # Returns + A ransformed version of the input (same shape). + """ + # x is a single image, so it doesn't have image number at index 0 + img_row_axis = self.row_axis - 1 + img_col_axis = self.col_axis - 1 + img_channel_axis = self.channel_axis - 1 + + x = apply_affine_transform(x, transform_parameters.get('theta', 0), + transform_parameters.get('tx', 0), + transform_parameters.get('ty', 0), + transform_parameters.get('shear', 0), + transform_parameters.get('zx', 1), + transform_parameters.get('zy', 1), + row_axis=img_row_axis, col_axis=img_col_axis, + channel_axis=img_channel_axis, + fill_mode=self.fill_mode, cval=self.cval) + + if transform_parameters.get('channel_shift_intensity') is not None: + x = apply_channel_shift(x, + transform_parameters['channel_shift_intensity'], + img_channel_axis) + + if transform_parameters.get('flip_horizontal', False): + x = flip_axis(x, img_col_axis) + + if transform_parameters.get('flip_vertical', False): + x = flip_axis(x, img_row_axis) + + if transform_parameters.get('brightness') is not None: + x = apply_brightness_shift(x, transform_parameters['brightness']) + + + + if transform_parameters.get('contrast_stretching') is not None: + if np.random.random() < 1.0: + x = img_to_array(x) + p2, p98 = np.percentile((x),(2,98)) + x = (exposure.rescale_intensity((x), in_range=(p2, p98))) + # x = x.reshape((x.shape[0], x.shape[1],3)) + +# if transform_parameters.get('adaptive_equalization') is not None: +# if np.random.random() < 1.0: +# x = (exposure.equalize_adapthist(x/255, clip_limit = 0.03)) +# x = x.reshape((x.shape[0], x.shape[1],1)) + + if transform_parameters.get('histogram_equalization') is not None: + if np.random.random() < 1.0: + x[:,:,0] = exposure.equalize_hist(x[:,:,0]) + x[:,:,1] = exposure.equalize_hist(x[:,:,1]) + x[:,:,2] = exposure.equalize_hist(x[:,:,2]) + +# x = x.reshape((x.shape[0], x.shape[1],3)) +# x = x.reshape((x.shape[0], x.shape[1], 1)) + + + return x + + def random_transform(self, x, seed=None): + """Applies a random transformation to an image. + + # Arguments + x: 3D tensor, single image. + seed: Random seed. + + # Returns + A randomly transformed version of the input (same shape). + """ + params = self.get_random_transform(x.shape, seed) + return self.apply_transform(x, params) + + def fit(self, x, + augment=False, + rounds=1, + seed=None): + """Fits the data generator to some sample data. + + This computes the internal data stats related to the + data-dependent transformations, based on an array of sample data. + + Only required if `featurewise_center` or + `featurewise_std_normalization` or `zca_whitening` are set to True. + + # Arguments + x: Sample data. Should have rank 4. + In case of grayscale data, + the channels axis should have value 1, and in case + of RGB data, it should have value 3. + augment: Boolean (default: False). + Whether to fit on randomly augmented samples. + rounds: Int (default: 1). + If using data augmentation (`augment=True`), + this is how many augmentation passes over the data to use. + seed: Int (default: None). Random seed. + """ + x = np.asarray(x, dtype=backend.floatx()) + if x.ndim != 4: + raise ValueError('Input to `.fit()` should have rank 4. ' + 'Got array with shape: ' + str(x.shape)) + if x.shape[self.channel_axis] not in {1, 3, 4}: + warnings.warn( + 'Expected input to be images (as Numpy array) ' + 'following the data format convention "' + + self.data_format + '" (channels on axis ' + + str(self.channel_axis) + '), i.e. expected ' + 'either 1, 3 or 4 channels on axis ' + + str(self.channel_axis) + '. ' + 'However, it was passed an array with shape ' + + str(x.shape) + ' (' + str(x.shape[self.channel_axis]) + + ' channels).') + + if seed is not None: + np.random.seed(seed) + + x = np.copy(x) + if augment: + ax = np.zeros( + tuple([rounds * x.shape[0]] + list(x.shape)[1:]), + dtype=backend.floatx()) + for r in range(rounds): + for i in range(x.shape[0]): + ax[i + r * x.shape[0]] = self.random_transform(x[i]) + x = ax + + if self.featurewise_center: + self.mean = np.mean(x, axis=(0, self.row_axis, self.col_axis)) + broadcast_shape = [1, 1, 1] + broadcast_shape[self.channel_axis - 1] = x.shape[self.channel_axis] + self.mean = np.reshape(self.mean, broadcast_shape) + x -= self.mean + + if self.featurewise_std_normalization: + self.std = np.std(x, axis=(0, self.row_axis, self.col_axis)) + broadcast_shape = [1, 1, 1] + broadcast_shape[self.channel_axis - 1] = x.shape[self.channel_axis] + self.std = np.reshape(self.std, broadcast_shape) + x /= (self.std + backend.epsilon()) + + if self.zca_whitening: + flat_x = np.reshape( + x, (x.shape[0], x.shape[1] * x.shape[2] * x.shape[3])) + sigma = np.dot(flat_x.T, flat_x) / flat_x.shape[0] + u, s, _ = linalg.svd(sigma) + s_inv = 1. / np.sqrt(s[np.newaxis] + self.zca_epsilon) + self.principal_components = (u * s_inv).dot(u.T) + + +class Iterator(keras_utils.Sequence): + """Base class for image data iterators. + + Every `Iterator` must implement the `_get_batches_of_transformed_samples` + method. + + # Arguments + n: Integer, total number of samples in the dataset to loop over. + batch_size: Integer, size of a batch. + shuffle: Boolean, whether to shuffle the data between epochs. + seed: Random seeding for data shuffling. + """ + + def __init__(self, n, batch_size, shuffle, seed): + self.n = n + self.batch_size = batch_size + self.seed = seed + self.shuffle = shuffle + self.batch_index = 0 + self.total_batches_seen = 0 + self.lock = threading.Lock() + self.index_array = None + self.index_generator = self._flow_index() + + def _set_index_array(self): + self.index_array = np.arange(self.n) + if self.shuffle: + self.index_array = np.random.permutation(self.n) + + def __getitem__(self, idx): + if idx >= len(self): + raise ValueError('Asked to retrieve element {idx}, ' + 'but the Sequence ' + 'has length {length}'.format(idx=idx, + length=len(self))) + if self.seed is not None: + np.random.seed(self.seed + self.total_batches_seen) + self.total_batches_seen += 1 + if self.index_array is None: + self._set_index_array() + index_array = self.index_array[self.batch_size * idx: + self.batch_size * (idx + 1)] + return self._get_batches_of_transformed_samples(index_array) + + def __len__(self): + return (self.n + self.batch_size - 1) // self.batch_size # round up + + def on_epoch_end(self): + self._set_index_array() + + def reset(self): + self.batch_index = 0 + + def _flow_index(self): + # Ensure self.batch_index is 0. + self.reset() + while 1: + if self.seed is not None: + np.random.seed(self.seed + self.total_batches_seen) + if self.batch_index == 0: + self._set_index_array() + + current_index = (self.batch_index * self.batch_size) % self.n + if self.n > current_index + self.batch_size: + self.batch_index += 1 + else: + self.batch_index = 0 + self.total_batches_seen += 1 + yield self.index_array[current_index: + current_index + self.batch_size] + + def __iter__(self): + # Needed if we want to do something like: + # for x, y in data_gen.flow(...): + return self + + def __next__(self, *args, **kwargs): + return self.next(*args, **kwargs) + + def _get_batches_of_transformed_samples(self, index_array): + """Gets a batch of transformed samples. + + # Arguments + index_array: Array of sample indices to include in batch. + + # Returns + A batch of transformed samples. + """ + raise NotImplementedError + + +class NumpyArrayIterator(Iterator): + """Iterator yielding data from a Numpy array. + + # Arguments + x: Numpy array of input data or tuple. + If tuple, the second elements is either + another numpy array or a list of numpy arrays, + each of which gets passed + through as an output without any modifications. + y: Numpy array of targets data. + image_data_generator: Instance of `ImageDataGenerator` + to use for random transformations and normalization. + batch_size: Integer, size of a batch. + shuffle: Boolean, whether to shuffle the data between epochs. + sample_weight: Numpy array of sample weights. + seed: Random seed for data shuffling. + data_format: String, one of `channels_first`, `channels_last`. + save_to_dir: Optional directory where to save the pictures + being yielded, in a viewable format. This is useful + for visualizing the random transformations being + applied, for debugging purposes. + save_prefix: String prefix to use for saving sample + images (if `save_to_dir` is set). + save_format: Format to use for saving sample images + (if `save_to_dir` is set). + subset: Subset of data (`"training"` or `"validation"`) if + validation_split is set in ImageDataGenerator. + """ + + def __init__(self, x, y, image_data_generator, + batch_size=32, shuffle=False, sample_weight=None, + seed=None, data_format=None, + save_to_dir=None, save_prefix='', save_format='png', + subset=None): + if (type(x) is tuple) or (type(x) is list): + if type(x[1]) is not list: + x_misc = [np.asarray(x[1])] + else: + x_misc = [np.asarray(xx) for xx in x[1]] + x = x[0] + for xx in x_misc: + if len(x) != len(xx): + raise ValueError( + 'All of the arrays in `x` ' + 'should have the same length. ' + 'Found a pair with: len(x[0]) = %s, len(x[?]) = %s' % + (len(x), len(xx))) + else: + x_misc = [] + + if y is not None and len(x) != len(y): + raise ValueError('`x` (images tensor) and `y` (labels) ' + 'should have the same length. ' + 'Found: x.shape = %s, y.shape = %s' % + (np.asarray(x).shape, np.asarray(y).shape)) + if sample_weight is not None and len(x) != len(sample_weight): + raise ValueError('`x` (images tensor) and `sample_weight` ' + 'should have the same length. ' + 'Found: x.shape = %s, sample_weight.shape = %s' % + (np.asarray(x).shape, np.asarray(sample_weight).shape)) + if subset is not None: + if subset not in {'training', 'validation'}: + raise ValueError('Invalid subset name:', subset, + '; expected "training" or "validation".') + split_idx = int(len(x) * image_data_generator._validation_split) + if subset == 'validation': + x = x[:split_idx] + x_misc = [np.asarray(xx[:split_idx]) for xx in x_misc] + if y is not None: + y = y[:split_idx] + else: + x = x[split_idx:] + x_misc = [np.asarray(xx[split_idx:]) for xx in x_misc] + if y is not None: + y = y[split_idx:] + if data_format is None: + data_format = backend.image_data_format() + self.x = np.asarray(x, dtype=backend.floatx()) + self.x_misc = x_misc + if self.x.ndim != 4: + raise ValueError('Input data in `NumpyArrayIterator` ' + 'should have rank 4. You passed an array ' + 'with shape', self.x.shape) + channels_axis = 3 if data_format == 'channels_last' else 1 + if self.x.shape[channels_axis] not in {1, 3, 4}: + warnings.warn('NumpyArrayIterator is set to use the ' + 'data format convention "' + data_format + '" ' + '(channels on axis ' + str(channels_axis) + + '), i.e. expected either 1, 3 or 4 ' + 'channels on axis ' + str(channels_axis) + '. ' + 'However, it was passed an array with shape ' + + str(self.x.shape) + ' (' + + str(self.x.shape[channels_axis]) + ' channels).') + if y is not None: + self.y = np.asarray(y) + else: + self.y = None + if sample_weight is not None: + self.sample_weight = np.asarray(sample_weight) + else: + self.sample_weight = None + self.image_data_generator = image_data_generator + self.data_format = data_format + self.save_to_dir = save_to_dir + self.save_prefix = save_prefix + self.save_format = save_format + super(NumpyArrayIterator, self).__init__(x.shape[0], + batch_size, + shuffle, + seed) + + def _get_batches_of_transformed_samples(self, index_array): + batch_x = np.zeros(tuple([len(index_array)] + list(self.x.shape)[1:]), + dtype=backend.floatx()) + for i, j in enumerate(index_array): + x = self.x[j] + params = self.image_data_generator.get_random_transform(x.shape) + x = self.image_data_generator.apply_transform( + x.astype(backend.floatx()), params) + x = self.image_data_generator.standardize(x) + batch_x[i] = x + + if self.save_to_dir: + for i, j in enumerate(index_array): + img = array_to_img(batch_x[i], self.data_format, scale=True) + fname = '{prefix}_{index}_{hash}.{format}'.format( + prefix=self.save_prefix, + index=j, + hash=np.random.randint(1e4), + format=self.save_format) + img.save(os.path.join(self.save_to_dir, fname)) + batch_x_miscs = [xx[index_array] for xx in self.x_misc] + output = (batch_x if batch_x_miscs == [] + else [batch_x] + batch_x_miscs,) + if self.y is None: + return output[0] + output += (self.y[index_array],) + if self.sample_weight is not None: + output += (self.sample_weight[index_array],) + return output + + def next(self): + """For python 2.x. + + # Returns + The next batch. + """ + # Keeps under lock only the mechanism which advances + # the indexing of each batch. + with self.lock: + index_array = next(self.index_generator) + # The transformation of images is not under thread lock + # so it can be done in parallel + return self._get_batches_of_transformed_samples(index_array) + + +def _iter_valid_files(directory, white_list_formats, follow_links): + """Iterates on files with extension in `white_list_formats` contained in `directory`. + + # Arguments + directory: Absolute path to the directory + containing files to be counted + white_list_formats: Set of strings containing allowed extensions for + the files to be counted. + follow_links: Boolean. + + # Yields + Tuple of (root, filename) with extension in `white_list_formats`. + """ + def _recursive_list(subpath): + return sorted(os.walk(subpath, followlinks=follow_links), + key=lambda x: x[0]) + + for root, _, files in _recursive_list(directory): + for fname in sorted(files): + for extension in white_list_formats: + if fname.lower().endswith('.tiff'): + warnings.warn('Using \'.tiff\' files with multiple bands ' + 'will cause distortion. ' + 'Please verify your output.') + if fname.lower().endswith('.' + extension): + yield root, fname + + +def _count_valid_files_in_directory(directory, + white_list_formats, + split, + follow_links): + """Counts files with extension in `white_list_formats` contained in `directory`. + + # Arguments + directory: absolute path to the directory + containing files to be counted + white_list_formats: set of strings containing allowed extensions for + the files to be counted. + split: tuple of floats (e.g. `(0.2, 0.6)`) to only take into + account a certain fraction of files in each directory. + E.g.: `segment=(0.6, 1.0)` would only account for last 40 percent + of images in each directory. + follow_links: boolean. + + # Returns + the count of files with extension in `white_list_formats` contained in + the directory. + """ + num_files = len(list( + _iter_valid_files(directory, white_list_formats, follow_links))) + if split: + start, stop = int(split[0] * num_files), int(split[1] * num_files) + else: + start, stop = 0, num_files + return stop - start + + +def _list_valid_filenames_in_directory(directory, white_list_formats, split, + class_indices, follow_links): + """Lists paths of files in `subdir` with extensions in `white_list_formats`. + + # Arguments + directory: absolute path to a directory containing the files to list. + The directory name is used as class label + and must be a key of `class_indices`. + white_list_formats: set of strings containing allowed extensions for + the files to be counted. + split: tuple of floats (e.g. `(0.2, 0.6)`) to only take into + account a certain fraction of files in each directory. + E.g.: `segment=(0.6, 1.0)` would only account for last 40 percent + of images in each directory. + class_indices: dictionary mapping a class name to its index. + follow_links: boolean. + + # Returns + classes: a list of class indices + filenames: the path of valid files in `directory`, relative from + `directory`'s parent (e.g., if `directory` is "dataset/class1", + the filenames will be + `["class1/file1.jpg", "class1/file2.jpg", ...]`). + """ + dirname = os.path.basename(directory) + if split: + num_files = len(list( + _iter_valid_files(directory, white_list_formats, follow_links))) + start, stop = int(split[0] * num_files), int(split[1] * num_files) + valid_files = list( + _iter_valid_files( + directory, white_list_formats, follow_links))[start: stop] + else: + valid_files = _iter_valid_files( + directory, white_list_formats, follow_links) + + classes = [] + filenames = [] + for root, fname in valid_files: + classes.append(class_indices[dirname]) + absolute_path = os.path.join(root, fname) + relative_path = os.path.join( + dirname, os.path.relpath(absolute_path, directory)) + filenames.append(relative_path) + + return classes, filenames + + +class DirectoryIterator(Iterator): + """Iterator capable of reading images from a directory on disk. + + # Arguments + directory: Path to the directory to read images from. + Each subdirectory in this directory will be + considered to contain images from one class, + or alternatively you could specify class subdirectories + via the `classes` argument. + image_data_generator: Instance of `ImageDataGenerator` + to use for random transformations and normalization. + target_size: tuple of integers, dimensions to resize input images to. + color_mode: One of `"rgb"`, `"grayscale"`. Color mode to read images. + classes: Optional list of strings, names of subdirectories + containing images from each class (e.g. `["dogs", "cats"]`). + It will be computed automatically if not set. + class_mode: Mode for yielding the targets: + `"binary"`: binary targets (if there are only two classes), + `"categorical"`: categorical targets, + `"sparse"`: integer targets, + `"input"`: targets are images identical to input images (mainly + used to work with autoencoders), + `None`: no targets get yielded (only input images are yielded). + batch_size: Integer, size of a batch. + shuffle: Boolean, whether to shuffle the data between epochs. + seed: Random seed for data shuffling. + data_format: String, one of `channels_first`, `channels_last`. + save_to_dir: Optional directory where to save the pictures + being yielded, in a viewable format. This is useful + for visualizing the random transformations being + applied, for debugging purposes. + save_prefix: String prefix to use for saving sample + images (if `save_to_dir` is set). + save_format: Format to use for saving sample images + (if `save_to_dir` is set). + subset: Subset of data (`"training"` or `"validation"`) if + validation_split is set in ImageDataGenerator. + interpolation: Interpolation method used to resample the image if the + target size is different from that of the loaded image. + Supported methods are "nearest", "bilinear", and "bicubic". + If PIL version 1.1.3 or newer is installed, "lanczos" is also + supported. If PIL version 3.4.0 or newer is installed, "box" and + "hamming" are also supported. By default, "nearest" is used. + """ + + def __init__(self, directory, image_data_generator, + target_size=(256, 256), color_mode='rgb', + classes=None, class_mode='categorical', + batch_size=32, shuffle=True, seed=None, + data_format=None, + save_to_dir=None, save_prefix='', save_format='png', + follow_links=False, + subset=None, + interpolation='nearest'): + if data_format is None: + data_format = backend.image_data_format() + self.directory = directory + self.image_data_generator = image_data_generator + self.target_size = tuple(target_size) + if color_mode not in {'rgb', 'grayscale'}: + raise ValueError('Invalid color mode:', color_mode, + '; expected "rgb" or "grayscale".') + self.color_mode = color_mode + self.data_format = data_format + if self.color_mode == 'rgb': + if self.data_format == 'channels_last': + self.image_shape = self.target_size + (3,) + else: + self.image_shape = (3,) + self.target_size + else: + if self.data_format == 'channels_last': + self.image_shape = self.target_size + (1,) + else: + self.image_shape = (1,) + self.target_size + self.classes = classes + if class_mode not in {'categorical', 'binary', 'sparse', + 'input', None}: + raise ValueError('Invalid class_mode:', class_mode, + '; expected one of "categorical", ' + '"binary", "sparse", "input"' + ' or None.') + self.class_mode = class_mode + self.save_to_dir = save_to_dir + self.save_prefix = save_prefix + self.save_format = save_format + self.interpolation = interpolation + + if subset is not None: + validation_split = self.image_data_generator._validation_split + if subset == 'validation': + split = (0, validation_split) + elif subset == 'training': + split = (validation_split, 1) + else: + raise ValueError('Invalid subset name: ', subset, + '; expected "training" or "validation"') + else: + split = None + self.subset = subset + + white_list_formats = {'png', 'jpg', 'jpeg', 'bmp', + 'ppm', 'tif', 'tiff'} + # First, count the number of samples and classes. + self.samples = 0 + + if not classes: + classes = [] + for subdir in sorted(os.listdir(directory)): + if os.path.isdir(os.path.join(directory, subdir)): + classes.append(subdir) + self.num_classes = len(classes) + self.class_indices = dict(zip(classes, range(len(classes)))) + + pool = multiprocessing.pool.ThreadPool() + function_partial = partial(_count_valid_files_in_directory, + white_list_formats=white_list_formats, + follow_links=follow_links, + split=split) + self.samples = sum(pool.map(function_partial, + (os.path.join(directory, subdir) + for subdir in classes))) + + print('Found %d images belonging to %d classes.' % + (self.samples, self.num_classes)) + + # Second, build an index of the images + # in the different class subfolders. + results = [] + self.filenames = [] + self.classes = np.zeros((self.samples,), dtype='int32') + i = 0 + for dirpath in (os.path.join(directory, subdir) for subdir in classes): + results.append( + pool.apply_async(_list_valid_filenames_in_directory, + (dirpath, white_list_formats, split, + self.class_indices, follow_links))) + for res in results: + classes, filenames = res.get() + self.classes[i:i + len(classes)] = classes + self.filenames += filenames + i += len(classes) + + pool.close() + pool.join() + super(DirectoryIterator, self).__init__(self.samples, + batch_size, + shuffle, + seed) + + def _get_batches_of_transformed_samples(self, index_array): + batch_x = np.zeros( + (len(index_array),) + self.image_shape, + dtype=backend.floatx()) + grayscale = self.color_mode == 'grayscale' + # build batch of image data + for i, j in enumerate(index_array): + fname = self.filenames[j] + img = load_img(os.path.join(self.directory, fname), + grayscale=grayscale, + target_size=self.target_size, + interpolation=self.interpolation) + x = img_to_array(img, data_format=self.data_format) + params = self.image_data_generator.get_random_transform(x.shape) + x = self.image_data_generator.apply_transform(x, params) + + x = self.image_data_generator.standardize(x) + + batch_x[i] = x + # optionally save augmented images to disk for debugging purposes + if self.save_to_dir: + for i, j in enumerate(index_array): + img = array_to_img(batch_x[i], self.data_format, scale=True) + fname = '{prefix}_{index}_{hash}.{format}'.format( + prefix=self.save_prefix, + index=j, + hash=np.random.randint(1e7), + format=self.save_format) + img.save(os.path.join(self.save_to_dir, fname)) + # build batch of labels + if self.class_mode == 'input': + batch_y = batch_x.copy() + elif self.class_mode == 'sparse': + batch_y = self.classes[index_array] + elif self.class_mode == 'binary': + batch_y = self.classes[index_array].astype(backend.floatx()) + elif self.class_mode == 'categorical': + batch_y = np.zeros( + (len(batch_x), self.num_classes), + dtype=backend.floatx()) + for i, label in enumerate(self.classes[index_array]): + batch_y[i, label] = 1. + else: + return batch_x + return batch_x, batch_y + + def next(self): + """For python 2.x. + + # Returns + The next batch. + """ + with self.lock: + index_array = next(self.index_generator) + # The transformation of images is not under thread lock + # so it can be done in parallel + return self._get_batches_of_transformed_samples(index_array)