a b/BioSeqNet/resnest/transforms.py
1
# code adapted from:
2
# https://github.com/kakaobrain/fast-autoaugment
3
# https://github.com/rpmcruz/autoaugment
4
import math
5
import random
6
7
import numpy as np
8
from collections import defaultdict
9
import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw
10
from PIL import Image
11
12
random_mirror = True
13
14
RESAMPLE_MODE=Image.BICUBIC
15
16
def ShearX(img, v):  # [-0.3, 0.3]
17
    assert -0.3 <= v <= 0.3
18
    if random_mirror and random.random() > 0.5:
19
        v = -v
20
    return img.transform(img.size, Image.AFFINE, (1, v, 0, 0, 1, 0),
21
                         RESAMPLE_MODE)
22
23
24
def ShearY(img, v):  # [-0.3, 0.3]
25
    assert -0.3 <= v <= 0.3
26
    if random_mirror and random.random() > 0.5:
27
        v = -v
28
    return img.transform(img.size, Image.AFFINE, (1, 0, 0, v, 1, 0),
29
                         RESAMPLE_MODE)
30
31
32
def TranslateX(img, v):  # [-150, 150] => percentage: [-0.45, 0.45]
33
    assert -0.45 <= v <= 0.45
34
    if random_mirror and random.random() > 0.5:
35
        v = -v
36
    v = v * img.size[0]
37
    return img.transform(img.size, Image.AFFINE, (1, 0, v, 0, 1, 0),
38
                         RESAMPLE_MODE)
39
40
41
def TranslateY(img, v):  # [-150, 150] => percentage: [-0.45, 0.45]
42
    assert -0.45 <= v <= 0.45
43
    if random_mirror and random.random() > 0.5:
44
        v = -v
45
    v = v * img.size[1]
46
    return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, v),
47
                         RESAMPLE_MODE)
48
49
50
def TranslateXabs(img, v):  # [-150, 150] => percentage: [-0.45, 0.45]
51
    assert 0 <= v
52
    if random.random() > 0.5:
53
        v = -v
54
    return img.transform(img.size, Image.AFFINE, (1, 0, v, 0, 1, 0),
55
                         RESAMPLE_MODE)
56
57
58
def TranslateYabs(img, v):  # [-150, 150] => percentage: [-0.45, 0.45]
59
    assert 0 <= v
60
    if random.random() > 0.5:
61
        v = -v
62
    return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, v),
63
                         RESAMPLE_MODE)
64
65
66
def Rotate(img, v):  # [-30, 30]
67
    assert -30 <= v <= 30
68
    if random_mirror and random.random() > 0.5:
69
        v = -v
70
    return img.rotate(v)
71
72
73
def AutoContrast(img, _):
74
    return PIL.ImageOps.autocontrast(img)
75
76
77
def Invert(img, _):
78
    return PIL.ImageOps.invert(img)
79
80
81
def Equalize(img, _):
82
    return PIL.ImageOps.equalize(img)
83
84
85
def Flip(img, _):  # not from the paper
86
    return PIL.ImageOps.mirror(img)
87
88
89
def Solarize(img, v):  # [0, 256]
90
    assert 0 <= v <= 256
91
    return PIL.ImageOps.solarize(img, v)
92
93
94
def SolarizeAdd(img, addition=0, threshold=128):
95
    img_np = np.array(img).astype(np.int)
96
    img_np = img_np + addition
97
    img_np = np.clip(img_np, 0, 255)
98
    img_np = img_np.astype(np.uint8)
99
    img = Image.fromarray(img_np)
100
    return PIL.ImageOps.solarize(img, threshold)
101
102
103
def Posterize(img, v):  # [4, 8]
104
    #assert 4 <= v <= 8
105
    v = int(v)
106
    return PIL.ImageOps.posterize(img, v)
107
108
def Contrast(img, v):  # [0.1,1.9]
109
    assert 0.1 <= v <= 1.9
