Diff of /data/transforms.py [000000] .. [139527]

Switch to unified view

a b/data/transforms.py
1
import torch.nn as nn
2
import torchvision.transforms as T
3
4
from util.constants import IMAGENET_MEAN, IMAGENET_STD
5
6
import cv2
7
8
def get_transforms(split, augmentation, image_size):
9
10
    IMAGE_SIZE = (224, 224)
11
    
12
    if split != "train":
13
        augmentation = 'none' # Only do augmentations if its the training dataset
14
        
15
    augmentation_transforms = { # 4 different levels of augmentation transformations
16
        'none': [nn.Identity()],
17
        'flip': [T.RandomVerticalFlip(), T.RandomHorizontalFlip()],
18
        'affine': [T.RandomVerticalFlip(), 
19
                   T.RandomHorizontalFlip(),
20
                   T.RandomAffine(degrees=90, translate=(0.03, 0.03), scale=(0.95, 1.05))],
21
        'aggressive': [T.RandomResizedCrop(image_size, scale=(0.9, 1.1), ratio=(1.0, 1.0)),
22
                       T.RandomVerticalFlip(),
23
                       T.RandomHorizontalFlip(),
24
                       T.RandomAffine(degrees=90, translate=(0.03, 0.03), scale=(0.95, 1.05)),
25
                       T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2)],
26
        'simclr': [T.RandomHorizontalFlip(),
27
                    T.RandomResizedCrop(size=96),
28
                    T.RandomApply([
29
                    T.ColorJitter(brightness=0.5,contrast=0.5,
30
                     saturation=0.5,
31
                     hue=0.1)], p=0.8),
32
                     T.RandomGrayscale(p=0.2),
33
                    T.GaussianBlur(kernel_size=9)]
34
    }
35
        
36
    augmentation_transform = augmentation_transforms[augmentation] 
37
    resize_transform = [T.Resize((image_size, image_size))] # Resize to square of specified dimensions
38
    totensor_transform = [T.ToTensor()] # Convert to tensor
39
    #normalize_transform = T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD) # Normalize based on image net
40
41
    transforms_list = augmentation_transform + resize_transform + totensor_transform
42
    #if pretrained: # Normalized based on ImageNet if we are using a pre-trained model
43
    #    transforms_list.append(normalize_transform) 
44
        
45
    transforms = T.Compose(transforms_list)
46
    return transforms