Download this file

115 lines (96 with data), 4.6 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import random
import time
from abc import abstractmethod
import cv2
import numpy as np
from keras.applications import imagenet_utils
from keras.preprocessing.image import Iterator, load_img, img_to_array
from params import args
class BaseMaskDatasetIterator(Iterator):
def __init__(self,
images_dir,
masks_dir,
labels_dir,
image_ids,
crop_shape,
preprocessing_function,
random_transformer=None,
batch_size=8,
shuffle=True,
image_name_template=None,
mask_template=None,
label_template=None,
padding=32,
seed=None,
grayscale_mask=False,
):
self.images_dir = images_dir
self.masks_dir = masks_dir
self.labels_dir = labels_dir
self.image_ids = image_ids
self.image_name_template = image_name_template
self.mask_template = mask_template
self.label_template = label_template
self.random_transformer = random_transformer
self.crop_shape = crop_shape
self.preprocessing_function = preprocessing_function
self.padding = padding
self.grayscale_mask = grayscale_mask
if seed is None:
seed = np.uint32(time.time() * 1000)
super(BaseMaskDatasetIterator, self).__init__(len(self.image_ids), batch_size, shuffle, seed)
@abstractmethod
def transform_mask(self, mask, image):
raise NotImplementedError
def augment_and_crop_mask_image(self, mask, image, label, img_id, crop_shape):
return mask, image, label
def transform_batch_y(self, batch_y):
return batch_y
def _get_batches_of_transformed_samples(self, index_array):
batch_x = []
batch_y = []
for batch_index, image_index in enumerate(index_array):
id = self.image_ids[image_index]
img_name = self.image_name_template.format(id=id)
path = os.path.join(self.images_dir, img_name)
image = np.array(img_to_array(load_img(path)), "uint8")
mask_name = self.mask_template.format(id=id)
mask_path = os.path.join(self.masks_dir, mask_name)
mask = cv2.imread(mask_path, cv2.IMREAD_COLOR)
label = cv2.imread(os.path.join(self.labels_dir, self.label_template.format(id=id)), cv2.IMREAD_UNCHANGED)
if args.use_full_masks:
mask[...,0] = (label > 0) * 255
if self.crop_shape is not None:
crop_mask, crop_image, crop_label = self.augment_and_crop_mask_image(mask, image, label, id, self.crop_shape)
data = self.random_transformer(image=np.array(crop_image, "uint8"), mask=np.array(crop_mask, "uint8"))
crop_image, crop_mask = data['image'], data['mask']
if len(np.shape(crop_mask)) == 2:
crop_mask = np.expand_dims(crop_mask, -1)
crop_mask = self.transform_mask(crop_mask, crop_image)
batch_x.append(crop_image)
batch_y.append(crop_mask)
else:
x0, x1, y0, y1 = 0, 0, 0, 0
if (image.shape[1] % 32) != 0:
x0 = int((32 - image.shape[1] % 32) / 2)
x1 = (32 - image.shape[1] % 32) - x0
if (image.shape[0] % 32) != 0:
y0 = int((32 - image.shape[0] % 32) / 2)
y1 = (32 - image.shape[0] % 32) - y0
image = np.pad(image, ((y0, y1), (x0, x1), (0, 0)), 'reflect')
mask = np.pad(mask, ((y0, y1), (x0, x1), (0, 0)), 'reflect')
batch_x.append(image)
mask = self.transform_mask(mask, image)
batch_y.append(mask)
batch_x = np.array(batch_x, dtype="float32")
batch_y = np.array(batch_y, dtype="float32")
if self.preprocessing_function:
batch_x = imagenet_utils.preprocess_input(batch_x, mode=self.preprocessing_function)
return self.transform_batch_x(batch_x), self.transform_batch_y(batch_y)
def transform_batch_x(self, batch_x):
return batch_x
def next(self):
with self.lock:
index_array = next(self.index_generator)
return self._get_batches_of_transformed_samples(index_array)