--- a
+++ b/dsb2018_topcoders/selim/datasets/base.py
@@ -0,0 +1,114 @@
+
+import os
+import random
+import time
+from abc import abstractmethod
+
+import cv2
+import numpy as np
+from keras.applications import imagenet_utils
+from keras.preprocessing.image import Iterator, load_img, img_to_array
+
+from params import args
+
+
+class BaseMaskDatasetIterator(Iterator):
+    def __init__(self,
+                 images_dir,
+                 masks_dir,
+                 labels_dir,
+                 image_ids,
+                 crop_shape,
+                 preprocessing_function,
+                 random_transformer=None,
+                 batch_size=8,
+                 shuffle=True,
+                 image_name_template=None,
+                 mask_template=None,
+                 label_template=None,
+                 padding=32,
+                 seed=None,
+                 grayscale_mask=False,
+                 ):
+        self.images_dir = images_dir
+        self.masks_dir = masks_dir
+        self.labels_dir = labels_dir
+        self.image_ids = image_ids
+        self.image_name_template = image_name_template
+        self.mask_template = mask_template
+        self.label_template = label_template
+        self.random_transformer = random_transformer
+        self.crop_shape = crop_shape
+        self.preprocessing_function = preprocessing_function
+        self.padding = padding
+        self.grayscale_mask = grayscale_mask
+        if seed is None:
+            seed = np.uint32(time.time() * 1000)
+
+        super(BaseMaskDatasetIterator, self).__init__(len(self.image_ids), batch_size, shuffle, seed)
+
+    @abstractmethod
+    def transform_mask(self, mask, image):
+        raise NotImplementedError
+
+    def augment_and_crop_mask_image(self, mask, image, label, img_id, crop_shape):
+        return mask, image, label
+
+    def transform_batch_y(self, batch_y):
+        return batch_y
+
+    def _get_batches_of_transformed_samples(self, index_array):
+        batch_x = []
+        batch_y = []
+
+        for batch_index, image_index in enumerate(index_array):
+            id = self.image_ids[image_index]
+            img_name = self.image_name_template.format(id=id)
+            path = os.path.join(self.images_dir, img_name)
+            image = np.array(img_to_array(load_img(path)), "uint8")
+            mask_name = self.mask_template.format(id=id)
+            mask_path = os.path.join(self.masks_dir, mask_name)
+            mask = cv2.imread(mask_path, cv2.IMREAD_COLOR)
+            label = cv2.imread(os.path.join(self.labels_dir, self.label_template.format(id=id)), cv2.IMREAD_UNCHANGED)
+            if args.use_full_masks:
+                mask[...,0] = (label > 0) * 255
+            if self.crop_shape is not None:
+                crop_mask, crop_image, crop_label = self.augment_and_crop_mask_image(mask, image, label, id, self.crop_shape)
+                data = self.random_transformer(image=np.array(crop_image, "uint8"), mask=np.array(crop_mask, "uint8"))
+                crop_image, crop_mask = data['image'], data['mask']
+                if len(np.shape(crop_mask)) == 2:
+                    crop_mask = np.expand_dims(crop_mask, -1)
+                crop_mask = self.transform_mask(crop_mask, crop_image)
+                batch_x.append(crop_image)
+                batch_y.append(crop_mask)
+            else:
+                x0, x1, y0, y1 = 0, 0, 0, 0
+                if (image.shape[1] % 32) != 0:
+                    x0 = int((32 - image.shape[1] % 32) / 2)
+                    x1 = (32 - image.shape[1] % 32) - x0
+                if (image.shape[0] % 32) != 0:
+                    y0 = int((32 - image.shape[0] % 32) / 2)
+                    y1 = (32 - image.shape[0] % 32) - y0
+                image = np.pad(image, ((y0, y1), (x0, x1), (0, 0)), 'reflect')
+                mask = np.pad(mask, ((y0, y1), (x0, x1), (0, 0)), 'reflect')
+                batch_x.append(image)
+                mask = self.transform_mask(mask, image)
+
+                batch_y.append(mask)
+        batch_x = np.array(batch_x, dtype="float32")
+        batch_y = np.array(batch_y, dtype="float32")
+        if self.preprocessing_function:
+            batch_x = imagenet_utils.preprocess_input(batch_x, mode=self.preprocessing_function)
+        return self.transform_batch_x(batch_x), self.transform_batch_y(batch_y)
+
+    def transform_batch_x(self, batch_x):
+        return batch_x
+
+
+    def next(self):
+
+        with self.lock:
+            index_array = next(self.index_generator)
+        return self._get_batches_of_transformed_samples(index_array)
+
+