a b/sybil/loaders/abstract_loader.py
1
import torch
2
import os
3
import sys
4
import os.path
5
import warnings
6
from sybil.datasets.utils import get_scaled_annotation_mask, IMG_PAD_TOKEN
7
from sybil.augmentations import ComposeAug
8
import numpy as np
9
from abc import ABCMeta, abstractmethod
10
import hashlib
11
12
13
CACHED_FILES_EXT = ".png"
14
DEFAULT_CACHE_DIR = "default/"
15
16
CORUPTED_FILE_ERR = (
17
    "WARNING! Error processing file from cache - removed file from cache. Error: {}"
18
)
19
20
21
def md5(key):
22
    """
23
    returns a hashed with md5 string of the key
24
    """
25
    return hashlib.md5(key.encode()).hexdigest()
26
27
28
def split_augmentations_by_cache(augmentations):
29
    """
30
    Given a list of augmentations, returns a list of tuples. Each tuple
31
    contains a caching key of the augmentations up to the spiltting point,
32
    and a list of augmentations that should be applied afterwards.
33
34
    split_augmentations will contain all possible splits by cachable augmentations,
35
    ordered from latest possible one to the former ones.
36
    The last tuple will have all augmentations.
37
38
    Note - splitting will be done for indexes that all augmentations up to them are
39
    cachable.
40
    """
41
    # list of (cache key, post augmentations)
42
    split_augmentations = []
43
    split_augmentations.append((DEFAULT_CACHE_DIR, augmentations))
44
    all_prev_cachable = True
45
    key = ""
46
    for ind, trans in enumerate(augmentations):
47
48
        # check trans.cachable() first separately to save run time
49
        if not all_prev_cachable or not trans.cachable():
50
            all_prev_cachable = False
51
        else:
52
            key += trans.caching_keys()
53
            post_augmentations = (
54
                augmentations[ind + 1 :] if ind < len(augmentations) else []
55
            )
56
            split_augmentations.append((key, post_augmentations))
57
58
    return list(reversed(split_augmentations))
59
60
61
def apply_augmentations_and_cache(
62
    loaded_input, sample, img_path, augmentations, cache, base_key=""
63
):
64
    """
65
    Loads the loaded input by its absolute path and apply the augmentations one
66
    by one (similar to what the composed one is doing).  All first cachable
67
    transformer's output is cached (until reaching a non cachable one).
68
    """
69
    all_prev_cachable = True
70
    key = base_key
71
    for ind, trans in enumerate(augmentations):
72
        loaded_input = trans(loaded_input, sample)
73
        if not all_prev_cachable or not trans.cachable():
74
            all_prev_cachable = False
75
        else:
76
            key += trans.caching_keys()
77
            cache.add(img_path, key, loaded_input["input"])
78
79
    return loaded_input
80
81
82
class cache:
83
    def __init__(self, path, extension=CACHED_FILES_EXT):
84
        if not os.path.exists(path):
85
            os.makedirs(path)
86
87
        self.cache_dir = path
88
        self.files_extension = extension
89
        if ".npy" != extension:
90
            self.files_extension += ".npy"
91
92
    def _file_dir(self, attr_key, par_dir):
93
        return os.path.join(self.cache_dir, attr_key, par_dir)
94
95
    def _file_path(self, attr_key, par_dir, hashed_key):
96
        return os.path.join(
97
            self.cache_dir, attr_key, par_dir, hashed_key + self.files_extension
98
        )
99
100
    def _parent_dir(self, path):
101
        return os.path.basename(os.path.dirname(path))
102
103
    def exists(self, image_path, attr_key):
104
        hashed_key = md5(image_path)
105
        par_dir = self._parent_dir(image_path)
106
        return os.path.isfile(self._file_path(attr_key, par_dir, hashed_key))
107
108
    def get(self, image_path, attr_key):
109
        hashed_key = md5(image_path)
110
        par_dir = self._parent_dir(image_path)
