a b/minigpt4/processors/randaugment.py
1
"""
2
 Copyright (c) 2022, salesforce.com, inc.
3
 All rights reserved.
4
 SPDX-License-Identifier: BSD-3-Clause
5
 For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
"""
7
8
import cv2
9
import numpy as np
10
11
import torch
12
13
14
## aug functions
15
def identity_func(img):
16
    return img
17
18
19
def autocontrast_func(img, cutoff=0):
20
    """
21
    same output as PIL.ImageOps.autocontrast
22
    """
23
    n_bins = 256
24
25
    def tune_channel(ch):
26
        n = ch.size
27
        cut = cutoff * n // 100
28
        if cut == 0:
29
            high, low = ch.max(), ch.min()
30
        else:
31
            hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
32
            low = np.argwhere(np.cumsum(hist) > cut)
33
            low = 0 if low.shape[0] == 0 else low[0]
34
            high = np.argwhere(np.cumsum(hist[::-1]) > cut)
35
            high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
36
        if high <= low:
37
            table = np.arange(n_bins)
38
        else:
39
            scale = (n_bins - 1) / (high - low)
40
            offset = -low * scale
41
            table = np.arange(n_bins) * scale + offset
42
            table[table < 0] = 0
43
            table[table > n_bins - 1] = n_bins - 1
44
        table = table.clip(0, 255).astype(np.uint8)
45
        return table[ch]
46
47
    channels = [tune_channel(ch) for ch in cv2.split(img)]
48
    out = cv2.merge(channels)
49
    return out
50
51
52
def equalize_func(img):
53
    """
54
    same output as PIL.ImageOps.equalize
55
    PIL's implementation is different from cv2.equalize
56
    """
57
    n_bins = 256
58
59
    def tune_channel(ch):
60
        hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
61
        non_zero_hist = hist[hist != 0].reshape(-1)
62
        step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
63
        if step == 0:
64
            return ch
65
        n = np.empty_like(hist)
66
        n[0] = step // 2
67
        n[1:] = hist[:-1]
68
        table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
69
        return table[ch]
70
71
    channels = [tune_channel(ch) for ch in cv2.split(img)]
72
    out = cv2.merge(channels)
73
    return out
74
75
76
def rotate_func(img, degree, fill=(0, 0, 0)):
77
    """
78
    like PIL, rotate by degree, not radians
79
    """
80
    H, W = img.shape[0], img.shape[1]
81
    center = W / 2, H / 2
82
    M = cv2.getRotationMatrix2D(center, degree, 1)
83
    out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
84
    return out
85
86
87
def solarize_func(img, thresh=128):
88
    """
89
    same output as PIL.ImageOps.posterize
90
    """
91
    table = np.array([el if el < thresh else 255 - el for el in range(256)])
92
    table = table.clip(0, 255).astype(np.uint8)
93
    out = table[img]
94
    return out
95
96
97
def color_func(img, factor):
98
    """
99
    same output as PIL.ImageEnhance.Color
100
    """
101
    ## implementation according to PIL definition, quite slow
102
    #  degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
103
    #  out = blend(degenerate, img, factor)
104
    #  M = (
105
    #      np.eye(3) * factor
106
    #      + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
107
    #  )[np.newaxis, np.newaxis, :]
108
    M = np.float32(
109
        [[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587], [-0.299, -0.299, 0.701]]
110
    ) * factor + np.float32([[0.114], [0.587], [0.299]])
111
    out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
112
    return out
113
114
115
def contrast_func(img, factor):
116
    """
117
    same output as PIL.ImageEnhance.Contrast
118
    """
119
    mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
120
    table = (
121
        np.array([(el - mean) * factor + mean for el in range(256)])
122
        .clip(0, 255)
123
        .astype(np.uint8)
124
    )
125
    out = table[img]
126
    return out
127
128
129
def brightness_func(img, factor):
130
    """
131
    same output as PIL.ImageEnhance.Contrast
132
    """
133
    table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
134
    out = table[img]
135
    return out
136
137
138
def sharpness_func(img, factor):
139
    """
140
    The differences the this result and PIL are all on the 4 boundaries, the center
141
    areas are same
142
    """
143
    kernel = np.ones((3, 3), dtype=np.float32)