110
    return PIL.ImageEnhance.Contrast(img).enhance(v)
111
112
113
def Color(img, v):  # [0.1,1.9]
114
    assert 0.1 <= v <= 1.9
115
    return PIL.ImageEnhance.Color(img).enhance(v)
116
117
118
def Brightness(img, v):  # [0.1,1.9]
119
    assert 0.1 <= v <= 1.9
120
    return PIL.ImageEnhance.Brightness(img).enhance(v)
121
122
123
def Sharpness(img, v):  # [0.1,1.9]
124
    assert 0.1 <= v <= 1.9
125
    return PIL.ImageEnhance.Sharpness(img).enhance(v)
126
127
128
def CutoutAbs(img, v):  # [0, 60] => percentage: [0, 0.2]
129
    # assert 0 <= v <= 20
130
    if v < 0:
131
        return img
132
    w, h = img.size
133
    x0 = np.random.uniform(w)
134
    y0 = np.random.uniform(h)
135
136
    x0 = int(max(0, x0 - v / 2.))
137
    y0 = int(max(0, y0 - v / 2.))
138
    x1 = min(w, x0 + v)
139
    y1 = min(h, y0 + v)
140
141
    xy = (x0, y0, x1, y1)
142
    color = (125, 123, 114)
143
    # color = (0, 0, 0)
144
    img = img.copy()
145
    PIL.ImageDraw.Draw(img).rectangle(xy, color)
146
    return img
147
148
149
def Cutout(img, v):  # [0, 60] => percentage: [0, 0.2]
150
    assert 0.0 <= v <= 0.2
151
    if v <= 0.:
152
        return img
153
154
    v = v * img.size[0]
155
    return CutoutAbs(img, v)
156
157
158
159
def TranslateYAbs(img, v):  # [-150, 150] => percentage: [-0.45, 0.45]
160
    assert 0 <= v <= 10
161
    if random.random() > 0.5:
162
        v = -v
163
    return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, v),
164
                         resample=RESAMPLE_MODE)
165
166
167
def TranslateXAbs(img, v):  # [-150, 150] => percentage: [-0.45, 0.45]
168
    assert 0 <= v <= 10
169
    if random.random() > 0.5:
170
        v = -v
171
    return img.transform(img.size, Image.AFFINE, (1, 0, v, 0, 1, 0),
172
                         resample=RESAMPLE_MODE)
173
174
175
def Posterize2(img, v):  # [0, 4]
176
    assert 0 <= v <= 4
177
    v = int(v)
178
    return PIL.ImageOps.posterize(img, v)
179
180
181
182
def SamplePairing(imgs):  # [0, 0.4]
183
    def f(img1, v):
184
        i = np.random.choice(len(imgs))
185
        img2 = Image.fromarray(imgs[i])
186
        return Image.blend(img1, img2, v)
187
188
    return f
189
190
191
def augment_list(for_autoaug=True):  # 16 oeprations and their ranges
192
    l = [
193
        (ShearX, -0.3, 0.3),  # 0
194
        (ShearY, -0.3, 0.3),  # 1
195
        (TranslateX, -0.45, 0.45),  # 2
196
        (TranslateY, -0.45, 0.45),  # 3
197
        (Rotate, -30, 30),  # 4
198
        (AutoContrast, 0, 1),  # 5
199
        (Invert, 0, 1),  # 6
200
        (Equalize, 0, 1),  # 7
201
        (Solarize, 0, 256),  # 8
202
        (Posterize, 4, 8),  # 9
203
        (Contrast, 0.1, 1.9),  # 10
204
        (Color, 0.1, 1.9),  # 11
205
        (Brightness, 0.1, 1.9),  # 12
206
        (Sharpness, 0.1, 1.9),  # 13
207
        (Cutout, 0, 0.2),  # 14
208
        # (SamplePairing(imgs), 0, 0.4),  # 15
209
    ]
210
    if for_autoaug:
211
        l += [
212
            (CutoutAbs, 0, 20),  # compatible with auto-augment
213
            (Posterize2, 0, 4),  # 9
214
            (TranslateXAbs, 0, 10),  # 9
215
            (TranslateYAbs, 0, 10),  # 9
216
        ]
217
    return l
218
219
220
augment_dict = {fn.__name__: (fn, v1, v2) for fn, v1, v2 in augment_list()}
221
222
PARAMETER_MAX = 10
223
224
225
def float_parameter(level, maxval):
226
    return float(level) * maxval / PARAMETER_MAX
227
228
229
def int_parameter(level, maxval):
230
    return int(float_parameter(level, maxval))
231
232
233
def autoaug2fastaa(f):
234
    def autoaug():
235
        mapper = defaultdict(lambda: lambda x: x)
236
        mapper.update({
237
            'ShearX': lambda x: float_parameter(x, 0.3),
238
            'ShearY': lambda x: float_parameter(x, 0.3),
239
            'TranslateX': lambda x: int_parameter(x, 10),
240
            'TranslateY': lambda x: int_parameter(x, 10),
241
            'Rotate': lambda x: int_parameter(x, 30),
242
            'Solarize': lambda x: 256 - int_parameter(x, 256),
243
            'Posterize2': lambda x: 4 - int_parameter(x, 4),
244
            'Contrast': lambda x: float_parameter(x, 1.8) + .1,
245
            'Color': lambda x: float_parameter(x, 1.8) + .1,
246
            'Brightness': lambda x: float_parameter(x, 1.8) + .1,
247
            'Sharpness': lambda x: float_parameter(x, 1.8) + .1,
248
            'CutoutAbs': lambda x: int_parameter(x, 20)
249
        })
250
251
        def low_high(name, prev_value):
252
            _, low, high = get_augment(name)
253
            return float(prev_value - low) / (high - low)
254
255
        policies = f()
256
        new_policies = []
257
        for policy in policies:
258
            new_policies.append([(name, pr, low_high(name, mapper[name](level))) for name, pr, level in policy])
259
        return new_policies
260
261
    return autoaug
