--- a +++ b/dsb2018_topcoders/selim/datasets/base.py @@ -0,0 +1,114 @@ + +import os +import random +import time +from abc import abstractmethod + +import cv2 +import numpy as np +from keras.applications import imagenet_utils +from keras.preprocessing.image import Iterator, load_img, img_to_array + +from params import args + + +class BaseMaskDatasetIterator(Iterator): + def __init__(self, + images_dir, + masks_dir, + labels_dir, + image_ids, + crop_shape, + preprocessing_function, + random_transformer=None, + batch_size=8, + shuffle=True, + image_name_template=None, + mask_template=None, + label_template=None, + padding=32, + seed=None, + grayscale_mask=False, + ): + self.images_dir = images_dir + self.masks_dir = masks_dir + self.labels_dir = labels_dir + self.image_ids = image_ids + self.image_name_template = image_name_template + self.mask_template = mask_template + self.label_template = label_template + self.random_transformer = random_transformer + self.crop_shape = crop_shape + self.preprocessing_function = preprocessing_function + self.padding = padding + self.grayscale_mask = grayscale_mask + if seed is None: + seed = np.uint32(time.time() * 1000) + + super(BaseMaskDatasetIterator, self).__init__(len(self.image_ids), batch_size, shuffle, seed) + + @abstractmethod + def transform_mask(self, mask, image): + raise NotImplementedError + + def augment_and_crop_mask_image(self, mask, image, label, img_id, crop_shape): + return mask, image, label + + def transform_batch_y(self, batch_y): + return batch_y + + def _get_batches_of_transformed_samples(self, index_array): + batch_x = [] + batch_y = [] + + for batch_index, image_index in enumerate(index_array): + id = self.image_ids[image_index] + img_name = self.image_name_template.format(id=id) + path = os.path.join(self.images_dir, img_name) + image = np.array(img_to_array(load_img(path)), "uint8") + mask_name = self.mask_template.format(id=id) + mask_path = os.path.join(self.masks_dir, mask_name) + mask = cv2.imread(mask_path, cv2.IMREAD_COLOR) + label = cv2.imread(os.path.join(self.labels_dir, self.label_template.format(id=id)), cv2.IMREAD_UNCHANGED) + if args.use_full_masks: + mask[...,0] = (label > 0) * 255 + if self.crop_shape is not None: + crop_mask, crop_image, crop_label = self.augment_and_crop_mask_image(mask, image, label, id, self.crop_shape) + data = self.random_transformer(image=np.array(crop_image, "uint8"), mask=np.array(crop_mask, "uint8")) + crop_image, crop_mask = data['image'], data['mask'] + if len(np.shape(crop_mask)) == 2: + crop_mask = np.expand_dims(crop_mask, -1) + crop_mask = self.transform_mask(crop_mask, crop_image) + batch_x.append(crop_image) + batch_y.append(crop_mask) + else: + x0, x1, y0, y1 = 0, 0, 0, 0 + if (image.shape[1] % 32) != 0: + x0 = int((32 - image.shape[1] % 32) / 2) + x1 = (32 - image.shape[1] % 32) - x0 + if (image.shape[0] % 32) != 0: + y0 = int((32 - image.shape[0] % 32) / 2) + y1 = (32 - image.shape[0] % 32) - y0 + image = np.pad(image, ((y0, y1), (x0, x1), (0, 0)), 'reflect') + mask = np.pad(mask, ((y0, y1), (x0, x1), (0, 0)), 'reflect') + batch_x.append(image) + mask = self.transform_mask(mask, image) + + batch_y.append(mask) + batch_x = np.array(batch_x, dtype="float32") + batch_y = np.array(batch_y, dtype="float32") + if self.preprocessing_function: + batch_x = imagenet_utils.preprocess_input(batch_x, mode=self.preprocessing_function) + return self.transform_batch_x(batch_x), self.transform_batch_y(batch_y) + + def transform_batch_x(self, batch_x): + return batch_x + + + def next(self): + + with self.lock: + index_array = next(self.index_generator) + return self._get_batches_of_transformed_samples(index_array) + +