111
        return np.load(self._file_path(attr_key, par_dir, hashed_key))
112
113
    def add(self, image_path, attr_key, image):
114
        hashed_key = md5(image_path)
115
        par_dir = self._parent_dir(image_path)
116
        file_dir = self._file_dir(attr_key, par_dir)
117
        if not os.path.exists(file_dir):
118
            os.makedirs(file_dir)
119
        np.save(self._file_path(attr_key, par_dir, hashed_key), image)
120
121
    def rem(self, image_path, attr_key):
122
        hashed_key = md5(image_path)
123
        par_dir = self._parent_dir(image_path)
124
        try:
125
            os.remove(self._file_path(attr_key, par_dir, hashed_key))
126
        # Don't raise error if file not exists.
127
        except OSError:
128
            pass
129
130
131
class abstract_loader:
132
    __metaclass__ = ABCMeta
133
134
    def __init__(self, cache_path, augmentations, args, apply_augmentations=True):
135
        self.pad_token = IMG_PAD_TOKEN
136
        self.augmentations = augmentations
137
        self.args = args
138
        self.apply_augmentations = apply_augmentations
139
        if cache_path is not None:
140
            self.use_cache = True
141
            self.cache = cache(cache_path, self.cached_extension)
142
            self.split_augmentations = split_augmentations_by_cache(augmentations)
143
        else:
144
            self.use_cache = False
145
            self.composed_all_augmentations = ComposeAug(augmentations)
146
147
    @abstractmethod
148
    def load_input(self, path):
149
        pass
150
151
    @property
152
    @abstractmethod
153
    def cached_extension(self):
154
        pass
155
156
    def configure_path(self, path, sample=None):
157
        return path
158
159
    def get_image(self, path, sample=None):
160
        """
161
        Returns a transformed image by its absolute path.
162
        If cache is used - transformed image will be loaded if available,
163
        and saved to cache if not.
164
        """
165
        input_dict = {}
166
        input_path = self.configure_path(path, sample)
167
168
        if input_path == self.pad_token:
169
            return self.load_input(input_path)
170
171
        if not self.use_cache:
172
            input_dict = self.load_input(input_path)
173
            # hidden loaders typically do not use augmentation
174
            if self.apply_augmentations:
175
                input_dict = self.composed_all_augmentations(input_dict, sample)
176
            return input_dict
177
178
        if self.args.use_annotations:
179
            input_dict["mask"] = get_scaled_annotation_mask(
180
                input_dict["annotations"], self.args
181
            )
182
183
        for key, post_augmentations in self.split_augmentations:
184
            base_key = (
185
                DEFAULT_CACHE_DIR + key
186
                if key != DEFAULT_CACHE_DIR
187
                else DEFAULT_CACHE_DIR
188
            )
189
            if self.cache.exists(input_path, base_key):
190
                try:
191
                    input_dict["input"] = self.cache.get(input_path, base_key)
192
                    if self.apply_augmentations:
193
                        input_dict = apply_augmentations_and_cache(
194
                            input_dict,
195
                            sample,
196
                            input_path,
197
                            post_augmentations,
198
                            self.cache,
199
                            base_key=base_key,
200
                        )
201
                    return input_dict
202
                except Exception as e:
203
                    print(e)
204
                    hashed_key = md5(input_path)
205
                    par_dir = self.cache._parent_dir(input_path)
206
                    corrupted_file = self.cache._file_path(key, par_dir, hashed_key)
207
                    warnings.warn(CORUPTED_FILE_ERR.format(sys.exc_info()[0]))
208
                    self.cache.rem(input_path, key)
209
        all_augmentations = self.split_augmentations[-1][1]
210
        input_dict = self.load_input(input_path)
211
        if self.apply_augmentations:
212
            input_dict = apply_augmentations_and_cache(
213
                input_dict,
214
                sample,
215
                input_path,
216
                all_augmentations,
217
                self.cache,
218
                base_key=key,
219
            )
220
221
        return input_dict