|
a |
|
b/sybil/loaders/abstract_loader.py |
|
|
1 |
import torch |
|
|
2 |
import os |
|
|
3 |
import sys |
|
|
4 |
import os.path |
|
|
5 |
import warnings |
|
|
6 |
from sybil.datasets.utils import get_scaled_annotation_mask, IMG_PAD_TOKEN |
|
|
7 |
from sybil.augmentations import ComposeAug |
|
|
8 |
import numpy as np |
|
|
9 |
from abc import ABCMeta, abstractmethod |
|
|
10 |
import hashlib |
|
|
11 |
|
|
|
12 |
|
|
|
13 |
CACHED_FILES_EXT = ".png" |
|
|
14 |
DEFAULT_CACHE_DIR = "default/" |
|
|
15 |
|
|
|
16 |
CORUPTED_FILE_ERR = ( |
|
|
17 |
"WARNING! Error processing file from cache - removed file from cache. Error: {}" |
|
|
18 |
) |
|
|
19 |
|
|
|
20 |
|
|
|
21 |
def md5(key): |
|
|
22 |
""" |
|
|
23 |
returns a hashed with md5 string of the key |
|
|
24 |
""" |
|
|
25 |
return hashlib.md5(key.encode()).hexdigest() |
|
|
26 |
|
|
|
27 |
|
|
|
28 |
def split_augmentations_by_cache(augmentations): |
|
|
29 |
""" |
|
|
30 |
Given a list of augmentations, returns a list of tuples. Each tuple |
|
|
31 |
contains a caching key of the augmentations up to the spiltting point, |
|
|
32 |
and a list of augmentations that should be applied afterwards. |
|
|
33 |
|
|
|
34 |
split_augmentations will contain all possible splits by cachable augmentations, |
|
|
35 |
ordered from latest possible one to the former ones. |
|
|
36 |
The last tuple will have all augmentations. |
|
|
37 |
|
|
|
38 |
Note - splitting will be done for indexes that all augmentations up to them are |
|
|
39 |
cachable. |
|
|
40 |
""" |
|
|
41 |
# list of (cache key, post augmentations) |
|
|
42 |
split_augmentations = [] |
|
|
43 |
split_augmentations.append((DEFAULT_CACHE_DIR, augmentations)) |
|
|
44 |
all_prev_cachable = True |
|
|
45 |
key = "" |
|
|
46 |
for ind, trans in enumerate(augmentations): |
|
|
47 |
|
|
|
48 |
# check trans.cachable() first separately to save run time |
|
|
49 |
if not all_prev_cachable or not trans.cachable(): |
|
|
50 |
all_prev_cachable = False |
|
|
51 |
else: |
|
|
52 |
key += trans.caching_keys() |
|
|
53 |
post_augmentations = ( |
|
|
54 |
augmentations[ind + 1 :] if ind < len(augmentations) else [] |
|
|
55 |
) |
|
|
56 |
split_augmentations.append((key, post_augmentations)) |
|
|
57 |
|
|
|
58 |
return list(reversed(split_augmentations)) |
|
|
59 |
|
|
|
60 |
|
|
|
61 |
def apply_augmentations_and_cache( |
|
|
62 |
loaded_input, sample, img_path, augmentations, cache, base_key="" |
|
|
63 |
): |
|
|
64 |
""" |
|
|
65 |
Loads the loaded input by its absolute path and apply the augmentations one |
|
|
66 |
by one (similar to what the composed one is doing). All first cachable |
|
|
67 |
transformer's output is cached (until reaching a non cachable one). |
|
|
68 |
""" |
|
|
69 |
all_prev_cachable = True |
|
|
70 |
key = base_key |
|
|
71 |
for ind, trans in enumerate(augmentations): |
|
|
72 |
loaded_input = trans(loaded_input, sample) |
|
|
73 |
if not all_prev_cachable or not trans.cachable(): |
|
|
74 |
all_prev_cachable = False |
|
|
75 |
else: |
|
|
76 |
key += trans.caching_keys() |
|
|
77 |
cache.add(img_path, key, loaded_input["input"]) |
|
|
78 |
|
|
|
79 |
return loaded_input |
|
|
80 |
|
|
|
81 |
|
|
|
82 |
class cache: |
|
|
83 |
def __init__(self, path, extension=CACHED_FILES_EXT): |
|
|
84 |
if not os.path.exists(path): |
|
|
85 |
os.makedirs(path) |
|
|
86 |
|
|
|
87 |
self.cache_dir = path |
|
|
88 |
self.files_extension = extension |
|
|
89 |
if ".npy" != extension: |
|
|
90 |
self.files_extension += ".npy" |
|
|
91 |
|
|
|
92 |
def _file_dir(self, attr_key, par_dir): |
|
|
93 |
return os.path.join(self.cache_dir, attr_key, par_dir) |
|
|
94 |
|
|
|
95 |
def _file_path(self, attr_key, par_dir, hashed_key): |
|
|
96 |
return os.path.join( |
|
|
97 |
self.cache_dir, attr_key, par_dir, hashed_key + self.files_extension |
|
|
98 |
) |
|
|
99 |
|
|
|
100 |
def _parent_dir(self, path): |
|
|
101 |
return os.path.basename(os.path.dirname(path)) |
|
|
102 |
|
|
|
103 |
def exists(self, image_path, attr_key): |
|
|
104 |
hashed_key = md5(image_path) |
|
|
105 |
par_dir = self._parent_dir(image_path) |
|
|
106 |
return os.path.isfile(self._file_path(attr_key, par_dir, hashed_key)) |
|
|
107 |
|
|
|
108 |
def get(self, image_path, attr_key): |
|
|
109 |
hashed_key = md5(image_path) |
|
|
110 |
par_dir = self._parent_dir(image_path) |
|
|
111 |
return np.load(self._file_path(attr_key, par_dir, hashed_key)) |
|
|
112 |
|
|
|
113 |
def add(self, image_path, attr_key, image): |
|
|
114 |
hashed_key = md5(image_path) |
|
|
115 |
par_dir = self._parent_dir(image_path) |
|
|
116 |
file_dir = self._file_dir(attr_key, par_dir) |
|
|
117 |
if not os.path.exists(file_dir): |
|
|
118 |
os.makedirs(file_dir) |
|
|
119 |
np.save(self._file_path(attr_key, par_dir, hashed_key), image) |
|
|
120 |
|
|
|
121 |
def rem(self, image_path, attr_key): |
|
|
122 |
hashed_key = md5(image_path) |
|
|
123 |
par_dir = self._parent_dir(image_path) |
|
|
124 |
try: |
|
|
125 |
os.remove(self._file_path(attr_key, par_dir, hashed_key)) |
|
|
126 |
# Don't raise error if file not exists. |
|
|
127 |
except OSError: |
|
|
128 |
pass |
|
|
129 |
|
|
|
130 |
|
|
|
131 |
class abstract_loader: |
|
|
132 |
__metaclass__ = ABCMeta |
|
|
133 |
|
|
|
134 |
def __init__(self, cache_path, augmentations, args, apply_augmentations=True): |
|
|
135 |
self.pad_token = IMG_PAD_TOKEN |
|
|
136 |
self.augmentations = augmentations |
|
|
137 |
self.args = args |
|
|
138 |
self.apply_augmentations = apply_augmentations |
|
|
139 |
if cache_path is not None: |
|
|
140 |
self.use_cache = True |
|
|
141 |
self.cache = cache(cache_path, self.cached_extension) |
|
|
142 |
self.split_augmentations = split_augmentations_by_cache(augmentations) |
|
|
143 |
else: |
|
|
144 |
self.use_cache = False |
|
|
145 |
self.composed_all_augmentations = ComposeAug(augmentations) |
|
|
146 |
|
|
|
147 |
@abstractmethod |
|
|
148 |
def load_input(self, path): |
|
|
149 |
pass |
|
|
150 |
|
|
|
151 |
@property |
|
|
152 |
@abstractmethod |
|
|
153 |
def cached_extension(self): |
|
|
154 |
pass |
|
|
155 |
|
|
|
156 |
def configure_path(self, path, sample=None): |
|
|
157 |
return path |
|
|
158 |
|
|
|
159 |
def get_image(self, path, sample=None): |
|
|
160 |
""" |
|
|
161 |
Returns a transformed image by its absolute path. |
|
|
162 |
If cache is used - transformed image will be loaded if available, |
|
|
163 |
and saved to cache if not. |
|
|
164 |
""" |
|
|
165 |
input_dict = {} |
|
|
166 |
input_path = self.configure_path(path, sample) |
|
|
167 |
|
|
|
168 |
if input_path == self.pad_token: |
|
|
169 |
return self.load_input(input_path) |
|
|
170 |
|
|
|
171 |
if not self.use_cache: |
|
|
172 |
input_dict = self.load_input(input_path) |
|
|
173 |
# hidden loaders typically do not use augmentation |
|
|
174 |
if self.apply_augmentations: |
|
|
175 |
input_dict = self.composed_all_augmentations(input_dict, sample) |
|
|
176 |
return input_dict |
|
|
177 |
|
|
|
178 |
if self.args.use_annotations: |
|
|
179 |
input_dict["mask"] = get_scaled_annotation_mask( |
|
|
180 |
input_dict["annotations"], self.args |
|
|
181 |
) |
|
|
182 |
|
|
|
183 |
for key, post_augmentations in self.split_augmentations: |
|
|
184 |
base_key = ( |
|
|
185 |
DEFAULT_CACHE_DIR + key |
|
|
186 |
if key != DEFAULT_CACHE_DIR |
|
|
187 |
else DEFAULT_CACHE_DIR |
|
|
188 |
) |
|
|
189 |
if self.cache.exists(input_path, base_key): |
|
|
190 |
try: |
|
|
191 |
input_dict["input"] = self.cache.get(input_path, base_key) |
|
|
192 |
if self.apply_augmentations: |
|
|
193 |
input_dict = apply_augmentations_and_cache( |
|
|
194 |
input_dict, |
|
|
195 |
sample, |
|
|
196 |
input_path, |
|
|
197 |
post_augmentations, |
|
|
198 |
self.cache, |
|
|
199 |
base_key=base_key, |
|
|
200 |
) |
|
|
201 |
return input_dict |
|
|
202 |
except Exception as e: |
|
|
203 |
print(e) |
|
|
204 |
hashed_key = md5(input_path) |
|
|
205 |
par_dir = self.cache._parent_dir(input_path) |
|
|
206 |
corrupted_file = self.cache._file_path(key, par_dir, hashed_key) |
|
|
207 |
warnings.warn(CORUPTED_FILE_ERR.format(sys.exc_info()[0])) |
|
|
208 |
self.cache.rem(input_path, key) |
|
|
209 |
all_augmentations = self.split_augmentations[-1][1] |
|
|
210 |
input_dict = self.load_input(input_path) |
|
|
211 |
if self.apply_augmentations: |
|
|
212 |
input_dict = apply_augmentations_and_cache( |
|
|
213 |
input_dict, |
|
|
214 |
sample, |
|
|
215 |
input_path, |
|
|
216 |
all_augmentations, |
|
|
217 |
self.cache, |
|
|
218 |
base_key=key, |
|
|
219 |
) |
|
|
220 |
|
|
|
221 |
return input_dict |