144
    kernel[1][1] = 5
145
    kernel /= 13
146
    degenerate = cv2.filter2D(img, -1, kernel)
147
    if factor == 0.0:
148
        out = degenerate
149
    elif factor == 1.0:
150
        out = img
151
    else:
152
        out = img.astype(np.float32)
153
        degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
154
        out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
155
        out = out.astype(np.uint8)
156
    return out
157
158
159
def shear_x_func(img, factor, fill=(0, 0, 0)):
160
    H, W = img.shape[0], img.shape[1]
161
    M = np.float32([[1, factor, 0], [0, 1, 0]])
162
    out = cv2.warpAffine(
163
        img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
164
    ).astype(np.uint8)
165
    return out
166
167
168
def translate_x_func(img, offset, fill=(0, 0, 0)):
169
    """
170
    same output as PIL.Image.transform
171
    """
172
    H, W = img.shape[0], img.shape[1]
173
    M = np.float32([[1, 0, -offset], [0, 1, 0]])
174
    out = cv2.warpAffine(
175
        img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
176
    ).astype(np.uint8)
177
    return out
178
179
180
def translate_y_func(img, offset, fill=(0, 0, 0)):
181
    """
182
    same output as PIL.Image.transform
183
    """
184
    H, W = img.shape[0], img.shape[1]
185
    M = np.float32([[1, 0, 0], [0, 1, -offset]])
186
    out = cv2.warpAffine(
187
        img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
188
    ).astype(np.uint8)
189
    return out
190
191
192
def posterize_func(img, bits):
193
    """
194
    same output as PIL.ImageOps.posterize
195
    """
196
    out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
197
    return out
198
199
200
def shear_y_func(img, factor, fill=(0, 0, 0)):
201
    H, W = img.shape[0], img.shape[1]
202
    M = np.float32([[1, 0, 0], [factor, 1, 0]])
203
    out = cv2.warpAffine(
204
        img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
205
    ).astype(np.uint8)
206
    return out
207
208
209
def cutout_func(img, pad_size, replace=(0, 0, 0)):
210
    replace = np.array(replace, dtype=np.uint8)
211
    H, W = img.shape[0], img.shape[1]
212
    rh, rw = np.random.random(2)
213
    pad_size = pad_size // 2
214
    ch, cw = int(rh * H), int(rw * W)
215
    x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
216
    y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
217
    out = img.copy()
218
    out[x1:x2, y1:y2, :] = replace
219
    return out
220
221
222
### level to args
223
def enhance_level_to_args(MAX_LEVEL):
224
    def level_to_args(level):
225
        return ((level / MAX_LEVEL) * 1.8 + 0.1,)
226
227
    return level_to_args
228
229
230
def shear_level_to_args(MAX_LEVEL, replace_value):
231
    def level_to_args(level):
232
        level = (level / MAX_LEVEL) * 0.3
233
        if np.random.random() > 0.5:
234
            level = -level
235
        return (level, replace_value)
236
237
    return level_to_args
238
239
240
def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
241
    def level_to_args(level):
242
        level = (level / MAX_LEVEL) * float(translate_const)
243
        if np.random.random() > 0.5:
244
            level = -level
245
        return (level, replace_value)
246
247
    return level_to_args
248
249
250
def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
251
    def level_to_args(level):
252
        level = int((level / MAX_LEVEL) * cutout_const)
253
        return (level, replace_value)
254
255
    return level_to_args
256
257
258
def solarize_level_to_args(MAX_LEVEL):
259
    def level_to_args(level):
260
        level = int((level / MAX_LEVEL) * 256)
261
        return (level,)
262
263
    return level_to_args
264
265
266
def none_level_to_args(level):
267
    return ()
268
269
270
def posterize_level_to_args(MAX_LEVEL):
271
    def level_to_args(level):
272
        level = int((level / MAX_LEVEL) * 4)
273
        return (level,)
274
275
    return level_to_args
276
277
278
def rotate_level_to_args(MAX_LEVEL, replace_value):
279
    def level_to_args(level):
280
        level = (level / MAX_LEVEL) * 30
281
        if np.random.random() < 0.5:
282
            level = -level
283
        return (level, replace_value)