262
263
264
@autoaug2fastaa
265
def autoaug_imagenet_policies():
266
    return [
267
        [('Posterize2', 0.4, 8), ('Rotate', 0.6, 9)],
268
        [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
269
        [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
270
        [('Posterize2', 0.6, 7), ('Posterize2', 0.6, 6)],
271
        [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
272
        [('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
273
        [('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
274
        [('Posterize2', 0.8, 5), ('Equalize', 1.0, 2)],
275
        [('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
276
        [('Equalize', 0.6, 8), ('Posterize2', 0.4, 6)],
277
        [('Rotate', 0.8, 8), ('Color', 0.4, 0)],
278
        [('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
279
        [('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
280
        [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
281
        [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
282
        [('Rotate', 0.8, 8), ('Color', 1.0, 0)],
283
        [('Color', 0.8, 8), ('Solarize', 0.8, 7)],
284
        [('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
285
        [('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
286
        [('Color', 0.4, 0), ('Equalize', 0.6, 3)],
287
        [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
288
        [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
289
        [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
290
        [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
291
        [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
292
    ]
293
294
295
def get_augment(name):
296
    return augment_dict[name]
297
298
299
def apply_augment(img, name, level):
300
    augment_fn, low, high = get_augment(name)
301
    return augment_fn(img.copy(), level * (high - low) + low)
302
303
304
def rand_augment_list():  # 16 oeprations and their ranges
305
    l = [
306
        (AutoContrast, 0, 1),
307
        (Equalize, 0, 1),
308
        (Invert, 0, 1),
309
        (Rotate, 0, 30),
310
        (Posterize, 0, 4),
311
        (Solarize, 0, 256),
312
        (SolarizeAdd, 0, 110),
313
        (Color, 0.1, 1.9),
314
        (Contrast, 0.1, 1.9),
315
        (Brightness, 0.1, 1.9),
316
        (Sharpness, 0.1, 1.9),
317
        (ShearX, 0., 0.3),
318
        (ShearY, 0., 0.3),
319
        (CutoutAbs, 0, 40),
320
        (TranslateXabs, 0., 100),
321
        (TranslateYabs, 0., 100),
322
    ]
323
324
    return l
325
326
327
328
class ERandomCrop:
329
    # pylint: disable=misplaced-comparison-constant
330
    def __init__(self, imgsize, min_covered=0.1, aspect_ratio_range=(3./4, 4./3),
331
                 area_range=(0.1, 1.0), max_attempts=10):
332
        assert 0.0 < min_covered
333
        assert 0 < aspect_ratio_range[0] <= aspect_ratio_range[1]
334
        assert 0 < area_range[0] <= area_range[1]
335
        assert 1 <= max_attempts
336
337
        self.min_covered = min_covered
338
        self.aspect_ratio_range = aspect_ratio_range
339
        self.area_range = area_range
340
        self.max_attempts = max_attempts
341
        self._fallback = ECenterCrop(imgsize)
342
343
    def __call__(self, img):
344
        # https://github.com/tensorflow/tensorflow/blob/9274bcebb31322370139467039034f8ff852b004/tensorflow/core/kernels/sample_distorted_bounding_box_op.cc#L111
345
        original_width, original_height = img.size
346
        min_area = self.area_range[0] * (original_width * original_height)
347
        max_area = self.area_range[1] * (original_width * original_height)
348
349
        for _ in range(self.max_attempts):
350
            aspect_ratio = random.uniform(*self.aspect_ratio_range)
351
            height = int(round(math.sqrt(min_area / aspect_ratio)))
352
            max_height = int(round(math.sqrt(max_area / aspect_ratio)))
353
354
            if max_height * aspect_ratio > original_width:
355
                max_height = (original_width + 0.5 - 1e-7) / aspect_ratio
356
                max_height = int(max_height)
357
                if max_height * aspect_ratio > original_width:
358
                    max_height -= 1
359
360
            if max_height > original_height:
361
                max_height = original_height
362
363
            if height >= max_height:
364
                height = max_height
365
366
            height = int(round(random.uniform(height, max_height)))
367
            width = int(round(height * aspect_ratio))
368
            area = width * height
369
370
            if area < min_area or area > max_area:
371
                continue
372
            if width > original_width or height > original_height:
373
                continue
374
            if area < self.min_covered * (original_width * original_height):
375
                continue
376
            if width == original_width and height == original_height:
377
                return self._fallback(img)
378
379
            x = random.randint(0, original_width - width)
380
            y = random.randint(0, original_height - height)
381
            return img.crop((x, y, x + width, y + height))
382
383
        return self._fallback(img)
384
385
class ECenterCrop:
386
    """Crop the given PIL Image and resize it to desired size.
387
    Args:
388
        img (PIL Image): Image to be cropped. (0,0) denotes the top left corner of the image.
389
        output_size (sequence or int): (height, width) of the crop box. If int,
390
            it is used for both directions
391
    Returns:
392
        PIL Image: Cropped image.
393
    """
394
    def __init__(self, imgsize):
395
        self.imgsize = imgsize
396
        import torchvision.transforms as pth_transforms
397
        self.resize_method = pth_transforms.Resize((imgsize, imgsize), interpolation=RESAMPLE_MODE)
398
399
    def __call__(self, img):
400
        image_width, image_height = img.size
401
        image_short = min(image_width, image_height)
402
403
        crop_size = float(self.imgsize) / (self.imgsize + 32) * image_short
404
405
        crop_height, crop_width = crop_size, crop_size
406
        crop_top = int(round((image_height - crop_height) / 2.))
407
        crop_left = int(round((image_width - crop_width) / 2.))
408
        img = img.crop((crop_left, crop_top, crop_left + crop_width, crop_top + crop_height))
409
        return self.resize_method(img)
410
411