a b/BioSeqNet/resnest/gluon/data_utils.py
1
from PIL import Image
2
3
import mxnet as mx
4
from mxnet.gluon import Block
5
from ..transforms import *
6
7
class RandAugment(object):
8
    def __init__(self, n, m):
9
        self.n = n
10
        self.m = m
11
        self.augment_list = rand_augment_list()
12
        self.topil = ToPIL()
13
14
    def __call__(self, img):
15
        img = self.topil(img)
16
        ops = random.choices(self.augment_list, k=self.n)
17
        for op, minval, maxval in ops:
18
            if random.random() > random.uniform(0.2, 0.8):
19
                continue
20
            val = (float(self.m) / 30) * float(maxval - minval) + minval
21
            img = op(img, val)
22
        return img
23
24
25
class ToPIL(object):
26
    """Convert image from ndarray format to PIL
27
    """
28
    def __call__(self, img):
29
        x = Image.fromarray(img.asnumpy())
30
        return x
31
32
class ToNDArray(object):
33
    def __call__(self, img):
34
        x = mx.nd.array(np.array(img), mx.cpu(0))
35
        return x
36
37
class AugmentationBlock(Block):
38
    r"""
39
    AutoAugment Block
40
41
    Example
42
    -------
43
    >>> from autogluon.utils.augment import AugmentationBlock, autoaug_imagenet_policies
44
    >>> aa_transform = AugmentationBlock(autoaug_imagenet_policies())
45
    """
46
    def __init__(self, policies):
47
        """
48
        plicies : list of (name, pr, level)
49
        """
50
        super().__init__()
51
        self.policies = policies
52
        self.topil = ToPIL()
53
        self.tond = ToNDArray()
54
55
    def forward(self, img):
56
        img = self.topil(img)
57
        policy = random.choice(self.policies)
58
        for name, pr, level in policy:
59
            if random.random() > pr:
60
                continue
61
            img = apply_augment(img, name, level)
62
        img = self.tond(img)
63
        return img