Switch to side-by-side view

--- a
+++ b/lavis/processors/randaugment.py
@@ -0,0 +1,398 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import cv2
+import numpy as np
+
+import torch
+
+
+## aug functions
+def identity_func(img):
+    return img
+
+
+def autocontrast_func(img, cutoff=0):
+    """
+    same output as PIL.ImageOps.autocontrast
+    """
+    n_bins = 256
+
+    def tune_channel(ch):
+        n = ch.size
+        cut = cutoff * n // 100
+        if cut == 0:
+            high, low = ch.max(), ch.min()
+        else:
+            hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
+            low = np.argwhere(np.cumsum(hist) > cut)
+            low = 0 if low.shape[0] == 0 else low[0]
+            high = np.argwhere(np.cumsum(hist[::-1]) > cut)
+            high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
+        if high <= low:
+            table = np.arange(n_bins)
+        else:
+            scale = (n_bins - 1) / (high - low)
+            offset = -low * scale
+            table = np.arange(n_bins) * scale + offset
+            table[table < 0] = 0
+            table[table > n_bins - 1] = n_bins - 1
+        table = table.clip(0, 255).astype(np.uint8)
+        return table[ch]
+
+    channels = [tune_channel(ch) for ch in cv2.split(img)]
+    out = cv2.merge(channels)
+    return out
+
+
+def equalize_func(img):
+    """
+    same output as PIL.ImageOps.equalize
+    PIL's implementation is different from cv2.equalize
+    """
+    n_bins = 256
+
+    def tune_channel(ch):
+        hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
+        non_zero_hist = hist[hist != 0].reshape(-1)
+        step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
+        if step == 0:
+            return ch
+        n = np.empty_like(hist)
+        n[0] = step // 2
+        n[1:] = hist[:-1]
+        table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
+        return table[ch]
+
+    channels = [tune_channel(ch) for ch in cv2.split(img)]
+    out = cv2.merge(channels)
+    return out
+
+
+def rotate_func(img, degree, fill=(0, 0, 0)):
+    """
+    like PIL, rotate by degree, not radians
+    """
+    H, W = img.shape[0], img.shape[1]
+    center = W / 2, H / 2
+    M = cv2.getRotationMatrix2D(center, degree, 1)
+    out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
+    return out
+
+
+def solarize_func(img, thresh=128):
+    """
+    same output as PIL.ImageOps.posterize
+    """
+    table = np.array([el if el < thresh else 255 - el for el in range(256)])
+    table = table.clip(0, 255).astype(np.uint8)
+    out = table[img]
+    return out
+
+
+def color_func(img, factor):
+    """
+    same output as PIL.ImageEnhance.Color
+    """
+    ## implementation according to PIL definition, quite slow
+    #  degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
+    #  out = blend(degenerate, img, factor)
+    #  M = (
+    #      np.eye(3) * factor
+    #      + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
+    #  )[np.newaxis, np.newaxis, :]
+    M = np.float32(
+        [[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587], [-0.299, -0.299, 0.701]]
+    ) * factor + np.float32([[0.114], [0.587], [0.299]])
+    out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
+    return out
+
+
+def contrast_func(img, factor):
+    """
+    same output as PIL.ImageEnhance.Contrast
+    """
+    mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
+    table = (
+        np.array([(el - mean) * factor + mean for el in range(256)])
+        .clip(0, 255)
+        .astype(np.uint8)
+    )
+    out = table[img]
+    return out
+
+
+def brightness_func(img, factor):
+    """
+    same output as PIL.ImageEnhance.Contrast
+    """
+    table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
+    out = table[img]
+    return out
+
+
+def sharpness_func(img, factor):
+    """
+    The differences the this result and PIL are all on the 4 boundaries, the center
+    areas are same
+    """
+    kernel = np.ones((3, 3), dtype=np.float32)
+    kernel[1][1] = 5
+    kernel /= 13
+    degenerate = cv2.filter2D(img, -1, kernel)
+    if factor == 0.0:
+        out = degenerate
+    elif factor == 1.0:
+        out = img
+    else:
+        out = img.astype(np.float32)
+        degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
+        out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
+        out = out.astype(np.uint8)
+    return out
+
+
+def shear_x_func(img, factor, fill=(0, 0, 0)):
+    H, W = img.shape[0], img.shape[1]
+    M = np.float32([[1, factor, 0], [0, 1, 0]])
+    out = cv2.warpAffine(
+        img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
+    ).astype(np.uint8)
+    return out
+
+
+def translate_x_func(img, offset, fill=(0, 0, 0)):
+    """
+    same output as PIL.Image.transform
+    """
+    H, W = img.shape[0], img.shape[1]
+    M = np.float32([[1, 0, -offset], [0, 1, 0]])
+    out = cv2.warpAffine(
+        img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
+    ).astype(np.uint8)
+    return out
+
+
+def translate_y_func(img, offset, fill=(0, 0, 0)):
+    """
+    same output as PIL.Image.transform
+    """
+    H, W = img.shape[0], img.shape[1]
+    M = np.float32([[1, 0, 0], [0, 1, -offset]])
+    out = cv2.warpAffine(
+        img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
+    ).astype(np.uint8)
+    return out
+
+
+def posterize_func(img, bits):
+    """
+    same output as PIL.ImageOps.posterize
+    """
+    out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
+    return out
+
+
+def shear_y_func(img, factor, fill=(0, 0, 0)):
+    H, W = img.shape[0], img.shape[1]
+    M = np.float32([[1, 0, 0], [factor, 1, 0]])
+    out = cv2.warpAffine(
+        img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
+    ).astype(np.uint8)
+    return out
+
+
+def cutout_func(img, pad_size, replace=(0, 0, 0)):
+    replace = np.array(replace, dtype=np.uint8)
+    H, W = img.shape[0], img.shape[1]
+    rh, rw = np.random.random(2)
+    pad_size = pad_size // 2
+    ch, cw = int(rh * H), int(rw * W)
+    x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
+    y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
+    out = img.copy()
+    out[x1:x2, y1:y2, :] = replace
+    return out
+
+
+### level to args
+def enhance_level_to_args(MAX_LEVEL):
+    def level_to_args(level):
+        return ((level / MAX_LEVEL) * 1.8 + 0.1,)
+
+    return level_to_args
+
+
+def shear_level_to_args(MAX_LEVEL, replace_value):
+    def level_to_args(level):
+        level = (level / MAX_LEVEL) * 0.3
+        if np.random.random() > 0.5:
+            level = -level
+        return (level, replace_value)
+
+    return level_to_args
+
+
+def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
+    def level_to_args(level):
+        level = (level / MAX_LEVEL) * float(translate_const)
+        if np.random.random() > 0.5:
+            level = -level
+        return (level, replace_value)
+
+    return level_to_args
+
+
+def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
+    def level_to_args(level):
+        level = int((level / MAX_LEVEL) * cutout_const)
+        return (level, replace_value)
+
+    return level_to_args
+
+
+def solarize_level_to_args(MAX_LEVEL):
+    def level_to_args(level):
+        level = int((level / MAX_LEVEL) * 256)
+        return (level,)
+
+    return level_to_args
+
+
+def none_level_to_args(level):
+    return ()
+
+
+def posterize_level_to_args(MAX_LEVEL):
+    def level_to_args(level):
+        level = int((level / MAX_LEVEL) * 4)
+        return (level,)
+
+    return level_to_args
+
+
+def rotate_level_to_args(MAX_LEVEL, replace_value):
+    def level_to_args(level):
+        level = (level / MAX_LEVEL) * 30
+        if np.random.random() < 0.5:
+            level = -level
+        return (level, replace_value)
+
+    return level_to_args
+
+
+func_dict = {
+    "Identity": identity_func,
+    "AutoContrast": autocontrast_func,
+    "Equalize": equalize_func,
+    "Rotate": rotate_func,
+    "Solarize": solarize_func,
+    "Color": color_func,
+    "Contrast": contrast_func,
+    "Brightness": brightness_func,
+    "Sharpness": sharpness_func,
+    "ShearX": shear_x_func,
+    "TranslateX": translate_x_func,
+    "TranslateY": translate_y_func,
+    "Posterize": posterize_func,
+    "ShearY": shear_y_func,
+}
+
+translate_const = 10
+MAX_LEVEL = 10
+replace_value = (128, 128, 128)
+arg_dict = {
+    "Identity": none_level_to_args,
+    "AutoContrast": none_level_to_args,
+    "Equalize": none_level_to_args,
+    "Rotate": rotate_level_to_args(MAX_LEVEL, replace_value),
+    "Solarize": solarize_level_to_args(MAX_LEVEL),
+    "Color": enhance_level_to_args(MAX_LEVEL),
+    "Contrast": enhance_level_to_args(MAX_LEVEL),
+    "Brightness": enhance_level_to_args(MAX_LEVEL),
+    "Sharpness": enhance_level_to_args(MAX_LEVEL),
+    "ShearX": shear_level_to_args(MAX_LEVEL, replace_value),
+    "TranslateX": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
+    "TranslateY": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
+    "Posterize": posterize_level_to_args(MAX_LEVEL),
+    "ShearY": shear_level_to_args(MAX_LEVEL, replace_value),
+}
+
+
+class RandomAugment(object):
+    def __init__(self, N=2, M=10, isPIL=False, augs=[]):
+        self.N = N
+        self.M = M
+        self.isPIL = isPIL
+        if augs:
+            self.augs = augs
+        else:
+            self.augs = list(arg_dict.keys())
+
+    def get_random_ops(self):
+        sampled_ops = np.random.choice(self.augs, self.N)
+        return [(op, 0.5, self.M) for op in sampled_ops]
+
+    def __call__(self, img):
+        if self.isPIL:
+            img = np.array(img)
+        ops = self.get_random_ops()
+        for name, prob, level in ops:
+            if np.random.random() > prob:
+                continue
+            args = arg_dict[name](level)
+            img = func_dict[name](img, *args)
+        return img
+
+
+class VideoRandomAugment(object):
+    def __init__(self, N=2, M=10, p=0.0, tensor_in_tensor_out=True, augs=[]):
+        self.N = N
+        self.M = M
+        self.p = p
+        self.tensor_in_tensor_out = tensor_in_tensor_out
+        if augs:
+            self.augs = augs
+        else:
+            self.augs = list(arg_dict.keys())
+
+    def get_random_ops(self):
+        sampled_ops = np.random.choice(self.augs, self.N, replace=False)
+        return [(op, self.M) for op in sampled_ops]
+
+    def __call__(self, frames):
+        assert (
+            frames.shape[-1] == 3
+        ), "Expecting last dimension for 3-channels RGB (b, h, w, c)."
+
+        if self.tensor_in_tensor_out:
+            frames = frames.numpy().astype(np.uint8)
+
+        num_frames = frames.shape[0]
+
+        ops = num_frames * [self.get_random_ops()]
+        apply_or_not = num_frames * [np.random.random(size=self.N) > self.p]
+
+        frames = torch.stack(
+            list(map(self._aug, frames, ops, apply_or_not)), dim=0
+        ).float()
+
+        return frames
+
+    def _aug(self, img, ops, apply_or_not):
+        for i, (name, level) in enumerate(ops):
+            if not apply_or_not[i]:
+                continue
+            args = arg_dict[name](level)
+            img = func_dict[name](img, *args)
+        return torch.from_numpy(img)
+
+
+if __name__ == "__main__":
+    a = RandomAugment()
+    img = np.random.randn(32, 32, 3)
+    a(img)