Diff of /datacode/augmentations.py [000000] .. [a18f15]

Switch to unified view

a b/datacode/augmentations.py
1
import random
2
import numpy as np
3
from PIL import ImageOps, ImageFilter, Image
4
import torchvision.transforms as torch_transforms
5
from torchvision.transforms import InterpolationMode
6
7
# import phasepack
8
9
##========================Natural Images========================================
10
11
class GaussianBlur(object):
12
    def __init__(self, p):
13
        self.p = p
14
15
    def __call__(self, img):
16
        if random.random() < self.p:
17
            sigma = random.random() * 1.9 + 0.1
18
            return img.filter(ImageFilter.GaussianBlur(sigma))
19
        else:
20
            return img
21
22
23
class Solarization(object):
24
    def __init__(self, p):
25
        self.p = p
26
27
    def __call__(self, img):
28
        if random.random() < self.p:
29
            return ImageOps.solarize(img)
30
        else:
31
            return img
32
33
###-----------------------------------------------------------------------------
34
35
## For SimCLR
36
class SimCLRTransform:
37
    """Transforms for SimCLR during training step of the pre-training stage.
38
    Transform::
39
        RandomResizedCrop(size=self.image_size)
40
        RandomHorizontalFlip()
41
        RandomApply([color_jitter], p=0.8)
42
        RandomGrayscale(p=0.2)
43
        RandomApply([GaussianBlur(kernel_size=int(0.1 * self.image_size))], p=0.5)
44
        transforms.ToTensor()
45
46
    Example::
47
48
        transform = SimCLRTrainDataTransform(image_size=32)
49
        x = sample()
50
        (xi, xj, xk) = transform(x) # xk is only for the online evaluator if used
51
    """
52
53
    def __init__(
54
        self, image_size: int = 256, jitter_strength: float = 1.0, normalize=None
55
    ) -> None:
56
57
58
59
        self.jitter_strength = jitter_strength
60
        self.image_size = image_size
61
        self.normalize = normalize
62
63
        kernel_size = int(0.1 * self.image_size)
64
        if kernel_size % 2 == 0:
65
            kernel_size += 1
66
67
        self.data_transform = torch_transforms.Compose([
68
            torch_transforms.RandomResizedCrop(size=self.image_size),
69
            torch_transforms.RandomHorizontalFlip(p=0.5),
70
            torch_transforms.RandomApply([
71
                    torch_transforms.ColorJitter(
72
                            0.8 * self.jitter_strength,
73
                            0.8 * self.jitter_strength,
74
                            0.8 * self.jitter_strength,
75
                            0.2 * self.jitter_strength)], p=0.8),
76
            torch_transforms.RandomGrayscale(p=0.2),
77
            torch_transforms.RandomApply(
78
                    [torch_transforms.GaussianBlur(kernel_size=kernel_size)],
79
                    p=0.5),
80
            torch_transforms.ToTensor(),
81
            torch_transforms.Normalize( mean=[0.485, 0.456, 0.406],
82
                                        std=[0.229, 0.224, 0.225])
83
        ])
84
85
86
    def __call__(self, sample):
87
        transform = self.data_transform
88
89
        xi = transform(sample)
90
        xj = transform(sample)
91
92
        return xi, xj
93
94
    def get_composition(self):
95
        return self.data_transform
96
97
98
99
## For BalowTwins and VICRegularization
100
class BarlowTwinsTransformOrig:
101
    def __init__(self, image_size = 256):
102
        self.transform = torch_transforms.Compose([
103
            torch_transforms.RandomResizedCrop(image_size,
104
                                    interpolation=InterpolationMode.BICUBIC),
105
            torch_transforms.RandomHorizontalFlip(p=0.5),
106
            torch_transforms.RandomApply(
107
                [torch_transforms.ColorJitter(brightness=0.4, contrast=0.4,
108
                                    saturation=0.2, hue=0.1)],
109
                p=0.8
110
            ),
111
            torch_transforms.RandomGrayscale(p=0.2),
112
            GaussianBlur(p=1.0),
113
            Solarization(p=0.0),
114
            torch_transforms.ToTensor(),
115
            torch_transforms.Normalize( mean=[0.485, 0.456, 0.406],
116
                                        std=[0.229, 0.224, 0.225])
117
        ])
118
        self.transform_prime = torch_transforms.Compose([
119
            torch_transforms.RandomResizedCrop(image_size,
120
                                    interpolation=InterpolationMode.BICUBIC),
121
            torch_transforms.RandomHorizontalFlip(p=0.5),
122
            torch_transforms.RandomApply(
123
                [torch_transforms.ColorJitter(brightness=0.4, contrast=0.4,
124
                                    saturation=0.2, hue=0.1)],
125
                p=0.8
126
            ),
127
            torch_transforms.RandomGrayscale(p=0.2),
128
            GaussianBlur(p=0.1),
129
            Solarization(p=0.2),
130
            torch_transforms.ToTensor(),
131
            torch_transforms.Normalize( mean=[0.485, 0.456, 0.406],
132
                                        std=[0.229, 0.224, 0.225])
133
        ])
134
135
    def __call__(self, x):
136
        y1 = self.transform(x)
137
        y2 = self.transform_prime(x)
138
        return y1, y2
139
140
    def get_composition(self):
141
        return (self.transform, self.transform_prime)
142
143
144
##======================== UltraSound Images ===================================
145
146
147
class ClassifierTransform:
148
    def __init__(self, mode = "train", image_size = 256):
