--- a
+++ b/data/transforms.py
@@ -0,0 +1,46 @@
+import torch.nn as nn
+import torchvision.transforms as T
+
+from util.constants import IMAGENET_MEAN, IMAGENET_STD
+
+import cv2
+
+def get_transforms(split, augmentation, image_size):
+
+    IMAGE_SIZE = (224, 224)
+    
+    if split != "train":
+        augmentation = 'none' # Only do augmentations if its the training dataset
+        
+    augmentation_transforms = { # 4 different levels of augmentation transformations
+        'none': [nn.Identity()],
+        'flip': [T.RandomVerticalFlip(), T.RandomHorizontalFlip()],
+        'affine': [T.RandomVerticalFlip(), 
+                   T.RandomHorizontalFlip(),
+                   T.RandomAffine(degrees=90, translate=(0.03, 0.03), scale=(0.95, 1.05))],
+        'aggressive': [T.RandomResizedCrop(image_size, scale=(0.9, 1.1), ratio=(1.0, 1.0)),
+                       T.RandomVerticalFlip(),
+                       T.RandomHorizontalFlip(),
+                       T.RandomAffine(degrees=90, translate=(0.03, 0.03), scale=(0.95, 1.05)),
+                       T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2)],
+        'simclr': [T.RandomHorizontalFlip(),
+                    T.RandomResizedCrop(size=96),
+                    T.RandomApply([
+                    T.ColorJitter(brightness=0.5,contrast=0.5,
+                     saturation=0.5,
+                     hue=0.1)], p=0.8),
+                     T.RandomGrayscale(p=0.2),
+                    T.GaussianBlur(kernel_size=9)]
+    }
+        
+    augmentation_transform = augmentation_transforms[augmentation] 
+    resize_transform = [T.Resize((image_size, image_size))] # Resize to square of specified dimensions
+    totensor_transform = [T.ToTensor()] # Convert to tensor
+    #normalize_transform = T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD) # Normalize based on image net
+
+    transforms_list = augmentation_transform + resize_transform + totensor_transform
+    #if pretrained: # Normalized based on ImageNet if we are using a pre-trained model
+    #    transforms_list.append(normalize_transform) 
+        
+    transforms = T.Compose(transforms_list)
+    return transforms
\ No newline at end of file