Diff of /libs/datasets/augment.py [000000] .. [98e649]

Switch to unified view

a b/libs/datasets/augment.py
1
2
import numpy as np
3
import random
4
import torch
5
import numpy as np
6
from PIL import Image, ImageEnhance
7
8
class Compose():
9
    def __init__(self, transforms):
10
        self.transforms = transforms
11
    
12
    def __call__(self, img):
13
        for t in self.transforms:
14
            img = t(img)
15
        return img
16
17
class to_Tensor():
18
    def __call__(self,arr):
19
        if len(np.array(arr).shape) == 2:
20
            arr = np.array(arr)[:,:,None]
21
        arr = torch.from_numpy(np.array(arr).transpose(2,0,1))
22
        return arr
23
24
def imresize(im, size, interp='bilinear'):
25
    if interp == 'nearest':
26
        resample = Image.NEAREST
27
    elif interp == 'bilinear':
28
        resample = Image.BILINEAR
29
    elif interp == 'bicubic':
30
        resample = Image.BICUBIC
31
    else:
32
        raise Exception('resample method undefined!')
33
    return im.resize(size, resample)
34
35
class To_PIL_Image():
36
    def __call__(self, img):
37
        return to_pil_image(img)
38
39
class normalize():
40
    def __init__(self,mean,std):
41
        self.mean = torch.tensor(mean)
42
        self.std = torch.tensor(std)
43
    def __call__(self,img):
44
        self.mean = torch.as_tensor(self.mean,dtype=img.dtype,device=img.device)
45
        self.std = torch.as_tensor(self.std,dtype=img.dtype,device=img.device)
46
        return (img-self.mean)/self.std
47
48
class RandomVerticalFlip():
49
    def __init__(self, prob):
50
        self.prob = prob
51
52
    def __call__(self, img):
53
        if random.random() < self.prob:
54
            if isinstance(img, Image.Image):
55
                return img.transpose(Image.FLIP_TOP_BOTTOM)
56
            if isinstance(img, np.ndarray):
57
                return np.flip(img, axis=0)
58
        return img
59
60
class RandomHorizontallyFlip():
61
    def __init__(self, prob=0.5):
62
        self.prob = prob
63
64
    def __call__(self, img):
65
        if random.random() < self.prob:
66
            if isinstance(img, Image.Image):
67
                return img.transpose(Image.FLIP_LEFT_RIGHT)
68
            if isinstance(img, np.ndarray):
69
                return np.flip(img, axis=1)
70
        return img
71
72
class RandomRotate():
73
    def __init__(self, degree, prob=0.5):
74
        self.prob = prob
75
        self.degree = degree
76
77
    def __call__(self, img, interpolation=Image.BILINEAR):
78
        if random.random() < self.prob:
79
            rotate_detree = random.random() * 2 * self.degree - self.degree
80
            return img.rotate(rotate_detree, interpolation)
81
        return img
82
83
class RandomBrightness():
84
    def __init__(self, min_factor, max_factor, prob=0.5):
85
        """ :param min_factor: The value between 0.0 and max_factor
86
            that define the minimum adjustment of image brightness.
87
            The value  0.0 gives a black image,The value 1.0 gives the original image, value bigger than 1.0 gives more bright image.
88
            :param max_factor: A value should be bigger than min_factor.
89
            that define the maximum adjustment of image brightness.
90
            The value  0.0 gives a black image, value 1.0 gives the original image, value bigger than 1.0 gives more bright image.
91
92
        """
93
        self.prob = prob
94
        self.min_factor = min_factor
95
        self.max_factor = max_factor
96
    
97
    # def __brightness(self, img, factor):
98
    #     return img * (1.0 - factor) + img * factor
99
100
    # def __call__(self, img):
101
    #     if random.random() < self.prob:
102
    #         factor = np.random.uniform(self.min_factor, self.max_factor)
103
    #         return self.__brightness(img, factor)
104
105
    def __call__(self, img):
106
        if random.random() < self.prob:
107
            factor = np.random.uniform(self.min_factor, self.max_factor)
108
            enhancer_brightness = ImageEnhance.Brightness(img)
109
            return enhancer_brightness.enhance(factor)
110
111
        return img
112
113
class RandomContrast():
114
    def __init__(self, min_factor, max_factor, prob=0.5):
115
        """ :param min_factor: The value between 0.0 and max_factor
116
            that define the minimum adjustment of image contrast.
117
            The value  0.0 gives s solid grey image, value 1.0 gives the original image.
118
            :param max_factor: A value should be bigger than min_factor.
119
            that define the maximum adjustment of image contrast.
120
            The value  0.0 gives s solid grey image, value 1.0 gives the original image.
121
        """
122
        self.prob = prob
123
        self.min_factor = min_factor
124
        self.max_factor = max_factor
125
126
    def __call__(self, img):
127
        if random.random() < self.prob:
128
            factor = np.random.uniform(self.min_factor, self.max_factor)
129
            enhance_contrast = ImageEnhance.Contrast(img)
130
            return enhance_contrast.enhance(factor)
131
        return img
132
133
def to_pil_image(pic, mode=None):
134
    """Convert a tensor or an ndarray to PIL Image.
135
136
    See :class:`~torchvision.transforms.ToPIlImage` for more details.
137
138
    Args:
139
        pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
140
        mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
141
142
    .. _PIL.Image mode: http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#modes
143
144
    Returns:
145
        PIL Image: Image converted to PIL Image.
146
    """
147
    # if not(_is_numpy_image(pic) or _is_tensor_image(pic)):
148
    #     raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic)))
149
150
    npimg = pic
151
    if isinstance(pic, torch.FloatTensor):
152
        pic = pic.mul(255).byte()
153
    if torch.is_tensor(pic):
154
        npimg = np.transpose(pic.numpy(), (1, 2, 0))
155
156
    if not isinstance(npimg, np.ndarray):
157
        raise TypeError('Input pic must be a torch.Tensor or NumPy ndarray, ' +
158
                        'not {}'.format(type(npimg)))
159
160
    if npimg.shape[2] == 1:
161
        expected_mode = None
162
        npimg = npimg[:, :, 0]
163
        if npimg.dtype == np.uint8:
164
            expected_mode = 'L'
165
        elif npimg.dtype == np.int16:
166
            expected_mode = 'I;16'
167
        elif npimg.dtype == np.int32:
168
            expected_mode = 'I'
169
        elif npimg.dtype == np.float32:
170
            expected_mode = 'F'
171
        if mode is not None and mode != expected_mode:
172
            raise ValueError("Incorrect mode ({}) supplied for input type {}. Should be {}"
173
                             .format(mode, np.dtype, expected_mode))
174
        mode = expected_mode
175
176
    elif npimg.shape[2] == 4:
177
        permitted_4_channel_modes = ['RGBA', 'CMYK']
178
        if mode is not None and mode not in permitted_4_channel_modes:
179
            raise ValueError("Only modes {} are supported for 4D inputs".format(permitted_4_channel_modes))
180
181
        if mode is None and npimg.dtype == np.uint8:
182
            mode = 'RGBA'
183
    else:
184
        permitted_3_channel_modes = ['RGB', 'YCbCr', 'HSV']
185
        if mode is not None and mode not in permitted_3_channel_modes:
186
            raise ValueError("Only modes {} are supported for 3D inputs".format(permitted_3_channel_modes))
187
        if mode is None and npimg.dtype == np.uint8:
188
            mode = 'RGB'
189
190
    if mode is None:
191
        raise TypeError('Input type {} is not supported'.format(npimg.dtype))
192
193
    return Image.fromarray(npimg, mode=mode)