--- a
+++ b/perturbation/model.py
@@ -0,0 +1,144 @@
+import albumentations as alb
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+import torch.nn as nn
+from perturbation.polyp_inpainter import *
+from perturbation.perturbator import RandomDraw
+
+
+class ModelOfNaturalVariationInpainter(nn.Module):
+    def __init__(self, T0=0, use_inpainter=False):
+        super(ModelOfNaturalVariationInpainter, self).__init__()
+        self.temp = T0
+        self.linstep = 0.1
+        self.use_inpainter = use_inpainter
+
+        self.pixelwise_augments, self.geometric_augments = self.get_encoded_transforms()
+
+    def _map_to_range(self, max, min=0):
+        # The temperature varies between 0 and 1, where 0 represents no augmentation and 1 represents the maximum augmentation
+        # Obviously, having for example a 100% reduction in brightness is not really productive.
+        return min + self.temp * (max - min)
+
+    def get_encoded_transforms(self):
+        quality_lower = max(100 - int(self._map_to_range(max=100)), 10)
+        quality_upper = np.clip(quality_lower + 10, 10, 100)
+        pixelwise = alb.Compose([
+            alb.ColorJitter(brightness=self._map_to_range(max=0.2),
+                            contrast=self._map_to_range(max=0.2),
+                            saturation=self._map_to_range(max=0.2),
+                            hue=self._map_to_range(max=0.05), p=0.5),
+            alb.GaussNoise(var_limit=self._map_to_range(max=0.01), p=0.5),
+            alb.ImageCompression(quality_lower=quality_lower,
+                                 quality_upper=quality_upper,
+                                 p=self.temp),
+        ]
+        )
+        geometric = alb.Compose([alb.RandomRotate90(p=0.5),
+                                 alb.Flip(p=0.5),
+                                 alb.OpticalDistortion(distort_limit=self.temp, p=self.temp)]
+                                )
+        return pixelwise, geometric
+
+    def forward(self, image, mask):
+        assert len(image.shape) == 4, "Image must be in BxCxHxW format"
+        augmented_imgs = torch.zeros_like(image)
+        augmented_masks = torch.zeros_like(mask)
+
+        for batch_idx in range(image.shape[0]):  # random transforms to every image in the batch
+            aug_img = image[batch_idx].squeeze().cpu().numpy().T
+            aug_mask = mask[batch_idx].squeeze().cpu().numpy().T
+            # if np.random.rand() < self.temp and self.use_inpainter:
+            #     # todo integrate inpainting
+            #     inpainting_mask_numpy = self.perturbator(rad=0.25)
+            #     inpainting_mask = torch.from_numpy(inpainting_mask_numpy).unsqueeze(0).to("cuda").float()
+            #     with torch.no_grad():
+            #         aug_img, polyp = self.inpainter(img=image[batch_idx], mask=inpainting_mask)
+            #     aug_img = aug_img[0].cpu().numpy().T  # TODO fix this filth
+            #     aug_mask = np.clip(aug_mask + inpainting_mask_numpy, 0, 1)
+            pixelwise = self.pixelwise_augments(image=aug_img)["image"]
+            geoms = self.geometric_augments(image=pixelwise, mask=aug_mask)
+            augmented_imgs[batch_idx] = torch.Tensor(geoms["image"].T)
+            augmented_masks[batch_idx] = torch.Tensor(geoms["mask"].T)
+        return augmented_imgs, augmented_masks
+
+    def step(self):
+        self.temp = np.clip(self.temp + self.linstep, 0, 1)
+        self.pixelwise_augments, self.geometric_augments = self.get_encoded_transforms()
+
+
+class ModelOfNaturalVariation(nn.Module):
+    def __init__(self, T0=0, use_inpainter=False):
+        super(ModelOfNaturalVariation, self).__init__()
+        self.temp = T0
+        self.linstep = 0.1
+        self.pixelwise_augments, self.geometric_augments = self.get_encoded_transforms()
+        self.use_inpainter = use_inpainter
+        if use_inpainter:
+            self.inpainter = Inpainter("Predictors/Inpainters/no-pretrain-deeplab-generator-4990")
+
+    def _map_to_range(self, max, min=0):
+        # The temperature varies between 0 and 1, where 0 represents no augmentation and 1 represents the maximum augmentation
+        # Obviously, having for example a 100% reduction in brightness is not really productive.
+        return min + self.temp * (max - min)
+
+    def get_encoded_transforms(self):
+        quality_lower = max(100 - int(self._map_to_range(max=100)), 10)
+        quality_upper = np.clip(quality_lower + 10, 10, 100)
+        pixelwise = alb.Compose([
+            alb.ColorJitter(brightness=self._map_to_range(max=0.2),
+                            contrast=self._map_to_range(max=0.2),
+                            saturation=self._map_to_range(max=0.2),
+                            hue=self._map_to_range(max=0.05), p=self.temp),
+            alb.GaussNoise(var_limit=self._map_to_range(max=0.01), p=self.temp),
+            alb.ImageCompression(quality_lower=quality_lower,
+                                 quality_upper=quality_upper,
+                                 p=self.temp)
+        ]
+        )
+        geometric = alb.Compose([alb.RandomRotate90(p=self.temp),
+                                 alb.Flip(p=self.temp),
+                                 alb.OpticalDistortion(distort_limit=self.temp, p=self.temp)])
+        return pixelwise, geometric
+
+    def forward(self, image, mask):
+        # assert len(image.shape) == 4, "Image must be in BxCxHxW format"
+
+        augmented_imgs = torch.zeros_like(image)
+        augmented_masks = torch.zeros_like(mask)
+        for batch_idx in range(image.shape[0]):  # random transforms to every image in the batch
+            aug_img = image[batch_idx].squeeze().cpu().numpy().T
+            aug_mask = mask[batch_idx].squeeze().cpu().numpy().T
+            if self.use_inpainter and np.random.rand(1) < 1:
+                aug_img, aug_mask = self.inpainter.add_polyp(aug_img, aug_mask)
+            pixelwise = self.pixelwise_augments(image=aug_img)["image"]
+            geoms = self.geometric_augments(image=pixelwise, mask=aug_mask)
+            augmented_imgs[batch_idx] = torch.Tensor(geoms["image"].T)
+            augmented_masks[batch_idx] = torch.Tensor(geoms["mask"].T)
+        return augmented_imgs, augmented_masks
+
+    def set_temp(self, temp):
+        self.temp = temp
+        self.pixelwise_augments, self.geometric_augments = self.get_encoded_transforms()
+
+    def step(self):
+        self.temp = np.clip(self.temp + self.linstep, 0, 1)
+        self.pixelwise_augments, self.geometric_augments = self.get_encoded_transforms()
+
+
+if __name__ == '__main__':
+    from data.hyperkvasir import KvasirSegmentationDataset
+    from torch.utils.data import DataLoader
+
+    mnv = ModelOfNaturalVariation(1)
+    for x, y, fname in DataLoader(KvasirSegmentationDataset("Datasets/HyperKvasir", augment=False)):
+        img = x
+        mask = y
+        aug_img, aug_mask = mnv(img, mask)
+        plt.imshow(x[0].T)
+        plt.axis("off")
+        plt.savefig(f"experiments/Data/augmentation_samples/unaugmented_{fname}.png", bbox_inches='tight')
+        plt.imshow(aug_img[0].T)
+        plt.axis("off")
+        plt.savefig(f"experiments/Data/augmentation_samples/augmented_{fname}.png", bbox_inches='tight')