|
a |
|
b/mediaug/dataset.py |
|
|
1 |
import os |
|
|
2 |
from os.path import join |
|
|
3 |
import numpy as np |
|
|
4 |
from PIL import Image |
|
|
5 |
import cv2 |
|
|
6 |
from tqdm import tqdm |
|
|
7 |
from random import choice |
|
|
8 |
from mediaug.image_utils import read_png, save_img |
|
|
9 |
from mediaug.download import get_data_cache |
|
|
10 |
from random import randint |
|
|
11 |
# TODO: Documentation |
|
|
12 |
|
|
|
13 |
class DataPoint: |
|
|
14 |
|
|
|
15 |
def __init__(self, img_path, mask_path, _class, _id=None): |
|
|
16 |
self.img_path = img_path |
|
|
17 |
self.mask_path = mask_path |
|
|
18 |
self._class = _class |
|
|
19 |
if _id is None: |
|
|
20 |
self.id = img_path.split('.')[0] |
|
|
21 |
|
|
|
22 |
@property |
|
|
23 |
def img(self): |
|
|
24 |
return read_png(self.img_path) |
|
|
25 |
|
|
|
26 |
@property |
|
|
27 |
def pil_img(self): |
|
|
28 |
return Image.fromarray(read_png(self.img_path)) |
|
|
29 |
|
|
|
30 |
@property |
|
|
31 |
def mask(self): |
|
|
32 |
return read_png(self.mask_path) |
|
|
33 |
|
|
|
34 |
@property |
|
|
35 |
def pil_mask(self): |
|
|
36 |
return Image.fromarray(read_png(self.mask_path)) |
|
|
37 |
|
|
|
38 |
def __repr__(self): |
|
|
39 |
return f'<img_path: {self.img_path}>\n<mask_path: {self.mask_path}>' |
|
|
40 |
|
|
|
41 |
|
|
|
42 |
class Dataset: |
|
|
43 |
"""Dataset object for managing image augmentation |
|
|
44 |
|
|
|
45 |
Attributes: |
|
|
46 |
data_path (str): Path to the data directory root |
|
|
47 |
""" |
|
|
48 |
|
|
|
49 |
def __init__(self, data_path=None, classes=None): |
|
|
50 |
self.data_path = data_path |
|
|
51 |
if not os.path.exists(data_path) and classes is not None: |
|
|
52 |
self._create_empty_dataset(classes) |
|
|
53 |
if not os.path.exists(data_path) and classes is None: |
|
|
54 |
raise ValueError('No data in path or classes.') |
|
|
55 |
self._parse(data_path) |
|
|
56 |
|
|
|
57 |
def _parse(self, data_path): |
|
|
58 |
self.data = {} |
|
|
59 |
categories = [x for x in os.listdir(data_path) if not x.startswith('.')] |
|
|
60 |
self.data = {key:[] for key in categories} |
|
|
61 |
for c in categories: |
|
|
62 |
cur_dir = join(data_path, c) |
|
|
63 |
for base_name in os.listdir(join(cur_dir, 'image')): |
|
|
64 |
name = base_name.split('.')[0] |
|
|
65 |
dp = DataPoint(join(cur_dir, 'image', base_name), |
|
|
66 |
join(cur_dir, 'mask', base_name), c, name) |
|
|
67 |
self.data[c].append(dp) |
|
|
68 |
|
|
|
69 |
def _create_empty_dataset(self, classes): |
|
|
70 |
os.mkdir(self.data_path) |
|
|
71 |
self.data = {key:[] for key in classes} |
|
|
72 |
for _class in classes: |
|
|
73 |
os.mkdir(join(self.data_path, _class)) |
|
|
74 |
os.mkdir(join(self.data_path, _class, 'image')) |
|
|
75 |
os.mkdir(join(self.data_path, _class, 'mask')) |
|
|
76 |
|
|
|
77 |
|
|
|
78 |
def add_datapoint(self, dp): |
|
|
79 |
self.data[dp._class].append(dp) |
|
|
80 |
|
|
|
81 |
|
|
|
82 |
def random_sample(self): |
|
|
83 |
_class = choice(self.classes) |
|
|
84 |
return choice(self.data[_class]) |
|
|
85 |
|
|
|
86 |
|
|
|
87 |
def add_data(self, img, mask, _class, name): |
|
|
88 |
img_path = save_img(img, join(self.data_path, _class, 'image', f'{name}.png')) |
|
|
89 |
mask_path = save_img(mask, join(self.data_path, _class, 'mask', f'{name}.png')) |
|
|
90 |
self.data[_class].append(DataPoint(img_path, mask_path, _class)) |
|
|
91 |
|
|
|
92 |
|
|
|
93 |
def get_data(self, _id): |
|
|
94 |
""" Gets a datapoint by id """ |
|
|
95 |
raise NotImplementedError |
|
|
96 |
|
|
|
97 |
|
|
|
98 |
def get_array(self, num_samples=-1, n_last=False, greyscale=False): |
|
|
99 |
""" This is of the form: |
|
|
100 |
(x_train, y_train), (x_test, y_test) |
|
|
101 |
ex: (num_samples, 32, 32, 3) |
|
|
102 |
(num_samples, 1) |
|
|
103 |
""" |
|
|
104 |
images = [] |
|
|
105 |
masks = [] |
|
|
106 |
for c in tqdm(self.classes): |
|
|
107 |
for dp in tqdm(self.data[c][:num_samples]): |
|
|
108 |
if greyscale == True: |
|
|
109 |
images.append(cv2.cvtColor(dp.img, cv2.COLOR_BGR2GRAY)) |
|
|
110 |
masks.append(cv2.cvtColor(dp.mask, cv2.COLOR_BGR2GRAY)) |
|
|
111 |
else: |
|
|
112 |
images.append(dp.img) |
|
|
113 |
masks.append(dp.mask) |
|
|
114 |
images = np.array(images) |
|
|
115 |
masks = np.array(masks) |
|
|
116 |
if n_last: |
|
|
117 |
images = np.moveaxis(images, 0, -1) |
|
|
118 |
masks = np.moveaxis(masks, 0, -1) |
|
|
119 |
return images, masks |
|
|
120 |
|
|
|
121 |
|
|
|
122 |
@property |
|
|
123 |
def classes(self): |
|
|
124 |
return list(self.data.keys()) |
|
|
125 |
|
|
|
126 |
@property |
|
|
127 |
def size(self): |
|
|
128 |
size = 0 |
|
|
129 |
for c in self.classes: |
|
|
130 |
size += len(self.data[c]) |
|
|
131 |
return size |
|
|
132 |
|
|
|
133 |
def __getitem__(self, arg): |
|
|
134 |
return self.data[arg] |