149
150
        data_mean = [0.485, 0.456, 0.406]
151
        data_std  = [0.229, 0.224, 0.225]
152
153
        train_transform = torch_transforms.Compose([
154
            torch_transforms.Resize(image_size,
155
                                    interpolation=InterpolationMode.BICUBIC),
156
            torch_transforms.RandAugment(num_ops=5, magnitude=5),
157
            torch_transforms.RandomHorizontalFlip(p=0.5),
158
            torch_transforms.RandomVerticalFlip(p=0.5),
159
            torch_transforms.ToTensor(),
160
            torch_transforms.Normalize(mean=data_mean, std=data_std)
161
        ])
162
        infer_transform = torch_transforms.Compose([
163
            torch_transforms.Resize(image_size,
164
                        interpolation=InterpolationMode.BICUBIC),
165
            torch_transforms.ToTensor(),
166
            torch_transforms.Normalize(mean=data_mean, std=data_std)
167
        ])
168
169
        if   mode == "train": self.transform = train_transform
170
        elif mode == "infer": self.transform = infer_transform
171
        else : raise ValueError("Unknown Mode set only `train` or `infer` allowed")
172
173
174
    def __call__(self, x):
175
        y = self.transform(x)
176
        return y
177
178
    def get_composition(self):
179
        return self.transform
180
181
##------------------------------------------------------------------------------
182
183
184
# def phasecongruence(image):
185
#     image = np.asarray(image)
186
#     [M, ori, ft, T] = phasepack.phasecongmono(image,
187
#                             nscale=5, minWaveLength=5)
188
#     out = ((M - M.min())/(M.max() - M.min()+1) *255.0).astype(np.uint8)
189
#     out = Image.fromarray(out).convert("RGB")
190
#     return out
191
192
193
class CustomInfoMaxTransform:
194
    def __init__(self, image_size = 256):
195
196
        self.transform = torch_transforms.Compose([
197
            torch_transforms.RandomResizedCrop(256,
198
                                    interpolation=InterpolationMode.BICUBIC),
199
            torch_transforms.RandomHorizontalFlip(p=0.5),
200
            # torch_transforms.Lambda(phasecongruence),
201
            torch_transforms.RandAugment(num_ops=5, magnitude=5),
202
            torch_transforms.ToTensor(),
203
            torch_transforms.Normalize( mean=[0.485, 0.456, 0.406],
204
                                        std=[0.229, 0.224, 0.225])
205
        ])
206
        self.transform_prime = torch_transforms.Compose([
207
            torch_transforms.RandomResizedCrop(image_size,
208
                                    interpolation=InterpolationMode.BICUBIC),
209
            torch_transforms.RandomHorizontalFlip(p=0.5),
210
            torch_transforms.RandomApply(
211
                [torch_transforms.ColorJitter(brightness=0.4, contrast=0.4,
212
                                    saturation=0.2, hue=0.1)],
213
                p=0.8
214
            ),
215
            torch_transforms.RandomGrayscale(p=0.2),
216
            GaussianBlur(p=0.1),
217
            Solarization(p=0.2),
218
            torch_transforms.ToTensor(),
219
            torch_transforms.Normalize( mean=[0.485, 0.456, 0.406],
220
                                        std=[0.229, 0.224, 0.225])
221
        ])
222
223
    def __call__(self, x):
224
        y1 = self.transform(x)
225
        y2 = self.transform_prime(x)
226
        return y1, y2
227
228
    def get_composition(self):
229
        return (self.transform, self.transform_prime)
230
231
232
##------------------------------------------------------------------------------
233
234
235
236
class AEncStandardTransform:
237
    def __init__(self, image_size = 256):
238
239
        data_mean = [0.485, 0.456, 0.406]
240
        data_std  = [0.229, 0.224, 0.225]
241
242
        train_transform = torch_transforms.Compose([
243
            torch_transforms.Resize(image_size,
244
                                    interpolation=InterpolationMode.BICUBIC),
245
            torch_transforms.RandAugment(num_ops=3, magnitude=3),
246
            torch_transforms.RandomHorizontalFlip(p=0.5),
247
            torch_transforms.RandomVerticalFlip(p=0.5),
248
            torch_transforms.ToTensor(),
249
            torch_transforms.Normalize(mean=data_mean, std=data_std)
250
        ])
251
252
        self.transform = train_transform
253
254
255
    def __call__(self, x):
256
        y = self.transform(x)
257
        return y, y
258
259
    def get_composition(self):
260
        return self.transform
261
262
263
264
class AEncInpaintTransform:
265
    def __init__(self, image_size = 256):
266
267
        data_mean = [0.485, 0.456, 0.406]
268
        data_std  = [0.229, 0.224, 0.225]
269
270
        self.transform = torch_transforms.Compose([
271
            torch_transforms.Resize(image_size,
272
                                    interpolation=InterpolationMode.BICUBIC),
273
            torch_transforms.RandAugment(num_ops=1, magnitude=1),
274
            torch_transforms.RandomHorizontalFlip(p=0.5),
275
            torch_transforms.RandomVerticalFlip(p=0.5),
276
            torch_transforms.ToTensor(),
277
            torch_transforms.Normalize(mean=data_mean, std=data_std)
278
        ])
279
280
        self.erase = torch_transforms.RandomErasing(p=1, value="random")
281
282
283
    def __call__(self, x):
284
        y2 = self.transform(x)
285
        y1 = self.erase(y2)
286
        return y1, y2
287
288
    def get_composition(self):
289
        return (self.transform, self.erase)