Diff of /mediaug/dataset.py [000000] .. [05e710]

Switch to side-by-side view

--- a
+++ b/mediaug/dataset.py
@@ -0,0 +1,134 @@
+import os
+from os.path import join
+import numpy as np
+from PIL import Image
+import cv2
+from tqdm import tqdm
+from random import choice
+from mediaug.image_utils import read_png, save_img
+from mediaug.download import get_data_cache
+from random import randint
+# TODO: Documentation
+
+class DataPoint:
+
+    def __init__(self, img_path, mask_path, _class, _id=None):
+        self.img_path = img_path
+        self.mask_path = mask_path
+        self._class = _class
+        if _id is None:
+            self.id = img_path.split('.')[0]
+
+    @property
+    def img(self):
+        return read_png(self.img_path)
+
+    @property
+    def pil_img(self):
+        return Image.fromarray(read_png(self.img_path))
+
+    @property
+    def mask(self):
+        return read_png(self.mask_path)
+
+    @property
+    def pil_mask(self):
+        return Image.fromarray(read_png(self.mask_path))
+
+    def __repr__(self):
+        return f'<img_path: {self.img_path}>\n<mask_path: {self.mask_path}>'
+
+
+class Dataset:
+    """Dataset object for managing image augmentation
+
+    Attributes:
+        data_path (str): Path to the data directory root
+    """
+
+    def __init__(self, data_path=None, classes=None):
+        self.data_path = data_path
+        if not os.path.exists(data_path) and classes is not None:
+            self._create_empty_dataset(classes)
+        if not os.path.exists(data_path) and classes is None:
+            raise ValueError('No data in path or classes.')
+        self._parse(data_path)
+    
+    def _parse(self, data_path):
+        self.data = {}
+        categories =  [x for x in os.listdir(data_path) if not x.startswith('.')]
+        self.data = {key:[] for key in categories}
+        for c in categories:
+            cur_dir = join(data_path, c)
+            for base_name in os.listdir(join(cur_dir, 'image')):
+                name = base_name.split('.')[0]
+                dp = DataPoint(join(cur_dir, 'image', base_name),
+                                join(cur_dir, 'mask', base_name), c, name)
+                self.data[c].append(dp)
+
+    def _create_empty_dataset(self, classes):
+        os.mkdir(self.data_path)
+        self.data = {key:[] for key in classes}
+        for _class in classes:
+            os.mkdir(join(self.data_path, _class))
+            os.mkdir(join(self.data_path, _class, 'image'))
+            os.mkdir(join(self.data_path, _class, 'mask'))
+
+
+    def add_datapoint(self, dp):
+        self.data[dp._class].append(dp)
+
+    
+    def random_sample(self):
+        _class = choice(self.classes)
+        return choice(self.data[_class])
+
+
+    def add_data(self, img, mask, _class, name):
+        img_path = save_img(img, join(self.data_path, _class, 'image', f'{name}.png'))
+        mask_path = save_img(mask, join(self.data_path, _class, 'mask', f'{name}.png'))
+        self.data[_class].append(DataPoint(img_path, mask_path, _class))
+
+
+    def get_data(self, _id):
+        """ Gets a datapoint by id """
+        raise NotImplementedError 
+
+
+    def get_array(self, num_samples=-1, n_last=False, greyscale=False):
+        """ This is of the form:
+        (x_train, y_train), (x_test, y_test)
+        ex: (num_samples, 32, 32, 3)
+        (num_samples, 1)
+        """
+        images = []
+        masks = []
+        for c in tqdm(self.classes):
+            for dp in tqdm(self.data[c][:num_samples]):
+                if greyscale == True:
+                    images.append(cv2.cvtColor(dp.img, cv2.COLOR_BGR2GRAY))
+                    masks.append(cv2.cvtColor(dp.mask, cv2.COLOR_BGR2GRAY))
+                else:
+                    images.append(dp.img)
+                    masks.append(dp.mask)
+        images = np.array(images)
+        masks = np.array(masks)
+        if n_last:
+            images = np.moveaxis(images, 0, -1)
+            masks = np.moveaxis(masks, 0, -1)
+        return images, masks
+    
+
+    @property
+    def classes(self):
+        return list(self.data.keys())
+
+    @property
+    def size(self):
+        size = 0
+        for c in self.classes:
+            size += len(self.data[c])
+        return size
+
+    def __getitem__(self, arg):
+        return self.data[arg]