|
a |
|
b/image.py |
|
|
1 |
"""Utilities for real-time data augmentation on image data. |
|
|
2 |
""" |
|
|
3 |
from __future__ import absolute_import |
|
|
4 |
from __future__ import division |
|
|
5 |
from __future__ import print_function |
|
|
6 |
|
|
|
7 |
import numpy as np |
|
|
8 |
import re |
|
|
9 |
from scipy import linalg |
|
|
10 |
import scipy.ndimage as ndi |
|
|
11 |
from six.moves import range |
|
|
12 |
import os |
|
|
13 |
import threading |
|
|
14 |
import warnings |
|
|
15 |
import multiprocessing.pool |
|
|
16 |
import cv2 |
|
|
17 |
from functools import partial |
|
|
18 |
from skimage import data, img_as_float |
|
|
19 |
from skimage import exposure |
|
|
20 |
|
|
|
21 |
from . import get_keras_submodule |
|
|
22 |
|
|
|
23 |
backend = get_keras_submodule('backend') |
|
|
24 |
keras_utils = get_keras_submodule('utils') |
|
|
25 |
|
|
|
26 |
try: |
|
|
27 |
from PIL import ImageEnhance |
|
|
28 |
from PIL import Image as pil_image |
|
|
29 |
except ImportError: |
|
|
30 |
pil_image = None |
|
|
31 |
|
|
|
32 |
if pil_image is not None: |
|
|
33 |
_PIL_INTERPOLATION_METHODS = { |
|
|
34 |
'nearest': pil_image.NEAREST, |
|
|
35 |
'bilinear': pil_image.BILINEAR, |
|
|
36 |
'bicubic': pil_image.BICUBIC, |
|
|
37 |
'antialias' : pil_image.ANTIALIAS, |
|
|
38 |
} |
|
|
39 |
# These methods were only introduced in version 3.4.0 (2016). |
|
|
40 |
if hasattr(pil_image, 'HAMMING'): |
|
|
41 |
_PIL_INTERPOLATION_METHODS['hamming'] = pil_image.HAMMING |
|
|
42 |
if hasattr(pil_image, 'BOX'): |
|
|
43 |
_PIL_INTERPOLATION_METHODS['box'] = pil_image.BOX |
|
|
44 |
# This method is new in version 1.1.3 (2013). |
|
|
45 |
if hasattr(pil_image, 'LANCZOS'): |
|
|
46 |
_PIL_INTERPOLATION_METHODS['lanczos'] = pil_image.LANCZOS |
|
|
47 |
|
|
|
48 |
|
|
|
49 |
def random_rotation(x, rg, row_axis=1, col_axis=2, channel_axis=0, |
|
|
50 |
fill_mode='nearest', cval=0.): |
|
|
51 |
"""Performs a random rotation of a Numpy image tensor. |
|
|
52 |
|
|
|
53 |
# Arguments |
|
|
54 |
x: Input tensor. Must be 3D. |
|
|
55 |
rg: Rotation range, in degrees. |
|
|
56 |
row_axis: Index of axis for rows in the input tensor. |
|
|
57 |
col_axis: Index of axis for columns in the input tensor. |
|
|
58 |
channel_axis: Index of axis for channels in the input tensor. |
|
|
59 |
fill_mode: Points outside the boundaries of the input |
|
|
60 |
are filled according to the given mode |
|
|
61 |
(one of `{'constant', 'nearest', 'reflect', 'wrap'}`). |
|
|
62 |
cval: Value used for points outside the boundaries |
|
|
63 |
of the input if `mode='constant'`. |
|
|
64 |
|
|
|
65 |
# Returns |
|
|
66 |
Rotated Numpy image tensor. |
|
|
67 |
""" |
|
|
68 |
theta = np.random.uniform(-rg, rg) |
|
|
69 |
x = apply_affine_transform(x, theta=theta, channel_axis=channel_axis, |
|
|
70 |
fill_mode=fill_mode, cval=cval) |
|
|
71 |
return x |
|
|
72 |
|
|
|
73 |
|
|
|
74 |
def random_shift(x, wrg, hrg, row_axis=1, col_axis=2, channel_axis=0, |
|
|
75 |
fill_mode='nearest', cval=0.): |
|
|
76 |
"""Performs a random spatial shift of a Numpy image tensor. |
|
|
77 |
|
|
|
78 |
# Arguments |
|
|
79 |
x: Input tensor. Must be 3D. |
|
|
80 |
wrg: Width shift range, as a float fraction of the width. |
|
|
81 |
hrg: Height shift range, as a float fraction of the height. |
|
|
82 |
row_axis: Index of axis for rows in the input tensor. |
|
|
83 |
col_axis: Index of axis for columns in the input tensor. |
|
|
84 |
channel_axis: Index of axis for channels in the input tensor. |
|
|
85 |
fill_mode: Points outside the boundaries of the input |
|
|
86 |
are filled according to the given mode |
|
|
87 |
(one of `{'constant', 'nearest', 'reflect', 'wrap'}`). |
|
|
88 |
cval: Value used for points outside the boundaries |
|
|
89 |
of the input if `mode='constant'`. |
|
|
90 |
|
|
|
91 |
# Returns |
|
|
92 |
Shifted Numpy image tensor. |
|
|
93 |
""" |
|
|
94 |
h, w = x.shape[row_axis], x.shape[col_axis] |
|
|
95 |
tx = np.random.uniform(-hrg, hrg) * h |
|
|
96 |
ty = np.random.uniform(-wrg, wrg) * w |
|
|
97 |
x = apply_affine_transform(x, tx=tx, ty=ty, channel_axis=channel_axis, |
|
|
98 |
fill_mode=fill_mode, cval=cval) |
|
|
99 |
return x |
|
|
100 |
|
|
|
101 |
|
|
|
102 |
def random_shear(x, intensity, row_axis=1, col_axis=2, channel_axis=0, |
|
|
103 |
fill_mode='nearest', cval=0.): |
|
|
104 |
"""Performs a random spatial shear of a Numpy image tensor. |
|
|
105 |
|
|
|
106 |
# Arguments |
|
|
107 |
x: Input tensor. Must be 3D. |
|
|
108 |
intensity: Transformation intensity in degrees. |
|
|
109 |
row_axis: Index of axis for rows in the input tensor. |
|
|
110 |
col_axis: Index of axis for columns in the input tensor. |
|
|
111 |
channel_axis: Index of axis for channels in the input tensor. |
|
|
112 |
fill_mode: Points outside the boundaries of the input |
|
|
113 |
are filled according to the given mode |
|
|
114 |
(one of `{'constant', 'nearest', 'reflect', 'wrap'}`). |
|
|
115 |
cval: Value used for points outside the boundaries |
|
|
116 |
of the input if `mode='constant'`. |
|
|
117 |
|
|
|
118 |
# Returns |
|
|
119 |
Sheared Numpy image tensor. |
|
|
120 |
""" |
|
|
121 |
shear = np.random.uniform(-intensity, intensity) |
|
|
122 |
x = apply_affine_transform(x, shear=shear, channel_axis=channel_axis, |
|
|
123 |
fill_mode=fill_mode, cval=cval) |
|
|
124 |
return x |
|
|
125 |
|
|
|
126 |
|
|
|
127 |
def random_zoom(x, zoom_range, row_axis=1, col_axis=2, channel_axis=0, |
|
|
128 |
fill_mode='nearest', cval=0.): |
|
|
129 |
"""Performs a random spatial zoom of a Numpy image tensor. |
|
|
130 |
|
|
|
131 |
# Arguments |
|
|
132 |
x: Input tensor. Must be 3D. |
|
|
133 |
zoom_range: Tuple of floats; zoom range for width and height. |
|
|
134 |
row_axis: Index of axis for rows in the input tensor. |
|
|
135 |
col_axis: Index of axis for columns in the input tensor. |
|
|
136 |
channel_axis: Index of axis for channels in the input tensor. |
|
|
137 |
fill_mode: Points outside the boundaries of the input |
|
|
138 |
are filled according to the given mode |
|
|
139 |
(one of `{'constant', 'nearest', 'reflect', 'wrap'}`). |
|
|
140 |
cval: Value used for points outside the boundaries |
|
|
141 |
of the input if `mode='constant'`. |
|
|
142 |
|
|
|
143 |
# Returns |
|
|
144 |
Zoomed Numpy image tensor. |
|
|
145 |
|
|
|
146 |
# Raises |
|
|
147 |
ValueError: if `zoom_range` isn't a tuple. |
|
|
148 |
""" |
|
|
149 |
if len(zoom_range) != 2: |
|
|
150 |
raise ValueError('`zoom_range` should be a tuple or list of two' |
|
|
151 |
' floats. Received: ', zoom_range) |
|
|
152 |
|
|
|
153 |
if zoom_range[0] == 1 and zoom_range[1] == 1: |
|
|
154 |
zx, zy = 1, 1 |
|
|
155 |
else: |
|
|
156 |
zx, zy = np.random.uniform(zoom_range[0], zoom_range[1], 2) |
|
|
157 |
x = apply_affine_transform(x, zx=zx, zy=zy, channel_axis=channel_axis, |
|
|
158 |
fill_mode=fill_mode, cval=cval) |
|
|
159 |
return x |
|
|
160 |
|
|
|
161 |
|
|
|
162 |
def apply_channel_shift(x, intensity, channel_axis=0): |
|
|
163 |
"""Performs a channel shift. |
|
|
164 |
|
|
|
165 |
# Arguments |
|
|
166 |
x: Input tensor. Must be 3D. |
|
|
167 |
intensity: Transformation intensity. |
|
|
168 |
channel_axis: Index of axis for channels in the input tensor. |
|
|
169 |
|
|
|
170 |
# Returns |
|
|
171 |
Numpy image tensor. |
|
|
172 |
|
|
|
173 |
""" |
|
|
174 |
x = np.rollaxis(x, channel_axis, 0) |
|
|
175 |
min_x, max_x = np.min(x), np.max(x) |
|
|
176 |
channel_images = [ |
|
|
177 |
np.clip(x_channel + intensity, |
|
|
178 |
min_x, |
|
|
179 |
max_x) |
|
|
180 |
for x_channel in x] |
|
|
181 |
x = np.stack(channel_images, axis=0) |
|
|
182 |
x = np.rollaxis(x, 0, channel_axis + 1) |
|
|
183 |
return x |
|
|
184 |
|
|
|
185 |
|
|
|
186 |
def random_channel_shift(x, intensity_range, channel_axis=0): |
|
|
187 |
"""Performs a random channel shift. |
|
|
188 |
|
|
|
189 |
# Arguments |
|
|
190 |
x: Input tensor. Must be 3D. |
|
|
191 |
intensity_range: Transformation intensity. |
|
|
192 |
channel_axis: Index of axis for channels in the input tensor. |
|
|
193 |
|
|
|
194 |
# Returns |
|
|
195 |
Numpy image tensor. |
|
|
196 |
""" |
|
|
197 |
intensity = np.random.uniform(-intensity_range, intensity_range) |
|
|
198 |
return apply_channel_shift(x, intensity, channel_axis=channel_axis) |
|
|
199 |
|
|
|
200 |
|
|
|
201 |
def apply_brightness_shift(x, brightness): |
|
|
202 |
"""Performs a brightness shift. |
|
|
203 |
|
|
|
204 |
# Arguments |
|
|
205 |
x: Input tensor. Must be 3D. |
|
|
206 |
brightness: Float. The new brightness value. |
|
|
207 |
channel_axis: Index of axis for channels in the input tensor. |
|
|
208 |
|
|
|
209 |
# Returns |
|
|
210 |
Numpy image tensor. |
|
|
211 |
|
|
|
212 |
# Raises |
|
|
213 |
ValueError if `brightness_range` isn't a tuple. |
|
|
214 |
""" |
|
|
215 |
x = array_to_img(x) |
|
|
216 |
x = imgenhancer_Brightness = ImageEnhance.Brightness(x) |
|
|
217 |
x = imgenhancer_Brightness.enhance(brightness) |
|
|
218 |
x = img_to_array(x) |
|
|
219 |
return x |
|
|
220 |
|
|
|
221 |
|
|
|
222 |
def random_brightness(x, brightness_range): |
|
|
223 |
"""Performs a random brightness shift. |
|
|
224 |
|
|
|
225 |
# Arguments |
|
|
226 |
x: Input tensor. Must be 3D. |
|
|
227 |
brightness_range: Tuple of floats; brightness range. |
|
|
228 |
channel_axis: Index of axis for channels in the input tensor. |
|
|
229 |
|
|
|
230 |
# Returns |
|
|
231 |
Numpy image tensor. |
|
|
232 |
|
|
|
233 |
# Raises |
|
|
234 |
ValueError if `brightness_range` isn't a tuple. |
|
|
235 |
""" |
|
|
236 |
if len(brightness_range) != 2: |
|
|
237 |
raise ValueError( |
|
|
238 |
'`brightness_range should be tuple or list of two floats. ' |
|
|
239 |
'Received: %s' % brightness_range) |
|
|
240 |
|
|
|
241 |
u = np.random.uniform(brightness_range[0], brightness_range[1]) |
|
|
242 |
return apply_brightness_shift(x, u) |
|
|
243 |
|
|
|
244 |
|
|
|
245 |
def transform_matrix_offset_center(matrix, x, y): |
|
|
246 |
o_x = float(x) / 2 + 0.5 |
|
|
247 |
o_y = float(y) / 2 + 0.5 |
|
|
248 |
offset_matrix = np.array([[1, 0, o_x], [0, 1, o_y], [0, 0, 1]]) |
|
|
249 |
reset_matrix = np.array([[1, 0, -o_x], [0, 1, -o_y], [0, 0, 1]]) |
|
|
250 |
transform_matrix = np.dot(np.dot(offset_matrix, matrix), reset_matrix) |
|
|
251 |
return transform_matrix |
|
|
252 |
|
|
|
253 |
|
|
|
254 |
def apply_affine_transform(x, theta=0, tx=0, ty=0, shear=0, zx=1, zy=1, |
|
|
255 |
row_axis=0, col_axis=1, channel_axis=2, |
|
|
256 |
fill_mode='nearest', cval=0.): |
|
|
257 |
"""Applies an affine transformation specified by the parameters given. |
|
|
258 |
|
|
|
259 |
# Arguments |
|
|
260 |
x: 2D numpy array, single image. |
|
|
261 |
theta: Rotation angle in degrees. |
|
|
262 |
tx: Width shift. |
|
|
263 |
ty: Heigh shift. |
|
|
264 |
shear: Shear angle in degrees. |
|
|
265 |
zx: Zoom in x direction. |
|
|
266 |
zy: Zoom in y direction |
|
|
267 |
row_axis: Index of axis for rows in the input image. |
|
|
268 |
col_axis: Index of axis for columns in the input image. |
|
|
269 |
channel_axis: Index of axis for channels in the input image. |
|
|
270 |
fill_mode: Points outside the boundaries of the input |
|
|
271 |
are filled according to the given mode |
|
|
272 |
(one of `{'constant', 'nearest', 'reflect', 'wrap'}`). |
|
|
273 |
cval: Value used for points outside the boundaries |
|
|
274 |
of the input if `mode='constant'`. |
|
|
275 |
|
|
|
276 |
# Returns |
|
|
277 |
The transformed version of the input. |
|
|
278 |
""" |
|
|
279 |
transform_matrix = None |
|
|
280 |
if theta != 0: |
|
|
281 |
theta = np.deg2rad(theta) |
|
|
282 |
rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0], |
|
|
283 |
[np.sin(theta), np.cos(theta), 0], |
|
|
284 |
[0, 0, 1]]) |
|
|
285 |
transform_matrix = rotation_matrix |
|
|
286 |
|
|
|
287 |
if tx != 0 or ty != 0: |
|
|
288 |
shift_matrix = np.array([[1, 0, tx], |
|
|
289 |
[0, 1, ty], |
|
|
290 |
[0, 0, 1]]) |
|
|
291 |
if transform_matrix is None: |
|
|
292 |
transform_matrix = shift_matrix |
|
|
293 |
else: |
|
|
294 |
transform_matrix = np.dot(transform_matrix, shift_matrix) |
|
|
295 |
|
|
|
296 |
if shear != 0: |
|
|
297 |
shear = np.deg2rad(shear) |
|
|
298 |
shear_matrix = np.array([[1, -np.sin(shear), 0], |
|
|
299 |
[0, np.cos(shear), 0], |
|
|
300 |
[0, 0, 1]]) |
|
|
301 |
if transform_matrix is None: |
|
|
302 |
transform_matrix = shear_matrix |
|
|
303 |
else: |
|
|
304 |
transform_matrix = np.dot(transform_matrix, shear_matrix) |
|
|
305 |
|
|
|
306 |
if zx != 1 or zy != 1: |
|
|
307 |
zoom_matrix = np.array([[zx, 0, 0], |
|
|
308 |
[0, zy, 0], |
|
|
309 |
[0, 0, 1]]) |
|
|
310 |
if transform_matrix is None: |
|
|
311 |
transform_matrix = zoom_matrix |
|
|
312 |
else: |
|
|
313 |
transform_matrix = np.dot(transform_matrix, zoom_matrix) |
|
|
314 |
|
|
|
315 |
if transform_matrix is not None: |
|
|
316 |
h, w = x.shape[row_axis], x.shape[col_axis] |
|
|
317 |
transform_matrix = transform_matrix_offset_center( |
|
|
318 |
transform_matrix, h, w) |
|
|
319 |
x = np.rollaxis(x, channel_axis, 0) |
|
|
320 |
final_affine_matrix = transform_matrix[:2, :2] |
|
|
321 |
final_offset = transform_matrix[:2, 2] |
|
|
322 |
|
|
|
323 |
channel_images = [ndi.interpolation.affine_transform( |
|
|
324 |
x_channel, |
|
|
325 |
final_affine_matrix, |
|
|
326 |
final_offset, |
|
|
327 |
order=1, |
|
|
328 |
mode=fill_mode, |
|
|
329 |
cval=cval) for x_channel in x] |
|
|
330 |
x = np.stack(channel_images, axis=0) |
|
|
331 |
x = np.rollaxis(x, 0, channel_axis + 1) |
|
|
332 |
return x |
|
|
333 |
|
|
|
334 |
def rgb2gray(rgb): |
|
|
335 |
r,g,b = rgb[:,:,0], rgb[:,:,1], rgb[:,:,2] |
|
|
336 |
gray = 0.2989* r + 0.5870*g + 0.1140*b |
|
|
337 |
return gray |
|
|
338 |
|
|
|
339 |
|
|
|
340 |
def flip_axis(x, axis): |
|
|
341 |
x = np.asarray(x).swapaxes(axis, 0) |
|
|
342 |
x = x[::-1, ...] |
|
|
343 |
x = x.swapaxes(0, axis) |
|
|
344 |
return x |
|
|
345 |
|
|
|
346 |
|
|
|
347 |
def array_to_img(x, data_format=None, scale=True): |
|
|
348 |
"""Converts a 3D Numpy array to a PIL Image instance. |
|
|
349 |
|
|
|
350 |
# Arguments |
|
|
351 |
x: Input Numpy array. |
|
|
352 |
data_format: Image data format. |
|
|
353 |
either "channels_first" or "channels_last". |
|
|
354 |
scale: Whether to rescale image values |
|
|
355 |
to be within `[0, 255]`. |
|
|
356 |
|
|
|
357 |
# Returns |
|
|
358 |
A PIL Image instance. |
|
|
359 |
|
|
|
360 |
# Raises |
|
|
361 |
ImportError: if PIL is not available. |
|
|
362 |
ValueError: if invalid `x` or `data_format` is passed. |
|
|
363 |
""" |
|
|
364 |
if pil_image is None: |
|
|
365 |
raise ImportError('Could not import PIL.Image. ' |
|
|
366 |
'The use of `array_to_img` requires PIL.') |
|
|
367 |
x = np.asarray(x, dtype=backend.floatx()) |
|
|
368 |
if x.ndim != 3: |
|
|
369 |
raise ValueError('Expected image array to have rank 3 (single image). ' |
|
|
370 |
'Got array with shape:', x.shape) |
|
|
371 |
|
|
|
372 |
if data_format is None: |
|
|
373 |
data_format = backend.image_data_format() |
|
|
374 |
if data_format not in {'channels_first', 'channels_last'}: |
|
|
375 |
raise ValueError('Invalid data_format:', data_format) |
|
|
376 |
|
|
|
377 |
# Original Numpy array x has format (height, width, channel) |
|
|
378 |
# or (channel, height, width) |
|
|
379 |
# but target PIL image has format (width, height, channel) |
|
|
380 |
if data_format == 'channels_first': |
|
|
381 |
x = x.transpose(1, 2, 0) |
|
|
382 |
if scale: |
|
|
383 |
x = x + max(-np.min(x), 0) |
|
|
384 |
x_max = np.max(x) |
|
|
385 |
if x_max != 0: |
|
|
386 |
x /= x_max |
|
|
387 |
x *= 255 |
|
|
388 |
if x.shape[2] == 3: |
|
|
389 |
# RGB |
|
|
390 |
return pil_image.fromarray(x.astype('uint8'), 'RGB') |
|
|
391 |
elif x.shape[2] == 1: |
|
|
392 |
# grayscale |
|
|
393 |
return pil_image.fromarray(x[:, :, 0].astype('uint8'), 'L') |
|
|
394 |
else: |
|
|
395 |
raise ValueError('Unsupported channel number: ', x.shape[2]) |
|
|
396 |
|
|
|
397 |
|
|
|
398 |
def img_to_array(img, data_format=None): |
|
|
399 |
"""Converts a PIL Image instance to a Numpy array. |
|
|
400 |
|
|
|
401 |
# Arguments |
|
|
402 |
img: PIL Image instance. |
|
|
403 |
data_format: Image data format, |
|
|
404 |
either "channels_first" or "channels_last". |
|
|
405 |
|
|
|
406 |
# Returns |
|
|
407 |
A 3D Numpy array. |
|
|
408 |
|
|
|
409 |
# Raises |
|
|
410 |
ValueError: if invalid `img` or `data_format` is passed. |
|
|
411 |
""" |
|
|
412 |
if data_format is None: |
|
|
413 |
data_format = backend.image_data_format() |
|
|
414 |
if data_format not in {'channels_first', 'channels_last'}: |
|
|
415 |
raise ValueError('Unknown data_format: ', data_format) |
|
|
416 |
# Numpy array x has format (height, width, channel) |
|
|
417 |
# or (channel, height, width) |
|
|
418 |
# but original PIL image has format (width, height, channel) |
|
|
419 |
x = np.asarray(img, dtype=backend.floatx()) |
|
|
420 |
if len(x.shape) == 3: |
|
|
421 |
if data_format == 'channels_first': |
|
|
422 |
x = x.transpose(2, 0, 1) |
|
|
423 |
elif len(x.shape) == 2: |
|
|
424 |
if data_format == 'channels_first': |
|
|
425 |
x = x.reshape((1, x.shape[0], x.shape[1])) |
|
|
426 |
else: |
|
|
427 |
x = x.reshape((x.shape[0], x.shape[1], 1)) |
|
|
428 |
else: |
|
|
429 |
raise ValueError('Unsupported image shape: ', x.shape) |
|
|
430 |
return x |
|
|
431 |
|
|
|
432 |
|
|
|
433 |
def save_img(path, |
|
|
434 |
x, |
|
|
435 |
data_format=None, |
|
|
436 |
file_format=None, |
|
|
437 |
scale=True, **kwargs): |
|
|
438 |
"""Saves an image stored as a Numpy array to a path or file object. |
|
|
439 |
|
|
|
440 |
# Arguments |
|
|
441 |
path: Path or file object. |
|
|
442 |
x: Numpy array. |
|
|
443 |
data_format: Image data format, |
|
|
444 |
either "channels_first" or "channels_last". |
|
|
445 |
file_format: Optional file format override. If omitted, the |
|
|
446 |
format to use is determined from the filename extension. |
|
|
447 |
If a file object was used instead of a filename, this |
|
|
448 |
parameter should always be used. |
|
|
449 |
scale: Whether to rescale image values to be within `[0, 255]`. |
|
|
450 |
**kwargs: Additional keyword arguments passed to `PIL.Image.save()`. |
|
|
451 |
""" |
|
|
452 |
img = array_to_img(x, data_format=data_format, scale=scale) |
|
|
453 |
img.save(path, format=file_format, **kwargs) |
|
|
454 |
|
|
|
455 |
|
|
|
456 |
def load_img(path, grayscale=False, target_size=None, |
|
|
457 |
interpolation='nearest'): #nearest |
|
|
458 |
"""Loads an image into PIL format. |
|
|
459 |
|
|
|
460 |
# Arguments |
|
|
461 |
path: Path to image file. |
|
|
462 |
grayscale: Boolean, whether to load the image as grayscale. |
|
|
463 |
target_size: Either `None` (default to original size) |
|
|
464 |
or tuple of ints `(img_height, img_width)`. |
|
|
465 |
interpolation: Interpolation method used to resample the image if the |
|
|
466 |
target size is different from that of the loaded image. |
|
|
467 |
Supported methods are "nearest", "bilinear", and "bicubic". |
|
|
468 |
If PIL version 1.1.3 or newer is installed, "lanczos" is also |
|
|
469 |
supported. If PIL version 3.4.0 or newer is installed, "box" and |
|
|
470 |
"hamming" are also supported. By default, "nearest" is used. |
|
|
471 |
|
|
|
472 |
# Returns |
|
|
473 |
A PIL Image instance. |
|
|
474 |
|
|
|
475 |
# Raises |
|
|
476 |
ImportError: if PIL is not available. |
|
|
477 |
ValueError: if interpolation method is not supported. |
|
|
478 |
""" |
|
|
479 |
if pil_image is None: |
|
|
480 |
raise ImportError('Could not import PIL.Image. ' |
|
|
481 |
'The use of `array_to_img` requires PIL.') |
|
|
482 |
img = pil_image.open(path) |
|
|
483 |
if grayscale: |
|
|
484 |
if img.mode != 'L': |
|
|
485 |
img = img.convert('L') |
|
|
486 |
else: |
|
|
487 |
if img.mode != 'RGB': |
|
|
488 |
img = img.convert('RGB') |
|
|
489 |
if target_size is not None: |
|
|
490 |
width_height_tuple = (target_size[1], target_size[0]) |
|
|
491 |
if img.size != width_height_tuple: |
|
|
492 |
if interpolation not in _PIL_INTERPOLATION_METHODS: |
|
|
493 |
raise ValueError( |
|
|
494 |
'Invalid interpolation method {} specified. Supported ' |
|
|
495 |
'methods are {}'.format( |
|
|
496 |
interpolation, |
|
|
497 |
", ".join(_PIL_INTERPOLATION_METHODS.keys()))) |
|
|
498 |
resample = _PIL_INTERPOLATION_METHODS[interpolation] |
|
|
499 |
img = img.resize(width_height_tuple, resample) |
|
|
500 |
return img |
|
|
501 |
|
|
|
502 |
|
|
|
503 |
def list_pictures(directory, ext='jpg|jpeg|bmp|png|ppm'): |
|
|
504 |
return [os.path.join(root, f) |
|
|
505 |
for root, _, files in os.walk(directory) for f in files |
|
|
506 |
if re.match(r'([\w]+\.(?:' + ext + '))', f.lower())] |
|
|
507 |
|
|
|
508 |
|
|
|
509 |
class ImageDataGenerator(object): |
|
|
510 |
"""Generate batches of tensor image data with real-time data augmentation. |
|
|
511 |
The data will be looped over (in batches). |
|
|
512 |
|
|
|
513 |
# Arguments |
|
|
514 |
featurewise_center: Boolean. |
|
|
515 |
Set input mean to 0 over the dataset, feature-wise. |
|
|
516 |
samplewise_center: Boolean. Set each sample mean to 0. |
|
|
517 |
featurewise_std_normalization: Boolean. |
|
|
518 |
Divide inputs by std of the dataset, feature-wise. |
|
|
519 |
samplewise_std_normalization: Boolean. Divide each input by its std. |
|
|
520 |
zca_epsilon: epsilon for ZCA whitening. Default is 1e-6. |
|
|
521 |
zca_whitening: Boolean. Apply ZCA whitening. |
|
|
522 |
rotation_range: Int. Degree range for random rotations. |
|
|
523 |
width_shift_range: Float, 1-D array-like or int |
|
|
524 |
- float: fraction of total width, if < 1, or pixels if >= 1. |
|
|
525 |
- 1-D array-like: random elements from the array. |
|
|
526 |
- int: integer number of pixels from interval |
|
|
527 |
`(-width_shift_range, +width_shift_range)` |
|
|
528 |
- With `width_shift_range=2` possible values |
|
|
529 |
are integers `[-1, 0, +1]`, |
|
|
530 |
same as with `width_shift_range=[-1, 0, +1]`, |
|
|
531 |
while with `width_shift_range=1.0` possible values are floats |
|
|
532 |
in the interval [-1.0, +1.0). |
|
|
533 |
height_shift_range: Float, 1-D array-like or int |
|
|
534 |
- float: fraction of total height, if < 1, or pixels if >= 1. |
|
|
535 |
- 1-D array-like: random elements from the array. |
|
|
536 |
- int: integer number of pixels from interval |
|
|
537 |
`(-height_shift_range, +height_shift_range)` |
|
|
538 |
- With `height_shift_range=2` possible values |
|
|
539 |
are integers `[-1, 0, +1]`, |
|
|
540 |
same as with `height_shift_range=[-1, 0, +1]`, |
|
|
541 |
while with `height_shift_range=1.0` possible values are floats |
|
|
542 |
in the interval [-1.0, +1.0). |
|
|
543 |
shear_range: Float. Shear Intensity |
|
|
544 |
(Shear angle in counter-clockwise direction in degrees) |
|
|
545 |
zoom_range: Float or [lower, upper]. Range for random zoom. |
|
|
546 |
If a float, `[lower, upper] = [1-zoom_range, 1+zoom_range]`. |
|
|
547 |
channel_shift_range: Float. Range for random channel shifts. |
|
|
548 |
fill_mode: One of {"constant", "nearest", "reflect" or "wrap"}. |
|
|
549 |
Default is 'nearest'. |
|
|
550 |
Points outside the boundaries of the input are filled |
|
|
551 |
according to the given mode: |
|
|
552 |
- 'constant': kkkkkkkk|abcd|kkkkkkkk (cval=k) |
|
|
553 |
- 'nearest': aaaaaaaa|abcd|dddddddd |
|
|
554 |
- 'reflect': abcddcba|abcd|dcbaabcd |
|
|
555 |
- 'wrap': abcdabcd|abcd|abcdabcd |
|
|
556 |
cval: Float or Int. |
|
|
557 |
Value used for points outside the boundaries |
|
|
558 |
when `fill_mode = "constant"`. |
|
|
559 |
horizontal_flip: Boolean. Randomly flip inputs horizontally. |
|
|
560 |
vertical_flip: Boolean. Randomly flip inputs vertically. |
|
|
561 |
rescale: rescaling factor. Defaults to None. |
|
|
562 |
If None or 0, no rescaling is applied, |
|
|
563 |
otherwise we multiply the data by the value provided |
|
|
564 |
(before applying any other transformation). |
|
|
565 |
preprocessing_function: function that will be implied on each input. |
|
|
566 |
The function will run after the image is resized and augmented. |
|
|
567 |
The function should take one argument: |
|
|
568 |
one image (Numpy tensor with rank 3), |
|
|
569 |
and should output a Numpy tensor with the same shape. |
|
|
570 |
data_format: Image data format, |
|
|
571 |
either "channels_first" or "channels_last". |
|
|
572 |
"channels_last" mode means that the images should have shape |
|
|
573 |
`(samples, height, width, channels)`, |
|
|
574 |
"channels_first" mode means that the images should have shape |
|
|
575 |
`(samples, channels, height, width)`. |
|
|
576 |
It defaults to the `image_data_format` value found in your |
|
|
577 |
Keras config file at `~/.keras/keras.json`. |
|
|
578 |
If you never set it, then it will be "channels_last". |
|
|
579 |
validation_split: Float. Fraction of images reserved for validation |
|
|
580 |
(strictly between 0 and 1). |
|
|
581 |
|
|
|
582 |
# Examples |
|
|
583 |
Example of using `.flow(x, y)`: |
|
|
584 |
|
|
|
585 |
```python |
|
|
586 |
(x_train, y_train), (x_test, y_test) = cifar10.load_data() |
|
|
587 |
y_train = np_utils.to_categorical(y_train, num_classes) |
|
|
588 |
y_test = np_utils.to_categorical(y_test, num_classes) |
|
|
589 |
|
|
|
590 |
datagen = ImageDataGenerator( |
|
|
591 |
featurewise_center=True, |
|
|
592 |
featurewise_std_normalization=True, |
|
|
593 |
rotation_range=20, |
|
|
594 |
width_shift_range=0.2, |
|
|
595 |
height_shift_range=0.2, |
|
|
596 |
horizontal_flip=True) |
|
|
597 |
|
|
|
598 |
# compute quantities required for featurewise normalization |
|
|
599 |
# (std, mean, and principal components if ZCA whitening is applied) |
|
|
600 |
datagen.fit(x_train) |
|
|
601 |
|
|
|
602 |
# fits the model on batches with real-time data augmentation: |
|
|
603 |
model.fit_generator(datagen.flow(x_train, y_train, batch_size=32), |
|
|
604 |
steps_per_epoch=len(x_train) / 32, epochs=epochs) |
|
|
605 |
|
|
|
606 |
# here's a more "manual" example |
|
|
607 |
for e in range(epochs): |
|
|
608 |
print('Epoch', e) |
|
|
609 |
batches = 0 |
|
|
610 |
for x_batch, y_batch in datagen.flow(x_train, y_train, batch_size=32): |
|
|
611 |
model.fit(x_batch, y_batch) |
|
|
612 |
batches += 1 |
|
|
613 |
if batches >= len(x_train) / 32: |
|
|
614 |
# we need to break the loop by hand because |
|
|
615 |
# the generator loops indefinitely |
|
|
616 |
break |
|
|
617 |
``` |
|
|
618 |
Example of using `.flow_from_directory(directory)`: |
|
|
619 |
|
|
|
620 |
```python |
|
|
621 |
train_datagen = ImageDataGenerator( |
|
|
622 |
rescale=1./255, |
|
|
623 |
shear_range=0.2, |
|
|
624 |
zoom_range=0.2, |
|
|
625 |
horizontal_flip=True) |
|
|
626 |
|
|
|
627 |
test_datagen = ImageDataGenerator(rescale=1./255) |
|
|
628 |
|
|
|
629 |
train_generator = train_datagen.flow_from_directory( |
|
|
630 |
'data/train', |
|
|
631 |
target_size=(150, 150), |
|
|
632 |
batch_size=32, |
|
|
633 |
class_mode='binary') |
|
|
634 |
|
|
|
635 |
validation_generator = test_datagen.flow_from_directory( |
|
|
636 |
'data/validation', |
|
|
637 |
target_size=(150, 150), |
|
|
638 |
batch_size=32, |
|
|
639 |
class_mode='binary') |
|
|
640 |
|
|
|
641 |
model.fit_generator( |
|
|
642 |
train_generator, |
|
|
643 |
steps_per_epoch=2000, |
|
|
644 |
epochs=50, |
|
|
645 |
validation_data=validation_generator, |
|
|
646 |
validation_steps=800) |
|
|
647 |
``` |
|
|
648 |
|
|
|
649 |
Example of transforming images and masks together. |
|
|
650 |
|
|
|
651 |
```python |
|
|
652 |
# we create two instances with the same arguments |
|
|
653 |
data_gen_args = dict(featurewise_center=True, |
|
|
654 |
featurewise_std_normalization=True, |
|
|
655 |
rotation_range=90., |
|
|
656 |
width_shift_range=0.1, |
|
|
657 |
height_shift_range=0.1, |
|
|
658 |
zoom_range=0.2) |
|
|
659 |
image_datagen = ImageDataGenerator(**data_gen_args) |
|
|
660 |
mask_datagen = ImageDataGenerator(**data_gen_args) |
|
|
661 |
|
|
|
662 |
# Provide the same seed and keyword arguments to the fit and flow methods |
|
|
663 |
seed = 1 |
|
|
664 |
image_datagen.fit(images, augment=True, seed=seed) |
|
|
665 |
mask_datagen.fit(masks, augment=True, seed=seed) |
|
|
666 |
|
|
|
667 |
image_generator = image_datagen.flow_from_directory( |
|
|
668 |
'data/images', |
|
|
669 |
class_mode=None, |
|
|
670 |
seed=seed) |
|
|
671 |
|
|
|
672 |
mask_generator = mask_datagen.flow_from_directory( |
|
|
673 |
'data/masks', |
|
|
674 |
class_mode=None, |
|
|
675 |
seed=seed) |
|
|
676 |
|
|
|
677 |
# combine generators into one which yields image and masks |
|
|
678 |
train_generator = zip(image_generator, mask_generator) |
|
|
679 |
|
|
|
680 |
model.fit_generator( |
|
|
681 |
train_generator, |
|
|
682 |
steps_per_epoch=2000, |
|
|
683 |
epochs=50) |
|
|
684 |
``` |
|
|
685 |
""" |
|
|
686 |
|
|
|
687 |
def __init__(self, |
|
|
688 |
contrast_stretching=False, |
|
|
689 |
histogram_equalization=False, |
|
|
690 |
adaptive_equalization=False, |
|
|
691 |
featurewise_center=False, |
|
|
692 |
samplewise_center=False, |
|
|
693 |
featurewise_std_normalization=False, |
|
|
694 |
samplewise_std_normalization=False, |
|
|
695 |
zca_whitening=False, |
|
|
696 |
zca_epsilon=1e-6, |
|
|
697 |
rotation_range=0., |
|
|
698 |
width_shift_range=0., |
|
|
699 |
height_shift_range=0., |
|
|
700 |
brightness_range=None, |
|
|
701 |
shear_range=0., |
|
|
702 |
zoom_range=0., |
|
|
703 |
channel_shift_range=0., |
|
|
704 |
fill_mode='nearest', |
|
|
705 |
cval=0., |
|
|
706 |
horizontal_flip=False, |
|
|
707 |
vertical_flip=False, |
|
|
708 |
rescale=None, |
|
|
709 |
preprocessing_function=None, |
|
|
710 |
data_format=None, |
|
|
711 |
validation_split=0.0): |
|
|
712 |
if data_format is None: |
|
|
713 |
data_format = backend.image_data_format() |
|
|
714 |
self.contrast_stretching = contrast_stretching |
|
|
715 |
self.histogram_equalization = histogram_equalization |
|
|
716 |
self.adaptive_equalization = adaptive_equalization |
|
|
717 |
self.featurewise_center = featurewise_center |
|
|
718 |
self.samplewise_center = samplewise_center |
|
|
719 |
self.featurewise_std_normalization = featurewise_std_normalization |
|
|
720 |
self.samplewise_std_normalization = samplewise_std_normalization |
|
|
721 |
self.zca_whitening = zca_whitening |
|
|
722 |
self.zca_epsilon = zca_epsilon |
|
|
723 |
self.rotation_range = rotation_range |
|
|
724 |
self.width_shift_range = width_shift_range |
|
|
725 |
self.height_shift_range = height_shift_range |
|
|
726 |
self.brightness_range = brightness_range |
|
|
727 |
self.shear_range = shear_range |
|
|
728 |
self.zoom_range = zoom_range |
|
|
729 |
self.channel_shift_range = channel_shift_range |
|
|
730 |
self.fill_mode = fill_mode |
|
|
731 |
self.cval = cval |
|
|
732 |
self.horizontal_flip = horizontal_flip |
|
|
733 |
self.vertical_flip = vertical_flip |
|
|
734 |
self.rescale = rescale |
|
|
735 |
self.preprocessing_function = preprocessing_function |
|
|
736 |
|
|
|
737 |
if data_format not in {'channels_last', 'channels_first'}: |
|
|
738 |
raise ValueError( |
|
|
739 |
'`data_format` should be `"channels_last"` ' |
|
|
740 |
'(channel after row and column) or ' |
|
|
741 |
'`"channels_first"` (channel before row and column). ' |
|
|
742 |
'Received: %s' % data_format) |
|
|
743 |
self.data_format = data_format |
|
|
744 |
if data_format == 'channels_first': |
|
|
745 |
self.channel_axis = 1 |
|
|
746 |
self.row_axis = 2 |
|
|
747 |
self.col_axis = 3 |
|
|
748 |
if data_format == 'channels_last': |
|
|
749 |
self.channel_axis = 3 |
|
|
750 |
self.row_axis = 1 |
|
|
751 |
self.col_axis = 2 |
|
|
752 |
if validation_split and not 0 < validation_split < 1: |
|
|
753 |
raise ValueError( |
|
|
754 |
'`validation_split` must be strictly between 0 and 1. ' |
|
|
755 |
' Received: %s' % validation_split) |
|
|
756 |
self._validation_split = validation_split |
|
|
757 |
|
|
|
758 |
self.mean = None |
|
|
759 |
self.std = None |
|
|
760 |
self.principal_components = None |
|
|
761 |
|
|
|
762 |
if np.isscalar(zoom_range): |
|
|
763 |
self.zoom_range = [1 - zoom_range, 1 + zoom_range] |
|
|
764 |
elif len(zoom_range) == 2: |
|
|
765 |
self.zoom_range = [zoom_range[0], zoom_range[1]] |
|
|
766 |
else: |
|
|
767 |
raise ValueError('`zoom_range` should be a float or ' |
|
|
768 |
'a tuple or list of two floats. ' |
|
|
769 |
'Received: %s' % zoom_range) |
|
|
770 |
if zca_whitening: |
|
|
771 |
if not featurewise_center: |
|
|
772 |
self.featurewise_center = True |
|
|
773 |
warnings.warn('This ImageDataGenerator specifies ' |
|
|
774 |
'`zca_whitening`, which overrides ' |
|
|
775 |
'setting of `featurewise_center`.') |
|
|
776 |
if featurewise_std_normalization: |
|
|
777 |
self.featurewise_std_normalization = False |
|
|
778 |
warnings.warn('This ImageDataGenerator specifies ' |
|
|
779 |
'`zca_whitening` ' |
|
|
780 |
'which overrides setting of' |
|
|
781 |
'`featurewise_std_normalization`.') |
|
|
782 |
if featurewise_std_normalization: |
|
|
783 |
if not featurewise_center: |
|
|
784 |
self.featurewise_center = True |
|
|
785 |
warnings.warn('This ImageDataGenerator specifies ' |
|
|
786 |
'`featurewise_std_normalization`, ' |
|
|
787 |
'which overrides setting of ' |
|
|
788 |
'`featurewise_center`.') |
|
|
789 |
if samplewise_std_normalization: |
|
|
790 |
if not samplewise_center: |
|
|
791 |
self.samplewise_center = True |
|
|
792 |
warnings.warn('This ImageDataGenerator specifies ' |
|
|
793 |
'`samplewise_std_normalization`, ' |
|
|
794 |
'which overrides setting of ' |
|
|
795 |
'`samplewise_center`.') |
|
|
796 |
|
|
|
797 |
def flow(self, x, |
|
|
798 |
y=None, batch_size=32, shuffle=True, |
|
|
799 |
sample_weight=None, seed=None, |
|
|
800 |
save_to_dir=None, save_prefix='', save_format='png', subset=None): |
|
|
801 |
"""Takes data & label arrays, generates batches of augmented data. |
|
|
802 |
|
|
|
803 |
# Arguments |
|
|
804 |
x: Input data. Numpy array of rank 4 or a tuple. |
|
|
805 |
If tuple, the first element |
|
|
806 |
should contain the images and the second element |
|
|
807 |
another numpy array or a list of numpy arrays |
|
|
808 |
that gets passed to the output |
|
|
809 |
without any modifications. |
|
|
810 |
Can be used to feed the model miscellaneous data |
|
|
811 |
along with the images. |
|
|
812 |
In case of grayscale data, the channels axis of the image array |
|
|
813 |
should have value 1, and in case |
|
|
814 |
of RGB data, it should have value 3. |
|
|
815 |
y: Labels. |
|
|
816 |
batch_size: Int (default: 32). |
|
|
817 |
shuffle: Boolean (default: True). |
|
|
818 |
sample_weight: Sample weights. |
|
|
819 |
seed: Int (default: None). |
|
|
820 |
save_to_dir: None or str (default: None). |
|
|
821 |
This allows you to optionally specify a directory |
|
|
822 |
to which to save the augmented pictures being generated |
|
|
823 |
(useful for visualizing what you are doing). |
|
|
824 |
save_prefix: Str (default: `''`). |
|
|
825 |
Prefix to use for filenames of saved pictures |
|
|
826 |
(only relevant if `save_to_dir` is set). |
|
|
827 |
save_format: one of "png", "jpeg" |
|
|
828 |
(only relevant if `save_to_dir` is set). Default: "png". |
|
|
829 |
subset: Subset of data (`"training"` or `"validation"`) if |
|
|
830 |
`validation_split` is set in `ImageDataGenerator`. |
|
|
831 |
|
|
|
832 |
# Returns |
|
|
833 |
An `Iterator` yielding tuples of `(x, y)` |
|
|
834 |
where `x` is a numpy array of image data |
|
|
835 |
(in the case of a single image input) or a list |
|
|
836 |
of numpy arrays (in the case with |
|
|
837 |
additional inputs) and `y` is a numpy array |
|
|
838 |
of corresponding labels. If 'sample_weight' is not None, |
|
|
839 |
the yielded tuples are of the form `(x, y, sample_weight)`. |
|
|
840 |
If `y` is None, only the numpy array `x` is returned. |
|
|
841 |
""" |
|
|
842 |
return NumpyArrayIterator( |
|
|
843 |
x, y, self, |
|
|
844 |
batch_size=batch_size, |
|
|
845 |
shuffle=shuffle, |
|
|
846 |
sample_weight=sample_weight, |
|
|
847 |
seed=seed, |
|
|
848 |
data_format=self.data_format, |
|
|
849 |
save_to_dir=save_to_dir, |
|
|
850 |
save_prefix=save_prefix, |
|
|
851 |
save_format=save_format, |
|
|
852 |
subset=subset) |
|
|
853 |
|
|
|
854 |
def flow_from_directory(self, directory, |
|
|
855 |
target_size=(256, 256), color_mode='rgb', |
|
|
856 |
classes=None, class_mode='categorical', |
|
|
857 |
batch_size=32, shuffle=True, seed=None, |
|
|
858 |
save_to_dir=None, |
|
|
859 |
save_prefix='', |
|
|
860 |
save_format='png', |
|
|
861 |
follow_links=False, |
|
|
862 |
subset=None, |
|
|
863 |
interpolation='nearest'): |
|
|
864 |
"""Takes the path to a directory & generates batches of augmented data. |
|
|
865 |
|
|
|
866 |
# Arguments |
|
|
867 |
directory: Path to the target directory. |
|
|
868 |
It should contain one subdirectory per class. |
|
|
869 |
Any PNG, JPG, BMP, PPM or TIF images |
|
|
870 |
inside each of the subdirectories directory tree |
|
|
871 |
will be included in the generator. |
|
|
872 |
See [this script]( |
|
|
873 |
https://gist.github.com/fchollet/ |
|
|
874 |
0830affa1f7f19fd47b06d4cf89ed44d) |
|
|
875 |
for more details. |
|
|
876 |
target_size: Tuple of integers `(height, width)`, |
|
|
877 |
default: `(256, 256)`. |
|
|
878 |
The dimensions to which all images found will be resized. |
|
|
879 |
color_mode: One of "grayscale", "rbg". Default: "rgb". |
|
|
880 |
Whether the images will be converted to |
|
|
881 |
have 1 or 3 color channels. |
|
|
882 |
classes: Optional list of class subdirectories |
|
|
883 |
(e.g. `['dogs', 'cats']`). Default: None. |
|
|
884 |
If not provided, the list of classes will be automatically |
|
|
885 |
inferred from the subdirectory names/structure |
|
|
886 |
under `directory`, where each subdirectory will |
|
|
887 |
be treated as a different class |
|
|
888 |
(and the order of the classes, which will map to the label |
|
|
889 |
indices, will be alphanumeric). |
|
|
890 |
The dictionary containing the mapping from class names to class |
|
|
891 |
indices can be obtained via the attribute `class_indices`. |
|
|
892 |
class_mode: One of "categorical", "binary", "sparse", |
|
|
893 |
"input", or None. Default: "categorical". |
|
|
894 |
Determines the type of label arrays that are returned: |
|
|
895 |
- "categorical" will be 2D one-hot encoded labels, |
|
|
896 |
- "binary" will be 1D binary labels, |
|
|
897 |
"sparse" will be 1D integer labels, |
|
|
898 |
- "input" will be images identical |
|
|
899 |
to input images (mainly used to work with autoencoders). |
|
|
900 |
- If None, no labels are returned |
|
|
901 |
(the generator will only yield batches of image data, |
|
|
902 |
which is useful to use with `model.predict_generator()`, |
|
|
903 |
`model.evaluate_generator()`, etc.). |
|
|
904 |
Please note that in case of class_mode None, |
|
|
905 |
the data still needs to reside in a subdirectory |
|
|
906 |
of `directory` for it to work correctly. |
|
|
907 |
batch_size: Size of the batches of data (default: 32). |
|
|
908 |
shuffle: Whether to shuffle the data (default: True) |
|
|
909 |
seed: Optional random seed for shuffling and transformations. |
|
|
910 |
save_to_dir: None or str (default: None). |
|
|
911 |
This allows you to optionally specify |
|
|
912 |
a directory to which to save |
|
|
913 |
the augmented pictures being generated |
|
|
914 |
(useful for visualizing what you are doing). |
|
|
915 |
save_prefix: Str. Prefix to use for filenames of saved pictures |
|
|
916 |
(only relevant if `save_to_dir` is set). |
|
|
917 |
save_format: One of "png", "jpeg" |
|
|
918 |
(only relevant if `save_to_dir` is set). Default: "png". |
|
|
919 |
follow_links: Whether to follow symlinks inside |
|
|
920 |
class subdirectories (default: False). |
|
|
921 |
subset: Subset of data (`"training"` or `"validation"`) if |
|
|
922 |
`validation_split` is set in `ImageDataGenerator`. |
|
|
923 |
interpolation: Interpolation method used to |
|
|
924 |
resample the image if the |
|
|
925 |
target size is different from that of the loaded image. |
|
|
926 |
Supported methods are `"nearest"`, `"bilinear"`, |
|
|
927 |
and `"bicubic"`. |
|
|
928 |
If PIL version 1.1.3 or newer is installed, `"lanczos"` is also |
|
|
929 |
supported. If PIL version 3.4.0 or newer is installed, |
|
|
930 |
`"box"` and `"hamming"` are also supported. |
|
|
931 |
By default, `"nearest"` is used. |
|
|
932 |
|
|
|
933 |
# Returns |
|
|
934 |
A `DirectoryIterator` yielding tuples of `(x, y)` |
|
|
935 |
where `x` is a numpy array containing a batch |
|
|
936 |
of images with shape `(batch_size, *target_size, channels)` |
|
|
937 |
and `y` is a numpy array of corresponding labels. |
|
|
938 |
""" |
|
|
939 |
return DirectoryIterator( |
|
|
940 |
directory, self, |
|
|
941 |
target_size=target_size, color_mode=color_mode, |
|
|
942 |
classes=classes, class_mode=class_mode, |
|
|
943 |
data_format=self.data_format, |
|
|
944 |
batch_size=batch_size, shuffle=shuffle, seed=seed, |
|
|
945 |
save_to_dir=save_to_dir, |
|
|
946 |
save_prefix=save_prefix, |
|
|
947 |
save_format=save_format, |
|
|
948 |
follow_links=follow_links, |
|
|
949 |
subset=subset, |
|
|
950 |
interpolation=interpolation) |
|
|
951 |
|
|
|
952 |
def standardize(self, x): |
|
|
953 |
"""Applies the normalization configuration to a batch of inputs. |
|
|
954 |
|
|
|
955 |
# Arguments |
|
|
956 |
x: Batch of inputs to be normalized. |
|
|
957 |
|
|
|
958 |
# Returns |
|
|
959 |
The inputs, normalized. |
|
|
960 |
""" |
|
|
961 |
imagenet_mean = np.array([0.485, 0.456, 0.406]) |
|
|
962 |
imagenet_std = np.array([0.229, 0.224, 0.225]) |
|
|
963 |
|
|
|
964 |
if self.rescale: |
|
|
965 |
x *= self.rescale |
|
|
966 |
if self.preprocessing_function: |
|
|
967 |
x = self.preprocessing_function(x) |
|
|
968 |
# if self.rescale: |
|
|
969 |
# x *= self.rescale |
|
|
970 |
if self.samplewise_center: |
|
|
971 |
x -= np.mean(x, keepdims=True) |
|
|
972 |
if self.samplewise_std_normalization: |
|
|
973 |
x /= (np.std(x, keepdims=True) + backend.epsilon()) |
|
|
974 |
|
|
|
975 |
#x = (x - imagenet_mean) / imagenet_std |
|
|
976 |
|
|
|
977 |
if self.featurewise_center: |
|
|
978 |
if self.mean is not None: |
|
|
979 |
x -= self.mean |
|
|
980 |
else: |
|
|
981 |
warnings.warn('This ImageDataGenerator specifies ' |
|
|
982 |
'`featurewise_center`, but it hasn\'t ' |
|
|
983 |
'been fit on any training data. Fit it ' |
|
|
984 |
'first by calling `.fit(numpy_data)`.') |
|
|
985 |
if self.featurewise_std_normalization: |
|
|
986 |
if self.std is not None: |
|
|
987 |
x /= (self.std + backend.epsilon()) |
|
|
988 |
else: |
|
|
989 |
warnings.warn('This ImageDataGenerator specifies ' |
|
|
990 |
'`featurewise_std_normalization`, ' |
|
|
991 |
'but it hasn\'t ' |
|
|
992 |
'been fit on any training data. Fit it ' |
|
|
993 |
'first by calling `.fit(numpy_data)`.') |
|
|
994 |
if self.zca_whitening: |
|
|
995 |
if self.principal_components is not None: |
|
|
996 |
flatx = np.reshape(x, (-1, np.prod(x.shape[-3:]))) |
|
|
997 |
whitex = np.dot(flatx, self.principal_components) |
|
|
998 |
x = np.reshape(whitex, x.shape) |
|
|
999 |
else: |
|
|
1000 |
warnings.warn('This ImageDataGenerator specifies ' |
|
|
1001 |
'`zca_whitening`, but it hasn\'t ' |
|
|
1002 |
'been fit on any training data. Fit it ' |
|
|
1003 |
'first by calling ') |
|
|
1004 |
|
|
|
1005 |
|
|
|
1006 |
# if self.contrast_stretching: |
|
|
1007 |
# if np.random.random() < 0.5: |
|
|
1008 |
# p2, p98 = np.percentile((x),(2,98)) |
|
|
1009 |
# x = (exposure.rescale_intensity((x), in_range=(p2, p98))) |
|
|
1010 |
|
|
|
1011 |
# if self.adaptive_equalization: |
|
|
1012 |
# if np.random.random() < 0.5: |
|
|
1013 |
# x = (exposure.equalize_adapthist((x), clip_limit = 0.03)) |
|
|
1014 |
|
|
|
1015 |
# if self.histogram_equalization: |
|
|
1016 |
# if np.random.random() < 0.5: |
|
|
1017 |
# x = (exposure.equalize_hist((x))) |
|
|
1018 |
|
|
|
1019 |
|
|
|
1020 |
return x |
|
|
1021 |
|
|
|
1022 |
|
|
|
1023 |
def get_random_transform(self, img_shape, seed=None): |
|
|
1024 |
"""Generates random parameters for a transformation. |
|
|
1025 |
|
|
|
1026 |
# Arguments |
|
|
1027 |
seed: Random seed. |
|
|
1028 |
img_shape: Tuple of integers. |
|
|
1029 |
Shape of the image that is transformed. |
|
|
1030 |
|
|
|
1031 |
# Returns |
|
|
1032 |
A dictionary containing randomly chosen parameters describing the |
|
|
1033 |
transformation. |
|
|
1034 |
""" |
|
|
1035 |
img_row_axis = self.row_axis - 1 |
|
|
1036 |
img_col_axis = self.col_axis - 1 |
|
|
1037 |
|
|
|
1038 |
if seed is not None: |
|
|
1039 |
np.random.seed(seed) |
|
|
1040 |
|
|
|
1041 |
if self.rotation_range: |
|
|
1042 |
theta = np.random.uniform( |
|
|
1043 |
-self.rotation_range, |
|
|
1044 |
self.rotation_range) |
|
|
1045 |
else: |
|
|
1046 |
theta = 0 |
|
|
1047 |
|
|
|
1048 |
if self.height_shift_range: |
|
|
1049 |
try: # 1-D array-like or int |
|
|
1050 |
tx = np.random.choice(self.height_shift_range) |
|
|
1051 |
tx *= np.random.choice([-1, 1]) |
|
|
1052 |
except ValueError: # floating point |
|
|
1053 |
tx = np.random.uniform(-self.height_shift_range, |
|
|
1054 |
self.height_shift_range) |
|
|
1055 |
if np.max(self.height_shift_range) < 1: |
|
|
1056 |
tx *= img_shape[img_row_axis] |
|
|
1057 |
else: |
|
|
1058 |
tx = 0 |
|
|
1059 |
|
|
|
1060 |
if self.width_shift_range: |
|
|
1061 |
try: # 1-D array-like or int |
|
|
1062 |
ty = np.random.choice(self.width_shift_range) |
|
|
1063 |
ty *= np.random.choice([-1, 1]) |
|
|
1064 |
except ValueError: # floating point |
|
|
1065 |
ty = np.random.uniform(-self.width_shift_range, |
|
|
1066 |
self.width_shift_range) |
|
|
1067 |
if np.max(self.width_shift_range) < 1: |
|
|
1068 |
ty *= img_shape[img_col_axis] |
|
|
1069 |
else: |
|
|
1070 |
ty = 0 |
|
|
1071 |
|
|
|
1072 |
if self.shear_range: |
|
|
1073 |
shear = np.random.uniform( |
|
|
1074 |
-self.shear_range, |
|
|
1075 |
self.shear_range) |
|
|
1076 |
else: |
|
|
1077 |
shear = 0 |
|
|
1078 |
|
|
|
1079 |
if self.zoom_range[0] == 1 and self.zoom_range[1] == 1: |
|
|
1080 |
zx, zy = 1, 1 |
|
|
1081 |
else: |
|
|
1082 |
zx, zy = np.random.uniform( |
|
|
1083 |
self.zoom_range[0], |
|
|
1084 |
self.zoom_range[1], |
|
|
1085 |
2) |
|
|
1086 |
|
|
|
1087 |
flip_horizontal = (np.random.random() < 0.5) * self.horizontal_flip |
|
|
1088 |
flip_vertical = (np.random.random() < 0.5) * self.vertical_flip |
|
|
1089 |
|
|
|
1090 |
channel_shift_intensity = None |
|
|
1091 |
if self.channel_shift_range != 0: |
|
|
1092 |
channel_shift_intensity = np.random.uniform(-self.channel_shift_range, |
|
|
1093 |
self.channel_shift_range) |
|
|
1094 |
|
|
|
1095 |
brightness = None |
|
|
1096 |
if self.brightness_range is not None: |
|
|
1097 |
if len(self.brightness_range) != 2: |
|
|
1098 |
raise ValueError( |
|
|
1099 |
'`brightness_range should be tuple or list of two floats. ' |
|
|
1100 |
'Received: %s' % brightness_range) |
|
|
1101 |
brightness = np.random.uniform(self.brightness_range[0], |
|
|
1102 |
self.brightness_range[1]) |
|
|
1103 |
|
|
|
1104 |
transform_parameters = {'theta': theta, |
|
|
1105 |
'tx': tx, |
|
|
1106 |
'ty': ty, |
|
|
1107 |
'shear': shear, |
|
|
1108 |
'zx': zx, |
|
|
1109 |
'zy': zy, |
|
|
1110 |
'flip_horizontal': flip_horizontal, |
|
|
1111 |
'flip_vertical': flip_vertical, |
|
|
1112 |
'channel_shift_intensity': channel_shift_intensity, |
|
|
1113 |
'brightness': brightness, |
|
|
1114 |
'contrast_stretching' : self.contrast_stretching, |
|
|
1115 |
'adaptive_equalization' : self.adaptive_equalization, |
|
|
1116 |
'histogram_equalization' : self.histogram_equalization |
|
|
1117 |
} |
|
|
1118 |
|
|
|
1119 |
return transform_parameters |
|
|
1120 |
|
|
|
1121 |
def apply_transform(self, x, transform_parameters): |
|
|
1122 |
"""Applies a transformation to an image according to given parameters. |
|
|
1123 |
|
|
|
1124 |
# Arguments |
|
|
1125 |
x: 3D tensor, single image. |
|
|
1126 |
transform_parameters: Dictionary with string - parameter pairs |
|
|
1127 |
describing the transformation. |
|
|
1128 |
Currently, the following parameters |
|
|
1129 |
from the dictionary are used: |
|
|
1130 |
- `'theta'`: Float. Rotation angle in degrees. |
|
|
1131 |
- `'tx'`: Float. Shift in the x direction. |
|
|
1132 |
- `'ty'`: Float. Shift in the y direction. |
|
|
1133 |
- `'shear'`: Float. Shear angle in degrees. |
|
|
1134 |
- `'zx'`: Float. Zoom in the x direction. |
|
|
1135 |
- `'zy'`: Float. Zoom in the y direction. |
|
|
1136 |
- `'flip_horizontal'`: Boolean. Horizontal flip. |
|
|
1137 |
- `'flip_vertical'`: Boolean. Vertical flip. |
|
|
1138 |
- `'channel_shift_intencity'`: Float. Channel shift intensity. |
|
|
1139 |
- `'brightness'`: Float. Brightness shift intensity. |
|
|
1140 |
|
|
|
1141 |
# Returns |
|
|
1142 |
A ransformed version of the input (same shape). |
|
|
1143 |
""" |
|
|
1144 |
# x is a single image, so it doesn't have image number at index 0 |
|
|
1145 |
img_row_axis = self.row_axis - 1 |
|
|
1146 |
img_col_axis = self.col_axis - 1 |
|
|
1147 |
img_channel_axis = self.channel_axis - 1 |
|
|
1148 |
|
|
|
1149 |
x = apply_affine_transform(x, transform_parameters.get('theta', 0), |
|
|
1150 |
transform_parameters.get('tx', 0), |
|
|
1151 |
transform_parameters.get('ty', 0), |
|
|
1152 |
transform_parameters.get('shear', 0), |
|
|
1153 |
transform_parameters.get('zx', 1), |
|
|
1154 |
transform_parameters.get('zy', 1), |
|
|
1155 |
row_axis=img_row_axis, col_axis=img_col_axis, |
|
|
1156 |
channel_axis=img_channel_axis, |
|
|
1157 |
fill_mode=self.fill_mode, cval=self.cval) |
|
|
1158 |
|
|
|
1159 |
if transform_parameters.get('channel_shift_intensity') is not None: |
|
|
1160 |
x = apply_channel_shift(x, |
|
|
1161 |
transform_parameters['channel_shift_intensity'], |
|
|
1162 |
img_channel_axis) |
|
|
1163 |
|
|
|
1164 |
if transform_parameters.get('flip_horizontal', False): |
|
|
1165 |
x = flip_axis(x, img_col_axis) |
|
|
1166 |
|
|
|
1167 |
if transform_parameters.get('flip_vertical', False): |
|
|
1168 |
x = flip_axis(x, img_row_axis) |
|
|
1169 |
|
|
|
1170 |
if transform_parameters.get('brightness') is not None: |
|
|
1171 |
x = apply_brightness_shift(x, transform_parameters['brightness']) |
|
|
1172 |
|
|
|
1173 |
|
|
|
1174 |
|
|
|
1175 |
if transform_parameters.get('contrast_stretching') is not None: |
|
|
1176 |
if np.random.random() < 1.0: |
|
|
1177 |
x = img_to_array(x) |
|
|
1178 |
p2, p98 = np.percentile((x),(2,98)) |
|
|
1179 |
x = (exposure.rescale_intensity((x), in_range=(p2, p98))) |
|
|
1180 |
# x = x.reshape((x.shape[0], x.shape[1],3)) |
|
|
1181 |
|
|
|
1182 |
# if transform_parameters.get('adaptive_equalization') is not None: |
|
|
1183 |
# if np.random.random() < 1.0: |
|
|
1184 |
# x = (exposure.equalize_adapthist(x/255, clip_limit = 0.03)) |
|
|
1185 |
# x = x.reshape((x.shape[0], x.shape[1],1)) |
|
|
1186 |
|
|
|
1187 |
if transform_parameters.get('histogram_equalization') is not None: |
|
|
1188 |
if np.random.random() < 1.0: |
|
|
1189 |
x[:,:,0] = exposure.equalize_hist(x[:,:,0]) |
|
|
1190 |
x[:,:,1] = exposure.equalize_hist(x[:,:,1]) |
|
|
1191 |
x[:,:,2] = exposure.equalize_hist(x[:,:,2]) |
|
|
1192 |
|
|
|
1193 |
# x = x.reshape((x.shape[0], x.shape[1],3)) |
|
|
1194 |
# x = x.reshape((x.shape[0], x.shape[1], 1)) |
|
|
1195 |
|
|
|
1196 |
|
|
|
1197 |
return x |
|
|
1198 |
|
|
|
1199 |
def random_transform(self, x, seed=None): |
|
|
1200 |
"""Applies a random transformation to an image. |
|
|
1201 |
|
|
|
1202 |
# Arguments |
|
|
1203 |
x: 3D tensor, single image. |
|
|
1204 |
seed: Random seed. |
|
|
1205 |
|
|
|
1206 |
# Returns |
|
|
1207 |
A randomly transformed version of the input (same shape). |
|
|
1208 |
""" |
|
|
1209 |
params = self.get_random_transform(x.shape, seed) |
|
|
1210 |
return self.apply_transform(x, params) |
|
|
1211 |
|
|
|
1212 |
def fit(self, x, |
|
|
1213 |
augment=False, |
|
|
1214 |
rounds=1, |
|
|
1215 |
seed=None): |
|
|
1216 |
"""Fits the data generator to some sample data. |
|
|
1217 |
|
|
|
1218 |
This computes the internal data stats related to the |
|
|
1219 |
data-dependent transformations, based on an array of sample data. |
|
|
1220 |
|
|
|
1221 |
Only required if `featurewise_center` or |
|
|
1222 |
`featurewise_std_normalization` or `zca_whitening` are set to True. |
|
|
1223 |
|
|
|
1224 |
# Arguments |
|
|
1225 |
x: Sample data. Should have rank 4. |
|
|
1226 |
In case of grayscale data, |
|
|
1227 |
the channels axis should have value 1, and in case |
|
|
1228 |
of RGB data, it should have value 3. |
|
|
1229 |
augment: Boolean (default: False). |
|
|
1230 |
Whether to fit on randomly augmented samples. |
|
|
1231 |
rounds: Int (default: 1). |
|
|
1232 |
If using data augmentation (`augment=True`), |
|
|
1233 |
this is how many augmentation passes over the data to use. |
|
|
1234 |
seed: Int (default: None). Random seed. |
|
|
1235 |
""" |
|
|
1236 |
x = np.asarray(x, dtype=backend.floatx()) |
|
|
1237 |
if x.ndim != 4: |
|
|
1238 |
raise ValueError('Input to `.fit()` should have rank 4. ' |
|
|
1239 |
'Got array with shape: ' + str(x.shape)) |
|
|
1240 |
if x.shape[self.channel_axis] not in {1, 3, 4}: |
|
|
1241 |
warnings.warn( |
|
|
1242 |
'Expected input to be images (as Numpy array) ' |
|
|
1243 |
'following the data format convention "' + |
|
|
1244 |
self.data_format + '" (channels on axis ' + |
|
|
1245 |
str(self.channel_axis) + '), i.e. expected ' |
|
|
1246 |
'either 1, 3 or 4 channels on axis ' + |
|
|
1247 |
str(self.channel_axis) + '. ' |
|
|
1248 |
'However, it was passed an array with shape ' + |
|
|
1249 |
str(x.shape) + ' (' + str(x.shape[self.channel_axis]) + |
|
|
1250 |
' channels).') |
|
|
1251 |
|
|
|
1252 |
if seed is not None: |
|
|
1253 |
np.random.seed(seed) |
|
|
1254 |
|
|
|
1255 |
x = np.copy(x) |
|
|
1256 |
if augment: |
|
|
1257 |
ax = np.zeros( |
|
|
1258 |
tuple([rounds * x.shape[0]] + list(x.shape)[1:]), |
|
|
1259 |
dtype=backend.floatx()) |
|
|
1260 |
for r in range(rounds): |
|
|
1261 |
for i in range(x.shape[0]): |
|
|
1262 |
ax[i + r * x.shape[0]] = self.random_transform(x[i]) |
|
|
1263 |
x = ax |
|
|
1264 |
|
|
|
1265 |
if self.featurewise_center: |
|
|
1266 |
self.mean = np.mean(x, axis=(0, self.row_axis, self.col_axis)) |
|
|
1267 |
broadcast_shape = [1, 1, 1] |
|
|
1268 |
broadcast_shape[self.channel_axis - 1] = x.shape[self.channel_axis] |
|
|
1269 |
self.mean = np.reshape(self.mean, broadcast_shape) |
|
|
1270 |
x -= self.mean |
|
|
1271 |
|
|
|
1272 |
if self.featurewise_std_normalization: |
|
|
1273 |
self.std = np.std(x, axis=(0, self.row_axis, self.col_axis)) |
|
|
1274 |
broadcast_shape = [1, 1, 1] |
|
|
1275 |
broadcast_shape[self.channel_axis - 1] = x.shape[self.channel_axis] |
|
|
1276 |
self.std = np.reshape(self.std, broadcast_shape) |
|
|
1277 |
x /= (self.std + backend.epsilon()) |
|
|
1278 |
|
|
|
1279 |
if self.zca_whitening: |
|
|
1280 |
flat_x = np.reshape( |
|
|
1281 |
x, (x.shape[0], x.shape[1] * x.shape[2] * x.shape[3])) |
|
|
1282 |
sigma = np.dot(flat_x.T, flat_x) / flat_x.shape[0] |
|
|
1283 |
u, s, _ = linalg.svd(sigma) |
|
|
1284 |
s_inv = 1. / np.sqrt(s[np.newaxis] + self.zca_epsilon) |
|
|
1285 |
self.principal_components = (u * s_inv).dot(u.T) |
|
|
1286 |
|
|
|
1287 |
|
|
|
1288 |
class Iterator(keras_utils.Sequence): |
|
|
1289 |
"""Base class for image data iterators. |
|
|
1290 |
|
|
|
1291 |
Every `Iterator` must implement the `_get_batches_of_transformed_samples` |
|
|
1292 |
method. |
|
|
1293 |
|
|
|
1294 |
# Arguments |
|
|
1295 |
n: Integer, total number of samples in the dataset to loop over. |
|
|
1296 |
batch_size: Integer, size of a batch. |
|
|
1297 |
shuffle: Boolean, whether to shuffle the data between epochs. |
|
|
1298 |
seed: Random seeding for data shuffling. |
|
|
1299 |
""" |
|
|
1300 |
|
|
|
1301 |
def __init__(self, n, batch_size, shuffle, seed): |
|
|
1302 |
self.n = n |
|
|
1303 |
self.batch_size = batch_size |
|
|
1304 |
self.seed = seed |
|
|
1305 |
self.shuffle = shuffle |
|
|
1306 |
self.batch_index = 0 |
|
|
1307 |
self.total_batches_seen = 0 |
|
|
1308 |
self.lock = threading.Lock() |
|
|
1309 |
self.index_array = None |
|
|
1310 |
self.index_generator = self._flow_index() |
|
|
1311 |
|
|
|
1312 |
def _set_index_array(self): |
|
|
1313 |
self.index_array = np.arange(self.n) |
|
|
1314 |
if self.shuffle: |
|
|
1315 |
self.index_array = np.random.permutation(self.n) |
|
|
1316 |
|
|
|
1317 |
def __getitem__(self, idx): |
|
|
1318 |
if idx >= len(self): |
|
|
1319 |
raise ValueError('Asked to retrieve element {idx}, ' |
|
|
1320 |
'but the Sequence ' |
|
|
1321 |
'has length {length}'.format(idx=idx, |
|
|
1322 |
length=len(self))) |
|
|
1323 |
if self.seed is not None: |
|
|
1324 |
np.random.seed(self.seed + self.total_batches_seen) |
|
|
1325 |
self.total_batches_seen += 1 |
|
|
1326 |
if self.index_array is None: |
|
|
1327 |
self._set_index_array() |
|
|
1328 |
index_array = self.index_array[self.batch_size * idx: |
|
|
1329 |
self.batch_size * (idx + 1)] |
|
|
1330 |
return self._get_batches_of_transformed_samples(index_array) |
|
|
1331 |
|
|
|
1332 |
def __len__(self): |
|
|
1333 |
return (self.n + self.batch_size - 1) // self.batch_size # round up |
|
|
1334 |
|
|
|
1335 |
def on_epoch_end(self): |
|
|
1336 |
self._set_index_array() |
|
|
1337 |
|
|
|
1338 |
def reset(self): |
|
|
1339 |
self.batch_index = 0 |
|
|
1340 |
|
|
|
1341 |
def _flow_index(self): |
|
|
1342 |
# Ensure self.batch_index is 0. |
|
|
1343 |
self.reset() |
|
|
1344 |
while 1: |
|
|
1345 |
if self.seed is not None: |
|
|
1346 |
np.random.seed(self.seed + self.total_batches_seen) |
|
|
1347 |
if self.batch_index == 0: |
|
|
1348 |
self._set_index_array() |
|
|
1349 |
|
|
|
1350 |
current_index = (self.batch_index * self.batch_size) % self.n |
|
|
1351 |
if self.n > current_index + self.batch_size: |
|
|
1352 |
self.batch_index += 1 |
|
|
1353 |
else: |
|
|
1354 |
self.batch_index = 0 |
|
|
1355 |
self.total_batches_seen += 1 |
|
|
1356 |
yield self.index_array[current_index: |
|
|
1357 |
current_index + self.batch_size] |
|
|
1358 |
|
|
|
1359 |
def __iter__(self): |
|
|
1360 |
# Needed if we want to do something like: |
|
|
1361 |
# for x, y in data_gen.flow(...): |
|
|
1362 |
return self |
|
|
1363 |
|
|
|
1364 |
def __next__(self, *args, **kwargs): |
|
|
1365 |
return self.next(*args, **kwargs) |
|
|
1366 |
|
|
|
1367 |
def _get_batches_of_transformed_samples(self, index_array): |
|
|
1368 |
"""Gets a batch of transformed samples. |
|
|
1369 |
|
|
|
1370 |
# Arguments |
|
|
1371 |
index_array: Array of sample indices to include in batch. |
|
|
1372 |
|
|
|
1373 |
# Returns |
|
|
1374 |
A batch of transformed samples. |
|
|
1375 |
""" |
|
|
1376 |
raise NotImplementedError |
|
|
1377 |
|
|
|
1378 |
|
|
|
1379 |
class NumpyArrayIterator(Iterator): |
|
|
1380 |
"""Iterator yielding data from a Numpy array. |
|
|
1381 |
|
|
|
1382 |
# Arguments |
|
|
1383 |
x: Numpy array of input data or tuple. |
|
|
1384 |
If tuple, the second elements is either |
|
|
1385 |
another numpy array or a list of numpy arrays, |
|
|
1386 |
each of which gets passed |
|
|
1387 |
through as an output without any modifications. |
|
|
1388 |
y: Numpy array of targets data. |
|
|
1389 |
image_data_generator: Instance of `ImageDataGenerator` |
|
|
1390 |
to use for random transformations and normalization. |
|
|
1391 |
batch_size: Integer, size of a batch. |
|
|
1392 |
shuffle: Boolean, whether to shuffle the data between epochs. |
|
|
1393 |
sample_weight: Numpy array of sample weights. |
|
|
1394 |
seed: Random seed for data shuffling. |
|
|
1395 |
data_format: String, one of `channels_first`, `channels_last`. |
|
|
1396 |
save_to_dir: Optional directory where to save the pictures |
|
|
1397 |
being yielded, in a viewable format. This is useful |
|
|
1398 |
for visualizing the random transformations being |
|
|
1399 |
applied, for debugging purposes. |
|
|
1400 |
save_prefix: String prefix to use for saving sample |
|
|
1401 |
images (if `save_to_dir` is set). |
|
|
1402 |
save_format: Format to use for saving sample images |
|
|
1403 |
(if `save_to_dir` is set). |
|
|
1404 |
subset: Subset of data (`"training"` or `"validation"`) if |
|
|
1405 |
validation_split is set in ImageDataGenerator. |
|
|
1406 |
""" |
|
|
1407 |
|
|
|
1408 |
def __init__(self, x, y, image_data_generator, |
|
|
1409 |
batch_size=32, shuffle=False, sample_weight=None, |
|
|
1410 |
seed=None, data_format=None, |
|
|
1411 |
save_to_dir=None, save_prefix='', save_format='png', |
|
|
1412 |
subset=None): |
|
|
1413 |
if (type(x) is tuple) or (type(x) is list): |
|
|
1414 |
if type(x[1]) is not list: |
|
|
1415 |
x_misc = [np.asarray(x[1])] |
|
|
1416 |
else: |
|
|
1417 |
x_misc = [np.asarray(xx) for xx in x[1]] |
|
|
1418 |
x = x[0] |
|
|
1419 |
for xx in x_misc: |
|
|
1420 |
if len(x) != len(xx): |
|
|
1421 |
raise ValueError( |
|
|
1422 |
'All of the arrays in `x` ' |
|
|
1423 |
'should have the same length. ' |
|
|
1424 |
'Found a pair with: len(x[0]) = %s, len(x[?]) = %s' % |
|
|
1425 |
(len(x), len(xx))) |
|
|
1426 |
else: |
|
|
1427 |
x_misc = [] |
|
|
1428 |
|
|
|
1429 |
if y is not None and len(x) != len(y): |
|
|
1430 |
raise ValueError('`x` (images tensor) and `y` (labels) ' |
|
|
1431 |
'should have the same length. ' |
|
|
1432 |
'Found: x.shape = %s, y.shape = %s' % |
|
|
1433 |
(np.asarray(x).shape, np.asarray(y).shape)) |
|
|
1434 |
if sample_weight is not None and len(x) != len(sample_weight): |
|
|
1435 |
raise ValueError('`x` (images tensor) and `sample_weight` ' |
|
|
1436 |
'should have the same length. ' |
|
|
1437 |
'Found: x.shape = %s, sample_weight.shape = %s' % |
|
|
1438 |
(np.asarray(x).shape, np.asarray(sample_weight).shape)) |
|
|
1439 |
if subset is not None: |
|
|
1440 |
if subset not in {'training', 'validation'}: |
|
|
1441 |
raise ValueError('Invalid subset name:', subset, |
|
|
1442 |
'; expected "training" or "validation".') |
|
|
1443 |
split_idx = int(len(x) * image_data_generator._validation_split) |
|
|
1444 |
if subset == 'validation': |
|
|
1445 |
x = x[:split_idx] |
|
|
1446 |
x_misc = [np.asarray(xx[:split_idx]) for xx in x_misc] |
|
|
1447 |
if y is not None: |
|
|
1448 |
y = y[:split_idx] |
|
|
1449 |
else: |
|
|
1450 |
x = x[split_idx:] |
|
|
1451 |
x_misc = [np.asarray(xx[split_idx:]) for xx in x_misc] |
|
|
1452 |
if y is not None: |
|
|
1453 |
y = y[split_idx:] |
|
|
1454 |
if data_format is None: |
|
|
1455 |
data_format = backend.image_data_format() |
|
|
1456 |
self.x = np.asarray(x, dtype=backend.floatx()) |
|
|
1457 |
self.x_misc = x_misc |
|
|
1458 |
if self.x.ndim != 4: |
|
|
1459 |
raise ValueError('Input data in `NumpyArrayIterator` ' |
|
|
1460 |
'should have rank 4. You passed an array ' |
|
|
1461 |
'with shape', self.x.shape) |
|
|
1462 |
channels_axis = 3 if data_format == 'channels_last' else 1 |
|
|
1463 |
if self.x.shape[channels_axis] not in {1, 3, 4}: |
|
|
1464 |
warnings.warn('NumpyArrayIterator is set to use the ' |
|
|
1465 |
'data format convention "' + data_format + '" ' |
|
|
1466 |
'(channels on axis ' + str(channels_axis) + |
|
|
1467 |
'), i.e. expected either 1, 3 or 4 ' |
|
|
1468 |
'channels on axis ' + str(channels_axis) + '. ' |
|
|
1469 |
'However, it was passed an array with shape ' + |
|
|
1470 |
str(self.x.shape) + ' (' + |
|
|
1471 |
str(self.x.shape[channels_axis]) + ' channels).') |
|
|
1472 |
if y is not None: |
|
|
1473 |
self.y = np.asarray(y) |
|
|
1474 |
else: |
|
|
1475 |
self.y = None |
|
|
1476 |
if sample_weight is not None: |
|
|
1477 |
self.sample_weight = np.asarray(sample_weight) |
|
|
1478 |
else: |
|
|
1479 |
self.sample_weight = None |
|
|
1480 |
self.image_data_generator = image_data_generator |
|
|
1481 |
self.data_format = data_format |
|
|
1482 |
self.save_to_dir = save_to_dir |
|
|
1483 |
self.save_prefix = save_prefix |
|
|
1484 |
self.save_format = save_format |
|
|
1485 |
super(NumpyArrayIterator, self).__init__(x.shape[0], |
|
|
1486 |
batch_size, |
|
|
1487 |
shuffle, |
|
|
1488 |
seed) |
|
|
1489 |
|
|
|
1490 |
def _get_batches_of_transformed_samples(self, index_array): |
|
|
1491 |
batch_x = np.zeros(tuple([len(index_array)] + list(self.x.shape)[1:]), |
|
|
1492 |
dtype=backend.floatx()) |
|
|
1493 |
for i, j in enumerate(index_array): |
|
|
1494 |
x = self.x[j] |
|
|
1495 |
params = self.image_data_generator.get_random_transform(x.shape) |
|
|
1496 |
x = self.image_data_generator.apply_transform( |
|
|
1497 |
x.astype(backend.floatx()), params) |
|
|
1498 |
x = self.image_data_generator.standardize(x) |
|
|
1499 |
batch_x[i] = x |
|
|
1500 |
|
|
|
1501 |
if self.save_to_dir: |
|
|
1502 |
for i, j in enumerate(index_array): |
|
|
1503 |
img = array_to_img(batch_x[i], self.data_format, scale=True) |
|
|
1504 |
fname = '{prefix}_{index}_{hash}.{format}'.format( |
|
|
1505 |
prefix=self.save_prefix, |
|
|
1506 |
index=j, |
|
|
1507 |
hash=np.random.randint(1e4), |
|
|
1508 |
format=self.save_format) |
|
|
1509 |
img.save(os.path.join(self.save_to_dir, fname)) |
|
|
1510 |
batch_x_miscs = [xx[index_array] for xx in self.x_misc] |
|
|
1511 |
output = (batch_x if batch_x_miscs == [] |
|
|
1512 |
else [batch_x] + batch_x_miscs,) |
|
|
1513 |
if self.y is None: |
|
|
1514 |
return output[0] |
|
|
1515 |
output += (self.y[index_array],) |
|
|
1516 |
if self.sample_weight is not None: |
|
|
1517 |
output += (self.sample_weight[index_array],) |
|
|
1518 |
return output |
|
|
1519 |
|
|
|
1520 |
def next(self): |
|
|
1521 |
"""For python 2.x. |
|
|
1522 |
|
|
|
1523 |
# Returns |
|
|
1524 |
The next batch. |
|
|
1525 |
""" |
|
|
1526 |
# Keeps under lock only the mechanism which advances |
|
|
1527 |
# the indexing of each batch. |
|
|
1528 |
with self.lock: |
|
|
1529 |
index_array = next(self.index_generator) |
|
|
1530 |
# The transformation of images is not under thread lock |
|
|
1531 |
# so it can be done in parallel |
|
|
1532 |
return self._get_batches_of_transformed_samples(index_array) |
|
|
1533 |
|
|
|
1534 |
|
|
|
1535 |
def _iter_valid_files(directory, white_list_formats, follow_links): |
|
|
1536 |
"""Iterates on files with extension in `white_list_formats` contained in `directory`. |
|
|
1537 |
|
|
|
1538 |
# Arguments |
|
|
1539 |
directory: Absolute path to the directory |
|
|
1540 |
containing files to be counted |
|
|
1541 |
white_list_formats: Set of strings containing allowed extensions for |
|
|
1542 |
the files to be counted. |
|
|
1543 |
follow_links: Boolean. |
|
|
1544 |
|
|
|
1545 |
# Yields |
|
|
1546 |
Tuple of (root, filename) with extension in `white_list_formats`. |
|
|
1547 |
""" |
|
|
1548 |
def _recursive_list(subpath): |
|
|
1549 |
return sorted(os.walk(subpath, followlinks=follow_links), |
|
|
1550 |
key=lambda x: x[0]) |
|
|
1551 |
|
|
|
1552 |
for root, _, files in _recursive_list(directory): |
|
|
1553 |
for fname in sorted(files): |
|
|
1554 |
for extension in white_list_formats: |
|
|
1555 |
if fname.lower().endswith('.tiff'): |
|
|
1556 |
warnings.warn('Using \'.tiff\' files with multiple bands ' |
|
|
1557 |
'will cause distortion. ' |
|
|
1558 |
'Please verify your output.') |
|
|
1559 |
if fname.lower().endswith('.' + extension): |
|
|
1560 |
yield root, fname |
|
|
1561 |
|
|
|
1562 |
|
|
|
1563 |
def _count_valid_files_in_directory(directory, |
|
|
1564 |
white_list_formats, |
|
|
1565 |
split, |
|
|
1566 |
follow_links): |
|
|
1567 |
"""Counts files with extension in `white_list_formats` contained in `directory`. |
|
|
1568 |
|
|
|
1569 |
# Arguments |
|
|
1570 |
directory: absolute path to the directory |
|
|
1571 |
containing files to be counted |
|
|
1572 |
white_list_formats: set of strings containing allowed extensions for |
|
|
1573 |
the files to be counted. |
|
|
1574 |
split: tuple of floats (e.g. `(0.2, 0.6)`) to only take into |
|
|
1575 |
account a certain fraction of files in each directory. |
|
|
1576 |
E.g.: `segment=(0.6, 1.0)` would only account for last 40 percent |
|
|
1577 |
of images in each directory. |
|
|
1578 |
follow_links: boolean. |
|
|
1579 |
|
|
|
1580 |
# Returns |
|
|
1581 |
the count of files with extension in `white_list_formats` contained in |
|
|
1582 |
the directory. |
|
|
1583 |
""" |
|
|
1584 |
num_files = len(list( |
|
|
1585 |
_iter_valid_files(directory, white_list_formats, follow_links))) |
|
|
1586 |
if split: |
|
|
1587 |
start, stop = int(split[0] * num_files), int(split[1] * num_files) |
|
|
1588 |
else: |
|
|
1589 |
start, stop = 0, num_files |
|
|
1590 |
return stop - start |
|
|
1591 |
|
|
|
1592 |
|
|
|
1593 |
def _list_valid_filenames_in_directory(directory, white_list_formats, split, |
|
|
1594 |
class_indices, follow_links): |
|
|
1595 |
"""Lists paths of files in `subdir` with extensions in `white_list_formats`. |
|
|
1596 |
|
|
|
1597 |
# Arguments |
|
|
1598 |
directory: absolute path to a directory containing the files to list. |
|
|
1599 |
The directory name is used as class label |
|
|
1600 |
and must be a key of `class_indices`. |
|
|
1601 |
white_list_formats: set of strings containing allowed extensions for |
|
|
1602 |
the files to be counted. |
|
|
1603 |
split: tuple of floats (e.g. `(0.2, 0.6)`) to only take into |
|
|
1604 |
account a certain fraction of files in each directory. |
|
|
1605 |
E.g.: `segment=(0.6, 1.0)` would only account for last 40 percent |
|
|
1606 |
of images in each directory. |
|
|
1607 |
class_indices: dictionary mapping a class name to its index. |
|
|
1608 |
follow_links: boolean. |
|
|
1609 |
|
|
|
1610 |
# Returns |
|
|
1611 |
classes: a list of class indices |
|
|
1612 |
filenames: the path of valid files in `directory`, relative from |
|
|
1613 |
`directory`'s parent (e.g., if `directory` is "dataset/class1", |
|
|
1614 |
the filenames will be |
|
|
1615 |
`["class1/file1.jpg", "class1/file2.jpg", ...]`). |
|
|
1616 |
""" |
|
|
1617 |
dirname = os.path.basename(directory) |
|
|
1618 |
if split: |
|
|
1619 |
num_files = len(list( |
|
|
1620 |
_iter_valid_files(directory, white_list_formats, follow_links))) |
|
|
1621 |
start, stop = int(split[0] * num_files), int(split[1] * num_files) |
|
|
1622 |
valid_files = list( |
|
|
1623 |
_iter_valid_files( |
|
|
1624 |
directory, white_list_formats, follow_links))[start: stop] |
|
|
1625 |
else: |
|
|
1626 |
valid_files = _iter_valid_files( |
|
|
1627 |
directory, white_list_formats, follow_links) |
|
|
1628 |
|
|
|
1629 |
classes = [] |
|
|
1630 |
filenames = [] |
|
|
1631 |
for root, fname in valid_files: |
|
|
1632 |
classes.append(class_indices[dirname]) |
|
|
1633 |
absolute_path = os.path.join(root, fname) |
|
|
1634 |
relative_path = os.path.join( |
|
|
1635 |
dirname, os.path.relpath(absolute_path, directory)) |
|
|
1636 |
filenames.append(relative_path) |
|
|
1637 |
|
|
|
1638 |
return classes, filenames |
|
|
1639 |
|
|
|
1640 |
|
|
|
1641 |
class DirectoryIterator(Iterator): |
|
|
1642 |
"""Iterator capable of reading images from a directory on disk. |
|
|
1643 |
|
|
|
1644 |
# Arguments |
|
|
1645 |
directory: Path to the directory to read images from. |
|
|
1646 |
Each subdirectory in this directory will be |
|
|
1647 |
considered to contain images from one class, |
|
|
1648 |
or alternatively you could specify class subdirectories |
|
|
1649 |
via the `classes` argument. |
|
|
1650 |
image_data_generator: Instance of `ImageDataGenerator` |
|
|
1651 |
to use for random transformations and normalization. |
|
|
1652 |
target_size: tuple of integers, dimensions to resize input images to. |
|
|
1653 |
color_mode: One of `"rgb"`, `"grayscale"`. Color mode to read images. |
|
|
1654 |
classes: Optional list of strings, names of subdirectories |
|
|
1655 |
containing images from each class (e.g. `["dogs", "cats"]`). |
|
|
1656 |
It will be computed automatically if not set. |
|
|
1657 |
class_mode: Mode for yielding the targets: |
|
|
1658 |
`"binary"`: binary targets (if there are only two classes), |
|
|
1659 |
`"categorical"`: categorical targets, |
|
|
1660 |
`"sparse"`: integer targets, |
|
|
1661 |
`"input"`: targets are images identical to input images (mainly |
|
|
1662 |
used to work with autoencoders), |
|
|
1663 |
`None`: no targets get yielded (only input images are yielded). |
|
|
1664 |
batch_size: Integer, size of a batch. |
|
|
1665 |
shuffle: Boolean, whether to shuffle the data between epochs. |
|
|
1666 |
seed: Random seed for data shuffling. |
|
|
1667 |
data_format: String, one of `channels_first`, `channels_last`. |
|
|
1668 |
save_to_dir: Optional directory where to save the pictures |
|
|
1669 |
being yielded, in a viewable format. This is useful |
|
|
1670 |
for visualizing the random transformations being |
|
|
1671 |
applied, for debugging purposes. |
|
|
1672 |
save_prefix: String prefix to use for saving sample |
|
|
1673 |
images (if `save_to_dir` is set). |
|
|
1674 |
save_format: Format to use for saving sample images |
|
|
1675 |
(if `save_to_dir` is set). |
|
|
1676 |
subset: Subset of data (`"training"` or `"validation"`) if |
|
|
1677 |
validation_split is set in ImageDataGenerator. |
|
|
1678 |
interpolation: Interpolation method used to resample the image if the |
|
|
1679 |
target size is different from that of the loaded image. |
|
|
1680 |
Supported methods are "nearest", "bilinear", and "bicubic". |
|
|
1681 |
If PIL version 1.1.3 or newer is installed, "lanczos" is also |
|
|
1682 |
supported. If PIL version 3.4.0 or newer is installed, "box" and |
|
|
1683 |
"hamming" are also supported. By default, "nearest" is used. |
|
|
1684 |
""" |
|
|
1685 |
|
|
|
1686 |
def __init__(self, directory, image_data_generator, |
|
|
1687 |
target_size=(256, 256), color_mode='rgb', |
|
|
1688 |
classes=None, class_mode='categorical', |
|
|
1689 |
batch_size=32, shuffle=True, seed=None, |
|
|
1690 |
data_format=None, |
|
|
1691 |
save_to_dir=None, save_prefix='', save_format='png', |
|
|
1692 |
follow_links=False, |
|
|
1693 |
subset=None, |
|
|
1694 |
interpolation='nearest'): |
|
|
1695 |
if data_format is None: |
|
|
1696 |
data_format = backend.image_data_format() |
|
|
1697 |
self.directory = directory |
|
|
1698 |
self.image_data_generator = image_data_generator |
|
|
1699 |
self.target_size = tuple(target_size) |
|
|
1700 |
if color_mode not in {'rgb', 'grayscale'}: |
|
|
1701 |
raise ValueError('Invalid color mode:', color_mode, |
|
|
1702 |
'; expected "rgb" or "grayscale".') |
|
|
1703 |
self.color_mode = color_mode |
|
|
1704 |
self.data_format = data_format |
|
|
1705 |
if self.color_mode == 'rgb': |
|
|
1706 |
if self.data_format == 'channels_last': |
|
|
1707 |
self.image_shape = self.target_size + (3,) |
|
|
1708 |
else: |
|
|
1709 |
self.image_shape = (3,) + self.target_size |
|
|
1710 |
else: |
|
|
1711 |
if self.data_format == 'channels_last': |
|
|
1712 |
self.image_shape = self.target_size + (1,) |
|
|
1713 |
else: |
|
|
1714 |
self.image_shape = (1,) + self.target_size |
|
|
1715 |
self.classes = classes |
|
|
1716 |
if class_mode not in {'categorical', 'binary', 'sparse', |
|
|
1717 |
'input', None}: |
|
|
1718 |
raise ValueError('Invalid class_mode:', class_mode, |
|
|
1719 |
'; expected one of "categorical", ' |
|
|
1720 |
'"binary", "sparse", "input"' |
|
|
1721 |
' or None.') |
|
|
1722 |
self.class_mode = class_mode |
|
|
1723 |
self.save_to_dir = save_to_dir |
|
|
1724 |
self.save_prefix = save_prefix |
|
|
1725 |
self.save_format = save_format |
|
|
1726 |
self.interpolation = interpolation |
|
|
1727 |
|
|
|
1728 |
if subset is not None: |
|
|
1729 |
validation_split = self.image_data_generator._validation_split |
|
|
1730 |
if subset == 'validation': |
|
|
1731 |
split = (0, validation_split) |
|
|
1732 |
elif subset == 'training': |
|
|
1733 |
split = (validation_split, 1) |
|
|
1734 |
else: |
|
|
1735 |
raise ValueError('Invalid subset name: ', subset, |
|
|
1736 |
'; expected "training" or "validation"') |
|
|
1737 |
else: |
|
|
1738 |
split = None |
|
|
1739 |
self.subset = subset |
|
|
1740 |
|
|
|
1741 |
white_list_formats = {'png', 'jpg', 'jpeg', 'bmp', |
|
|
1742 |
'ppm', 'tif', 'tiff'} |
|
|
1743 |
# First, count the number of samples and classes. |
|
|
1744 |
self.samples = 0 |
|
|
1745 |
|
|
|
1746 |
if not classes: |
|
|
1747 |
classes = [] |
|
|
1748 |
for subdir in sorted(os.listdir(directory)): |
|
|
1749 |
if os.path.isdir(os.path.join(directory, subdir)): |
|
|
1750 |
classes.append(subdir) |
|
|
1751 |
self.num_classes = len(classes) |
|
|
1752 |
self.class_indices = dict(zip(classes, range(len(classes)))) |
|
|
1753 |
|
|
|
1754 |
pool = multiprocessing.pool.ThreadPool() |
|
|
1755 |
function_partial = partial(_count_valid_files_in_directory, |
|
|
1756 |
white_list_formats=white_list_formats, |
|
|
1757 |
follow_links=follow_links, |
|
|
1758 |
split=split) |
|
|
1759 |
self.samples = sum(pool.map(function_partial, |
|
|
1760 |
(os.path.join(directory, subdir) |
|
|
1761 |
for subdir in classes))) |
|
|
1762 |
|
|
|
1763 |
print('Found %d images belonging to %d classes.' % |
|
|
1764 |
(self.samples, self.num_classes)) |
|
|
1765 |
|
|
|
1766 |
# Second, build an index of the images |
|
|
1767 |
# in the different class subfolders. |
|
|
1768 |
results = [] |
|
|
1769 |
self.filenames = [] |
|
|
1770 |
self.classes = np.zeros((self.samples,), dtype='int32') |
|
|
1771 |
i = 0 |
|
|
1772 |
for dirpath in (os.path.join(directory, subdir) for subdir in classes): |
|
|
1773 |
results.append( |
|
|
1774 |
pool.apply_async(_list_valid_filenames_in_directory, |
|
|
1775 |
(dirpath, white_list_formats, split, |
|
|
1776 |
self.class_indices, follow_links))) |
|
|
1777 |
for res in results: |
|
|
1778 |
classes, filenames = res.get() |
|
|
1779 |
self.classes[i:i + len(classes)] = classes |
|
|
1780 |
self.filenames += filenames |
|
|
1781 |
i += len(classes) |
|
|
1782 |
|
|
|
1783 |
pool.close() |
|
|
1784 |
pool.join() |
|
|
1785 |
super(DirectoryIterator, self).__init__(self.samples, |
|
|
1786 |
batch_size, |
|
|
1787 |
shuffle, |
|
|
1788 |
seed) |
|
|
1789 |
|
|
|
1790 |
def _get_batches_of_transformed_samples(self, index_array): |
|
|
1791 |
batch_x = np.zeros( |
|
|
1792 |
(len(index_array),) + self.image_shape, |
|
|
1793 |
dtype=backend.floatx()) |
|
|
1794 |
grayscale = self.color_mode == 'grayscale' |
|
|
1795 |
# build batch of image data |
|
|
1796 |
for i, j in enumerate(index_array): |
|
|
1797 |
fname = self.filenames[j] |
|
|
1798 |
img = load_img(os.path.join(self.directory, fname), |
|
|
1799 |
grayscale=grayscale, |
|
|
1800 |
target_size=self.target_size, |
|
|
1801 |
interpolation=self.interpolation) |
|
|
1802 |
x = img_to_array(img, data_format=self.data_format) |
|
|
1803 |
params = self.image_data_generator.get_random_transform(x.shape) |
|
|
1804 |
x = self.image_data_generator.apply_transform(x, params) |
|
|
1805 |
|
|
|
1806 |
x = self.image_data_generator.standardize(x) |
|
|
1807 |
|
|
|
1808 |
batch_x[i] = x |
|
|
1809 |
# optionally save augmented images to disk for debugging purposes |
|
|
1810 |
if self.save_to_dir: |
|
|
1811 |
for i, j in enumerate(index_array): |
|
|
1812 |
img = array_to_img(batch_x[i], self.data_format, scale=True) |
|
|
1813 |
fname = '{prefix}_{index}_{hash}.{format}'.format( |
|
|
1814 |
prefix=self.save_prefix, |
|
|
1815 |
index=j, |
|
|
1816 |
hash=np.random.randint(1e7), |
|
|
1817 |
format=self.save_format) |
|
|
1818 |
img.save(os.path.join(self.save_to_dir, fname)) |
|
|
1819 |
# build batch of labels |
|
|
1820 |
if self.class_mode == 'input': |
|
|
1821 |
batch_y = batch_x.copy() |
|
|
1822 |
elif self.class_mode == 'sparse': |
|
|
1823 |
batch_y = self.classes[index_array] |
|
|
1824 |
elif self.class_mode == 'binary': |
|
|
1825 |
batch_y = self.classes[index_array].astype(backend.floatx()) |
|
|
1826 |
elif self.class_mode == 'categorical': |
|
|
1827 |
batch_y = np.zeros( |
|
|
1828 |
(len(batch_x), self.num_classes), |
|
|
1829 |
dtype=backend.floatx()) |
|
|
1830 |
for i, label in enumerate(self.classes[index_array]): |
|
|
1831 |
batch_y[i, label] = 1. |
|
|
1832 |
else: |
|
|
1833 |
return batch_x |
|
|
1834 |
return batch_x, batch_y |
|
|
1835 |
|
|
|
1836 |
def next(self): |
|
|
1837 |
"""For python 2.x. |
|
|
1838 |
|
|
|
1839 |
# Returns |
|
|
1840 |
The next batch. |
|
|
1841 |
""" |
|
|
1842 |
with self.lock: |
|
|
1843 |
index_array = next(self.index_generator) |
|
|
1844 |
# The transformation of images is not under thread lock |
|
|
1845 |
# so it can be done in parallel |
|
|
1846 |
return self._get_batches_of_transformed_samples(index_array) |