Switch to side-by-side view

--- a
+++ b/BioSeqNet/resnest/transforms.py
@@ -0,0 +1,411 @@
+# code adapted from:
+# https://github.com/kakaobrain/fast-autoaugment
+# https://github.com/rpmcruz/autoaugment
+import math
+import random
+
+import numpy as np
+from collections import defaultdict
+import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw
+from PIL import Image
+
+random_mirror = True
+
+RESAMPLE_MODE=Image.BICUBIC
+
+def ShearX(img, v):  # [-0.3, 0.3]
+    assert -0.3 <= v <= 0.3
+    if random_mirror and random.random() > 0.5:
+        v = -v
+    return img.transform(img.size, Image.AFFINE, (1, v, 0, 0, 1, 0),
+                         RESAMPLE_MODE)
+
+
+def ShearY(img, v):  # [-0.3, 0.3]
+    assert -0.3 <= v <= 0.3
+    if random_mirror and random.random() > 0.5:
+        v = -v
+    return img.transform(img.size, Image.AFFINE, (1, 0, 0, v, 1, 0),
+                         RESAMPLE_MODE)
+
+
+def TranslateX(img, v):  # [-150, 150] => percentage: [-0.45, 0.45]
+    assert -0.45 <= v <= 0.45
+    if random_mirror and random.random() > 0.5:
+        v = -v
+    v = v * img.size[0]
+    return img.transform(img.size, Image.AFFINE, (1, 0, v, 0, 1, 0),
+                         RESAMPLE_MODE)
+
+
+def TranslateY(img, v):  # [-150, 150] => percentage: [-0.45, 0.45]
+    assert -0.45 <= v <= 0.45
+    if random_mirror and random.random() > 0.5:
+        v = -v
+    v = v * img.size[1]
+    return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, v),
+                         RESAMPLE_MODE)
+
+
+def TranslateXabs(img, v):  # [-150, 150] => percentage: [-0.45, 0.45]
+    assert 0 <= v
+    if random.random() > 0.5:
+        v = -v
+    return img.transform(img.size, Image.AFFINE, (1, 0, v, 0, 1, 0),
+                         RESAMPLE_MODE)
+
+
+def TranslateYabs(img, v):  # [-150, 150] => percentage: [-0.45, 0.45]
+    assert 0 <= v
+    if random.random() > 0.5:
+        v = -v
+    return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, v),
+                         RESAMPLE_MODE)
+
+
+def Rotate(img, v):  # [-30, 30]
+    assert -30 <= v <= 30
+    if random_mirror and random.random() > 0.5:
+        v = -v
+    return img.rotate(v)
+
+
+def AutoContrast(img, _):
+    return PIL.ImageOps.autocontrast(img)
+
+
+def Invert(img, _):
+    return PIL.ImageOps.invert(img)
+
+
+def Equalize(img, _):
+    return PIL.ImageOps.equalize(img)
+
+
+def Flip(img, _):  # not from the paper
+    return PIL.ImageOps.mirror(img)
+
+
+def Solarize(img, v):  # [0, 256]
+    assert 0 <= v <= 256
+    return PIL.ImageOps.solarize(img, v)
+
+
+def SolarizeAdd(img, addition=0, threshold=128):
+    img_np = np.array(img).astype(np.int)
+    img_np = img_np + addition
+    img_np = np.clip(img_np, 0, 255)
+    img_np = img_np.astype(np.uint8)
+    img = Image.fromarray(img_np)
+    return PIL.ImageOps.solarize(img, threshold)
+
+
+def Posterize(img, v):  # [4, 8]
+    #assert 4 <= v <= 8
+    v = int(v)
+    return PIL.ImageOps.posterize(img, v)
+
+def Contrast(img, v):  # [0.1,1.9]
+    assert 0.1 <= v <= 1.9
+    return PIL.ImageEnhance.Contrast(img).enhance(v)
+
+
+def Color(img, v):  # [0.1,1.9]
+    assert 0.1 <= v <= 1.9
+    return PIL.ImageEnhance.Color(img).enhance(v)
+
+
+def Brightness(img, v):  # [0.1,1.9]
+    assert 0.1 <= v <= 1.9
+    return PIL.ImageEnhance.Brightness(img).enhance(v)
+
+
+def Sharpness(img, v):  # [0.1,1.9]
+    assert 0.1 <= v <= 1.9
+    return PIL.ImageEnhance.Sharpness(img).enhance(v)
+
+
+def CutoutAbs(img, v):  # [0, 60] => percentage: [0, 0.2]
+    # assert 0 <= v <= 20
+    if v < 0:
+        return img
+    w, h = img.size
+    x0 = np.random.uniform(w)
+    y0 = np.random.uniform(h)
+
+    x0 = int(max(0, x0 - v / 2.))
+    y0 = int(max(0, y0 - v / 2.))
+    x1 = min(w, x0 + v)
+    y1 = min(h, y0 + v)
+
+    xy = (x0, y0, x1, y1)
+    color = (125, 123, 114)
+    # color = (0, 0, 0)
+    img = img.copy()
+    PIL.ImageDraw.Draw(img).rectangle(xy, color)
+    return img
+
+
+def Cutout(img, v):  # [0, 60] => percentage: [0, 0.2]
+    assert 0.0 <= v <= 0.2
+    if v <= 0.:
+        return img
+
+    v = v * img.size[0]
+    return CutoutAbs(img, v)
+
+
+
+def TranslateYAbs(img, v):  # [-150, 150] => percentage: [-0.45, 0.45]
+    assert 0 <= v <= 10
+    if random.random() > 0.5:
+        v = -v
+    return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, v),
+                         resample=RESAMPLE_MODE)
+
+
+def TranslateXAbs(img, v):  # [-150, 150] => percentage: [-0.45, 0.45]
+    assert 0 <= v <= 10
+    if random.random() > 0.5:
+        v = -v
+    return img.transform(img.size, Image.AFFINE, (1, 0, v, 0, 1, 0),
+                         resample=RESAMPLE_MODE)
+
+
+def Posterize2(img, v):  # [0, 4]
+    assert 0 <= v <= 4
+    v = int(v)
+    return PIL.ImageOps.posterize(img, v)
+
+
+
+def SamplePairing(imgs):  # [0, 0.4]
+    def f(img1, v):
+        i = np.random.choice(len(imgs))
+        img2 = Image.fromarray(imgs[i])
+        return Image.blend(img1, img2, v)
+
+    return f
+
+
+def augment_list(for_autoaug=True):  # 16 oeprations and their ranges
+    l = [
+        (ShearX, -0.3, 0.3),  # 0
+        (ShearY, -0.3, 0.3),  # 1
+        (TranslateX, -0.45, 0.45),  # 2
+        (TranslateY, -0.45, 0.45),  # 3
+        (Rotate, -30, 30),  # 4
+        (AutoContrast, 0, 1),  # 5
+        (Invert, 0, 1),  # 6
+        (Equalize, 0, 1),  # 7
+        (Solarize, 0, 256),  # 8
+        (Posterize, 4, 8),  # 9
+        (Contrast, 0.1, 1.9),  # 10
+        (Color, 0.1, 1.9),  # 11
+        (Brightness, 0.1, 1.9),  # 12
+        (Sharpness, 0.1, 1.9),  # 13
+        (Cutout, 0, 0.2),  # 14
+        # (SamplePairing(imgs), 0, 0.4),  # 15
+    ]
+    if for_autoaug:
+        l += [
+            (CutoutAbs, 0, 20),  # compatible with auto-augment
+            (Posterize2, 0, 4),  # 9
+            (TranslateXAbs, 0, 10),  # 9
+            (TranslateYAbs, 0, 10),  # 9
+        ]
+    return l
+
+
+augment_dict = {fn.__name__: (fn, v1, v2) for fn, v1, v2 in augment_list()}
+
+PARAMETER_MAX = 10
+
+
+def float_parameter(level, maxval):
+    return float(level) * maxval / PARAMETER_MAX
+
+
+def int_parameter(level, maxval):
+    return int(float_parameter(level, maxval))
+
+
+def autoaug2fastaa(f):
+    def autoaug():
+        mapper = defaultdict(lambda: lambda x: x)
+        mapper.update({
+            'ShearX': lambda x: float_parameter(x, 0.3),
+            'ShearY': lambda x: float_parameter(x, 0.3),
+            'TranslateX': lambda x: int_parameter(x, 10),
+            'TranslateY': lambda x: int_parameter(x, 10),
+            'Rotate': lambda x: int_parameter(x, 30),
+            'Solarize': lambda x: 256 - int_parameter(x, 256),
+            'Posterize2': lambda x: 4 - int_parameter(x, 4),
+            'Contrast': lambda x: float_parameter(x, 1.8) + .1,
+            'Color': lambda x: float_parameter(x, 1.8) + .1,
+            'Brightness': lambda x: float_parameter(x, 1.8) + .1,
+            'Sharpness': lambda x: float_parameter(x, 1.8) + .1,
+            'CutoutAbs': lambda x: int_parameter(x, 20)
+        })
+
+        def low_high(name, prev_value):
+            _, low, high = get_augment(name)
+            return float(prev_value - low) / (high - low)
+
+        policies = f()
+        new_policies = []
+        for policy in policies:
+            new_policies.append([(name, pr, low_high(name, mapper[name](level))) for name, pr, level in policy])
+        return new_policies
+
+    return autoaug
+
+
+@autoaug2fastaa
+def autoaug_imagenet_policies():
+    return [
+        [('Posterize2', 0.4, 8), ('Rotate', 0.6, 9)],
+        [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
+        [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
+        [('Posterize2', 0.6, 7), ('Posterize2', 0.6, 6)],
+        [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
+        [('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
+        [('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
+        [('Posterize2', 0.8, 5), ('Equalize', 1.0, 2)],
+        [('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
+        [('Equalize', 0.6, 8), ('Posterize2', 0.4, 6)],
+        [('Rotate', 0.8, 8), ('Color', 0.4, 0)],
+        [('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
+        [('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
+        [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
+        [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
+        [('Rotate', 0.8, 8), ('Color', 1.0, 0)],
+        [('Color', 0.8, 8), ('Solarize', 0.8, 7)],
+        [('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
+        [('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
+        [('Color', 0.4, 0), ('Equalize', 0.6, 3)],
+        [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
+        [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
+        [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
+        [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
+        [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
+    ]
+
+
+def get_augment(name):
+    return augment_dict[name]
+
+
+def apply_augment(img, name, level):
+    augment_fn, low, high = get_augment(name)
+    return augment_fn(img.copy(), level * (high - low) + low)
+
+
+def rand_augment_list():  # 16 oeprations and their ranges
+    l = [
+        (AutoContrast, 0, 1),
+        (Equalize, 0, 1),
+        (Invert, 0, 1),
+        (Rotate, 0, 30),
+        (Posterize, 0, 4),
+        (Solarize, 0, 256),
+        (SolarizeAdd, 0, 110),
+        (Color, 0.1, 1.9),
+        (Contrast, 0.1, 1.9),
+        (Brightness, 0.1, 1.9),
+        (Sharpness, 0.1, 1.9),
+        (ShearX, 0., 0.3),
+        (ShearY, 0., 0.3),
+        (CutoutAbs, 0, 40),
+        (TranslateXabs, 0., 100),
+        (TranslateYabs, 0., 100),
+    ]
+
+    return l
+
+
+
+class ERandomCrop:
+    # pylint: disable=misplaced-comparison-constant
+    def __init__(self, imgsize, min_covered=0.1, aspect_ratio_range=(3./4, 4./3),
+                 area_range=(0.1, 1.0), max_attempts=10):
+        assert 0.0 < min_covered
+        assert 0 < aspect_ratio_range[0] <= aspect_ratio_range[1]
+        assert 0 < area_range[0] <= area_range[1]
+        assert 1 <= max_attempts
+
+        self.min_covered = min_covered
+        self.aspect_ratio_range = aspect_ratio_range
+        self.area_range = area_range
+        self.max_attempts = max_attempts
+        self._fallback = ECenterCrop(imgsize)
+
+    def __call__(self, img):
+        # https://github.com/tensorflow/tensorflow/blob/9274bcebb31322370139467039034f8ff852b004/tensorflow/core/kernels/sample_distorted_bounding_box_op.cc#L111
+        original_width, original_height = img.size
+        min_area = self.area_range[0] * (original_width * original_height)
+        max_area = self.area_range[1] * (original_width * original_height)
+
+        for _ in range(self.max_attempts):
+            aspect_ratio = random.uniform(*self.aspect_ratio_range)
+            height = int(round(math.sqrt(min_area / aspect_ratio)))
+            max_height = int(round(math.sqrt(max_area / aspect_ratio)))
+
+            if max_height * aspect_ratio > original_width:
+                max_height = (original_width + 0.5 - 1e-7) / aspect_ratio
+                max_height = int(max_height)
+                if max_height * aspect_ratio > original_width:
+                    max_height -= 1
+
+            if max_height > original_height:
+                max_height = original_height
+
+            if height >= max_height:
+                height = max_height
+
+            height = int(round(random.uniform(height, max_height)))
+            width = int(round(height * aspect_ratio))
+            area = width * height
+
+            if area < min_area or area > max_area:
+                continue
+            if width > original_width or height > original_height:
+                continue
+            if area < self.min_covered * (original_width * original_height):
+                continue
+            if width == original_width and height == original_height:
+                return self._fallback(img)
+
+            x = random.randint(0, original_width - width)
+            y = random.randint(0, original_height - height)
+            return img.crop((x, y, x + width, y + height))
+
+        return self._fallback(img)
+
+class ECenterCrop:
+    """Crop the given PIL Image and resize it to desired size.
+    Args:
+        img (PIL Image): Image to be cropped. (0,0) denotes the top left corner of the image.
+        output_size (sequence or int): (height, width) of the crop box. If int,
+            it is used for both directions
+    Returns:
+        PIL Image: Cropped image.
+    """
+    def __init__(self, imgsize):
+        self.imgsize = imgsize
+        import torchvision.transforms as pth_transforms
+        self.resize_method = pth_transforms.Resize((imgsize, imgsize), interpolation=RESAMPLE_MODE)
+
+    def __call__(self, img):
+        image_width, image_height = img.size
+        image_short = min(image_width, image_height)
+
+        crop_size = float(self.imgsize) / (self.imgsize + 32) * image_short
+
+        crop_height, crop_width = crop_size, crop_size
+        crop_top = int(round((image_height - crop_height) / 2.))
+        crop_left = int(round((image_width - crop_width) / 2.))
+        img = img.crop((crop_left, crop_top, crop_left + crop_width, crop_top + crop_height))
+        return self.resize_method(img)
+
+