--- a +++ b/mediaug/dataset.py @@ -0,0 +1,134 @@ +import os +from os.path import join +import numpy as np +from PIL import Image +import cv2 +from tqdm import tqdm +from random import choice +from mediaug.image_utils import read_png, save_img +from mediaug.download import get_data_cache +from random import randint +# TODO: Documentation + +class DataPoint: + + def __init__(self, img_path, mask_path, _class, _id=None): + self.img_path = img_path + self.mask_path = mask_path + self._class = _class + if _id is None: + self.id = img_path.split('.')[0] + + @property + def img(self): + return read_png(self.img_path) + + @property + def pil_img(self): + return Image.fromarray(read_png(self.img_path)) + + @property + def mask(self): + return read_png(self.mask_path) + + @property + def pil_mask(self): + return Image.fromarray(read_png(self.mask_path)) + + def __repr__(self): + return f'<img_path: {self.img_path}>\n<mask_path: {self.mask_path}>' + + +class Dataset: + """Dataset object for managing image augmentation + + Attributes: + data_path (str): Path to the data directory root + """ + + def __init__(self, data_path=None, classes=None): + self.data_path = data_path + if not os.path.exists(data_path) and classes is not None: + self._create_empty_dataset(classes) + if not os.path.exists(data_path) and classes is None: + raise ValueError('No data in path or classes.') + self._parse(data_path) + + def _parse(self, data_path): + self.data = {} + categories = [x for x in os.listdir(data_path) if not x.startswith('.')] + self.data = {key:[] for key in categories} + for c in categories: + cur_dir = join(data_path, c) + for base_name in os.listdir(join(cur_dir, 'image')): + name = base_name.split('.')[0] + dp = DataPoint(join(cur_dir, 'image', base_name), + join(cur_dir, 'mask', base_name), c, name) + self.data[c].append(dp) + + def _create_empty_dataset(self, classes): + os.mkdir(self.data_path) + self.data = {key:[] for key in classes} + for _class in classes: + os.mkdir(join(self.data_path, _class)) + os.mkdir(join(self.data_path, _class, 'image')) + os.mkdir(join(self.data_path, _class, 'mask')) + + + def add_datapoint(self, dp): + self.data[dp._class].append(dp) + + + def random_sample(self): + _class = choice(self.classes) + return choice(self.data[_class]) + + + def add_data(self, img, mask, _class, name): + img_path = save_img(img, join(self.data_path, _class, 'image', f'{name}.png')) + mask_path = save_img(mask, join(self.data_path, _class, 'mask', f'{name}.png')) + self.data[_class].append(DataPoint(img_path, mask_path, _class)) + + + def get_data(self, _id): + """ Gets a datapoint by id """ + raise NotImplementedError + + + def get_array(self, num_samples=-1, n_last=False, greyscale=False): + """ This is of the form: + (x_train, y_train), (x_test, y_test) + ex: (num_samples, 32, 32, 3) + (num_samples, 1) + """ + images = [] + masks = [] + for c in tqdm(self.classes): + for dp in tqdm(self.data[c][:num_samples]): + if greyscale == True: + images.append(cv2.cvtColor(dp.img, cv2.COLOR_BGR2GRAY)) + masks.append(cv2.cvtColor(dp.mask, cv2.COLOR_BGR2GRAY)) + else: + images.append(dp.img) + masks.append(dp.mask) + images = np.array(images) + masks = np.array(masks) + if n_last: + images = np.moveaxis(images, 0, -1) + masks = np.moveaxis(masks, 0, -1) + return images, masks + + + @property + def classes(self): + return list(self.data.keys()) + + @property + def size(self): + size = 0 + for c in self.classes: + size += len(self.data[c]) + return size + + def __getitem__(self, arg): + return self.data[arg]