284
285
    return level_to_args
286
287
288
func_dict = {
289
    "Identity": identity_func,
290
    "AutoContrast": autocontrast_func,
291
    "Equalize": equalize_func,
292
    "Rotate": rotate_func,
293
    "Solarize": solarize_func,
294
    "Color": color_func,
295
    "Contrast": contrast_func,
296
    "Brightness": brightness_func,
297
    "Sharpness": sharpness_func,
298
    "ShearX": shear_x_func,
299
    "TranslateX": translate_x_func,
300
    "TranslateY": translate_y_func,
301
    "Posterize": posterize_func,
302
    "ShearY": shear_y_func,
303
}
304
305
translate_const = 10
306
MAX_LEVEL = 10
307
replace_value = (128, 128, 128)
308
arg_dict = {
309
    "Identity": none_level_to_args,
310
    "AutoContrast": none_level_to_args,
311
    "Equalize": none_level_to_args,
312
    "Rotate": rotate_level_to_args(MAX_LEVEL, replace_value),
313
    "Solarize": solarize_level_to_args(MAX_LEVEL),
314
    "Color": enhance_level_to_args(MAX_LEVEL),
315
    "Contrast": enhance_level_to_args(MAX_LEVEL),
316
    "Brightness": enhance_level_to_args(MAX_LEVEL),
317
    "Sharpness": enhance_level_to_args(MAX_LEVEL),
318
    "ShearX": shear_level_to_args(MAX_LEVEL, replace_value),
319
    "TranslateX": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
320
    "TranslateY": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
321
    "Posterize": posterize_level_to_args(MAX_LEVEL),
322
    "ShearY": shear_level_to_args(MAX_LEVEL, replace_value),
323
}
324
325
326
class RandomAugment(object):
327
    def __init__(self, N=2, M=10, isPIL=False, augs=[]):
328
        self.N = N
329
        self.M = M
330
        self.isPIL = isPIL
331
        if augs:
332
            self.augs = augs
333
        else:
334
            self.augs = list(arg_dict.keys())
335
336
    def get_random_ops(self):
337
        sampled_ops = np.random.choice(self.augs, self.N)
338
        return [(op, 0.5, self.M) for op in sampled_ops]
339
340
    def __call__(self, img):
341
        if self.isPIL:
342
            img = np.array(img)
343
        ops = self.get_random_ops()
344
        for name, prob, level in ops:
345
            if np.random.random() > prob:
346
                continue
347
            args = arg_dict[name](level)
348
            img = func_dict[name](img, *args)
349
        return img
350
351
352
class VideoRandomAugment(object):
353
    def __init__(self, N=2, M=10, p=0.0, tensor_in_tensor_out=True, augs=[]):
354
        self.N = N
355
        self.M = M
356
        self.p = p
357
        self.tensor_in_tensor_out = tensor_in_tensor_out
358
        if augs:
359
            self.augs = augs
360
        else:
361
            self.augs = list(arg_dict.keys())
362
363
    def get_random_ops(self):
364
        sampled_ops = np.random.choice(self.augs, self.N, replace=False)
365
        return [(op, self.M) for op in sampled_ops]
366
367
    def __call__(self, frames):
368
        assert (
369
            frames.shape[-1] == 3
370
        ), "Expecting last dimension for 3-channels RGB (b, h, w, c)."
371
372
        if self.tensor_in_tensor_out:
373
            frames = frames.numpy().astype(np.uint8)
374
375
        num_frames = frames.shape[0]
376
377
        ops = num_frames * [self.get_random_ops()]
378
        apply_or_not = num_frames * [np.random.random(size=self.N) > self.p]
379
380
        frames = torch.stack(
381
            list(map(self._aug, frames, ops, apply_or_not)), dim=0
382
        ).float()
383
384
        return frames
385
386
    def _aug(self, img, ops, apply_or_not):
387
        for i, (name, level) in enumerate(ops):
388
            if not apply_or_not[i]:
389
                continue
390
            args = arg_dict[name](level)
391
            img = func_dict[name](img, *args)
392
        return torch.from_numpy(img)
393
394
395
if __name__ == "__main__":
396
    a = RandomAugment()
397
    img = np.random.randn(32, 32, 3)
398
    a(img)