--- a +++ b/datacode/augmentations.py @@ -0,0 +1,289 @@ +import random +import numpy as np +from PIL import ImageOps, ImageFilter, Image +import torchvision.transforms as torch_transforms +from torchvision.transforms import InterpolationMode + +# import phasepack + +##========================Natural Images======================================== + +class GaussianBlur(object): + def __init__(self, p): + self.p = p + + def __call__(self, img): + if random.random() < self.p: + sigma = random.random() * 1.9 + 0.1 + return img.filter(ImageFilter.GaussianBlur(sigma)) + else: + return img + + +class Solarization(object): + def __init__(self, p): + self.p = p + + def __call__(self, img): + if random.random() < self.p: + return ImageOps.solarize(img) + else: + return img + +###----------------------------------------------------------------------------- + +## For SimCLR +class SimCLRTransform: + """Transforms for SimCLR during training step of the pre-training stage. + Transform:: + RandomResizedCrop(size=self.image_size) + RandomHorizontalFlip() + RandomApply([color_jitter], p=0.8) + RandomGrayscale(p=0.2) + RandomApply([GaussianBlur(kernel_size=int(0.1 * self.image_size))], p=0.5) + transforms.ToTensor() + + Example:: + + transform = SimCLRTrainDataTransform(image_size=32) + x = sample() + (xi, xj, xk) = transform(x) # xk is only for the online evaluator if used + """ + + def __init__( + self, image_size: int = 256, jitter_strength: float = 1.0, normalize=None + ) -> None: + + + + self.jitter_strength = jitter_strength + self.image_size = image_size + self.normalize = normalize + + kernel_size = int(0.1 * self.image_size) + if kernel_size % 2 == 0: + kernel_size += 1 + + self.data_transform = torch_transforms.Compose([ + torch_transforms.RandomResizedCrop(size=self.image_size), + torch_transforms.RandomHorizontalFlip(p=0.5), + torch_transforms.RandomApply([ + torch_transforms.ColorJitter( + 0.8 * self.jitter_strength, + 0.8 * self.jitter_strength, + 0.8 * self.jitter_strength, + 0.2 * self.jitter_strength)], p=0.8), + torch_transforms.RandomGrayscale(p=0.2), + torch_transforms.RandomApply( + [torch_transforms.GaussianBlur(kernel_size=kernel_size)], + p=0.5), + torch_transforms.ToTensor(), + torch_transforms.Normalize( mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]) + + + def __call__(self, sample): + transform = self.data_transform + + xi = transform(sample) + xj = transform(sample) + + return xi, xj + + def get_composition(self): + return self.data_transform + + + +## For BalowTwins and VICRegularization +class BarlowTwinsTransformOrig: + def __init__(self, image_size = 256): + self.transform = torch_transforms.Compose([ + torch_transforms.RandomResizedCrop(image_size, + interpolation=InterpolationMode.BICUBIC), + torch_transforms.RandomHorizontalFlip(p=0.5), + torch_transforms.RandomApply( + [torch_transforms.ColorJitter(brightness=0.4, contrast=0.4, + saturation=0.2, hue=0.1)], + p=0.8 + ), + torch_transforms.RandomGrayscale(p=0.2), + GaussianBlur(p=1.0), + Solarization(p=0.0), + torch_transforms.ToTensor(), + torch_transforms.Normalize( mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]) + self.transform_prime = torch_transforms.Compose([ + torch_transforms.RandomResizedCrop(image_size, + interpolation=InterpolationMode.BICUBIC), + torch_transforms.RandomHorizontalFlip(p=0.5), + torch_transforms.RandomApply( + [torch_transforms.ColorJitter(brightness=0.4, contrast=0.4, + saturation=0.2, hue=0.1)], + p=0.8 + ), + torch_transforms.RandomGrayscale(p=0.2), + GaussianBlur(p=0.1), + Solarization(p=0.2), + torch_transforms.ToTensor(), + torch_transforms.Normalize( mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]) + + def __call__(self, x): + y1 = self.transform(x) + y2 = self.transform_prime(x) + return y1, y2 + + def get_composition(self): + return (self.transform, self.transform_prime) + + +##======================== UltraSound Images =================================== + + +class ClassifierTransform: + def __init__(self, mode = "train", image_size = 256): + + data_mean = [0.485, 0.456, 0.406] + data_std = [0.229, 0.224, 0.225] + + train_transform = torch_transforms.Compose([ + torch_transforms.Resize(image_size, + interpolation=InterpolationMode.BICUBIC), + torch_transforms.RandAugment(num_ops=5, magnitude=5), + torch_transforms.RandomHorizontalFlip(p=0.5), + torch_transforms.RandomVerticalFlip(p=0.5), + torch_transforms.ToTensor(), + torch_transforms.Normalize(mean=data_mean, std=data_std) + ]) + infer_transform = torch_transforms.Compose([ + torch_transforms.Resize(image_size, + interpolation=InterpolationMode.BICUBIC), + torch_transforms.ToTensor(), + torch_transforms.Normalize(mean=data_mean, std=data_std) + ]) + + if mode == "train": self.transform = train_transform + elif mode == "infer": self.transform = infer_transform + else : raise ValueError("Unknown Mode set only `train` or `infer` allowed") + + + def __call__(self, x): + y = self.transform(x) + return y + + def get_composition(self): + return self.transform + +##------------------------------------------------------------------------------ + + +# def phasecongruence(image): +# image = np.asarray(image) +# [M, ori, ft, T] = phasepack.phasecongmono(image, +# nscale=5, minWaveLength=5) +# out = ((M - M.min())/(M.max() - M.min()+1) *255.0).astype(np.uint8) +# out = Image.fromarray(out).convert("RGB") +# return out + + +class CustomInfoMaxTransform: + def __init__(self, image_size = 256): + + self.transform = torch_transforms.Compose([ + torch_transforms.RandomResizedCrop(256, + interpolation=InterpolationMode.BICUBIC), + torch_transforms.RandomHorizontalFlip(p=0.5), + # torch_transforms.Lambda(phasecongruence), + torch_transforms.RandAugment(num_ops=5, magnitude=5), + torch_transforms.ToTensor(), + torch_transforms.Normalize( mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]) + self.transform_prime = torch_transforms.Compose([ + torch_transforms.RandomResizedCrop(image_size, + interpolation=InterpolationMode.BICUBIC), + torch_transforms.RandomHorizontalFlip(p=0.5), + torch_transforms.RandomApply( + [torch_transforms.ColorJitter(brightness=0.4, contrast=0.4, + saturation=0.2, hue=0.1)], + p=0.8 + ), + torch_transforms.RandomGrayscale(p=0.2), + GaussianBlur(p=0.1), + Solarization(p=0.2), + torch_transforms.ToTensor(), + torch_transforms.Normalize( mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]) + + def __call__(self, x): + y1 = self.transform(x) + y2 = self.transform_prime(x) + return y1, y2 + + def get_composition(self): + return (self.transform, self.transform_prime) + + +##------------------------------------------------------------------------------ + + + +class AEncStandardTransform: + def __init__(self, image_size = 256): + + data_mean = [0.485, 0.456, 0.406] + data_std = [0.229, 0.224, 0.225] + + train_transform = torch_transforms.Compose([ + torch_transforms.Resize(image_size, + interpolation=InterpolationMode.BICUBIC), + torch_transforms.RandAugment(num_ops=3, magnitude=3), + torch_transforms.RandomHorizontalFlip(p=0.5), + torch_transforms.RandomVerticalFlip(p=0.5), + torch_transforms.ToTensor(), + torch_transforms.Normalize(mean=data_mean, std=data_std) + ]) + + self.transform = train_transform + + + def __call__(self, x): + y = self.transform(x) + return y, y + + def get_composition(self): + return self.transform + + + +class AEncInpaintTransform: + def __init__(self, image_size = 256): + + data_mean = [0.485, 0.456, 0.406] + data_std = [0.229, 0.224, 0.225] + + self.transform = torch_transforms.Compose([ + torch_transforms.Resize(image_size, + interpolation=InterpolationMode.BICUBIC), + torch_transforms.RandAugment(num_ops=1, magnitude=1), + torch_transforms.RandomHorizontalFlip(p=0.5), + torch_transforms.RandomVerticalFlip(p=0.5), + torch_transforms.ToTensor(), + torch_transforms.Normalize(mean=data_mean, std=data_std) + ]) + + self.erase = torch_transforms.RandomErasing(p=1, value="random") + + + def __call__(self, x): + y2 = self.transform(x) + y1 = self.erase(y2) + return y1, y2 + + def get_composition(self): + return (self.transform, self.erase) \ No newline at end of file