Diff of /mediaug/dataset.py [000000] .. [05e710]

Switch to unified view

a b/mediaug/dataset.py
1
import os
2
from os.path import join
3
import numpy as np
4
from PIL import Image
5
import cv2
6
from tqdm import tqdm
7
from random import choice
8
from mediaug.image_utils import read_png, save_img
9
from mediaug.download import get_data_cache
10
from random import randint
11
# TODO: Documentation
12
13
class DataPoint:
14
15
    def __init__(self, img_path, mask_path, _class, _id=None):
16
        self.img_path = img_path
17
        self.mask_path = mask_path
18
        self._class = _class
19
        if _id is None:
20
            self.id = img_path.split('.')[0]
21
22
    @property
23
    def img(self):
24
        return read_png(self.img_path)
25
26
    @property
27
    def pil_img(self):
28
        return Image.fromarray(read_png(self.img_path))
29
30
    @property
31
    def mask(self):
32
        return read_png(self.mask_path)
33
34
    @property
35
    def pil_mask(self):
36
        return Image.fromarray(read_png(self.mask_path))
37
38
    def __repr__(self):
39
        return f'<img_path: {self.img_path}>\n<mask_path: {self.mask_path}>'
40
41
42
class Dataset:
43
    """Dataset object for managing image augmentation
44
45
    Attributes:
46
        data_path (str): Path to the data directory root
47
    """
48
49
    def __init__(self, data_path=None, classes=None):
50
        self.data_path = data_path
51
        if not os.path.exists(data_path) and classes is not None:
52
            self._create_empty_dataset(classes)
53
        if not os.path.exists(data_path) and classes is None:
54
            raise ValueError('No data in path or classes.')
55
        self._parse(data_path)
56
    
57
    def _parse(self, data_path):
58
        self.data = {}
59
        categories =  [x for x in os.listdir(data_path) if not x.startswith('.')]
60
        self.data = {key:[] for key in categories}
61
        for c in categories:
62
            cur_dir = join(data_path, c)
63
            for base_name in os.listdir(join(cur_dir, 'image')):
64
                name = base_name.split('.')[0]
65
                dp = DataPoint(join(cur_dir, 'image', base_name),
66
                                join(cur_dir, 'mask', base_name), c, name)
67
                self.data[c].append(dp)
68
69
    def _create_empty_dataset(self, classes):
70
        os.mkdir(self.data_path)
71
        self.data = {key:[] for key in classes}
72
        for _class in classes:
73
            os.mkdir(join(self.data_path, _class))
74
            os.mkdir(join(self.data_path, _class, 'image'))
75
            os.mkdir(join(self.data_path, _class, 'mask'))
76
77
78
    def add_datapoint(self, dp):
79
        self.data[dp._class].append(dp)
80
81
    
82
    def random_sample(self):
83
        _class = choice(self.classes)
84
        return choice(self.data[_class])
85
86
87
    def add_data(self, img, mask, _class, name):
88
        img_path = save_img(img, join(self.data_path, _class, 'image', f'{name}.png'))
89
        mask_path = save_img(mask, join(self.data_path, _class, 'mask', f'{name}.png'))
90
        self.data[_class].append(DataPoint(img_path, mask_path, _class))
91
92
93
    def get_data(self, _id):
94
        """ Gets a datapoint by id """
95
        raise NotImplementedError 
96
97
98
    def get_array(self, num_samples=-1, n_last=False, greyscale=False):
99
        """ This is of the form:
100
        (x_train, y_train), (x_test, y_test)
101
        ex: (num_samples, 32, 32, 3)
102
        (num_samples, 1)
103
        """
104
        images = []
105
        masks = []
106
        for c in tqdm(self.classes):
107
            for dp in tqdm(self.data[c][:num_samples]):
108
                if greyscale == True:
109
                    images.append(cv2.cvtColor(dp.img, cv2.COLOR_BGR2GRAY))
110
                    masks.append(cv2.cvtColor(dp.mask, cv2.COLOR_BGR2GRAY))
111
                else:
112
                    images.append(dp.img)
113
                    masks.append(dp.mask)
114
        images = np.array(images)
115
        masks = np.array(masks)
116
        if n_last:
117
            images = np.moveaxis(images, 0, -1)
118
            masks = np.moveaxis(masks, 0, -1)
119
        return images, masks
120
    
121
122
    @property
123
    def classes(self):
124
        return list(self.data.keys())
125
126
    @property
127
    def size(self):
128
        size = 0
129
        for c in self.classes:
130
            size += len(self.data[c])
131
        return size
132
133
    def __getitem__(self, arg):
134
        return self.data[arg]