|
a |
|
b/data/transforms.py |
|
|
1 |
import torch.nn as nn |
|
|
2 |
import torchvision.transforms as T |
|
|
3 |
|
|
|
4 |
from util.constants import IMAGENET_MEAN, IMAGENET_STD |
|
|
5 |
|
|
|
6 |
import cv2 |
|
|
7 |
|
|
|
8 |
def get_transforms(split, augmentation, image_size): |
|
|
9 |
|
|
|
10 |
IMAGE_SIZE = (224, 224) |
|
|
11 |
|
|
|
12 |
if split != "train": |
|
|
13 |
augmentation = 'none' # Only do augmentations if its the training dataset |
|
|
14 |
|
|
|
15 |
augmentation_transforms = { # 4 different levels of augmentation transformations |
|
|
16 |
'none': [nn.Identity()], |
|
|
17 |
'flip': [T.RandomVerticalFlip(), T.RandomHorizontalFlip()], |
|
|
18 |
'affine': [T.RandomVerticalFlip(), |
|
|
19 |
T.RandomHorizontalFlip(), |
|
|
20 |
T.RandomAffine(degrees=90, translate=(0.03, 0.03), scale=(0.95, 1.05))], |
|
|
21 |
'aggressive': [T.RandomResizedCrop(image_size, scale=(0.9, 1.1), ratio=(1.0, 1.0)), |
|
|
22 |
T.RandomVerticalFlip(), |
|
|
23 |
T.RandomHorizontalFlip(), |
|
|
24 |
T.RandomAffine(degrees=90, translate=(0.03, 0.03), scale=(0.95, 1.05)), |
|
|
25 |
T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2)], |
|
|
26 |
'simclr': [T.RandomHorizontalFlip(), |
|
|
27 |
T.RandomResizedCrop(size=96), |
|
|
28 |
T.RandomApply([ |
|
|
29 |
T.ColorJitter(brightness=0.5,contrast=0.5, |
|
|
30 |
saturation=0.5, |
|
|
31 |
hue=0.1)], p=0.8), |
|
|
32 |
T.RandomGrayscale(p=0.2), |
|
|
33 |
T.GaussianBlur(kernel_size=9)] |
|
|
34 |
} |
|
|
35 |
|
|
|
36 |
augmentation_transform = augmentation_transforms[augmentation] |
|
|
37 |
resize_transform = [T.Resize((image_size, image_size))] # Resize to square of specified dimensions |
|
|
38 |
totensor_transform = [T.ToTensor()] # Convert to tensor |
|
|
39 |
#normalize_transform = T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD) # Normalize based on image net |
|
|
40 |
|
|
|
41 |
transforms_list = augmentation_transform + resize_transform + totensor_transform |
|
|
42 |
#if pretrained: # Normalized based on ImageNet if we are using a pre-trained model |
|
|
43 |
# transforms_list.append(normalize_transform) |
|
|
44 |
|
|
|
45 |
transforms = T.Compose(transforms_list) |
|
|
46 |
return transforms |