--- a
+++ b/notebooks/albumenations_notebook.py
@@ -0,0 +1,533 @@
+# Notebook to check different augmentations
+
+# %%
+import matplotlib.pyplot as plt
+import os
+import yaml
+
+# https://github.com/albumentations-team/albumentations_examples/blob/master/notebooks/example_kaggle_salt.ipynb # noqa
+from albumentations import (
+    PadIfNeeded,
+    HorizontalFlip,
+    VerticalFlip,
+    CenterCrop,
+    Compose,
+    Transpose,
+    RandomRotate90,
+    ElasticTransform,
+    GridDistortion,
+    OpticalDistortion,
+    RandomSizedCrop,
+    RandomResizedCrop,
+    OneOf,
+    CLAHE,
+    RandomBrightnessContrast,
+    RandomGamma,
+    ShiftScaleRotate,
+    IAASharpen,
+    Blur,
+    MotionBlur,
+    ImageCompression,
+    IAAPerspective,
+    MultiplicativeNoise,
+)
+
+# enable lib loading even if not installed as a pip package or in PYTHONPATH
+# also convenient for relative paths in example config files
+from pathlib import Path
+
+os.chdir(Path(__file__).resolve().parent.parent)
+
+from adpkd_segmentation.config.config_utils import get_object_instance  # noqa
+from adpkd_segmentation.data.link_data import makelinks  # noqa
+from adpkd_segmentation.data.data_utils import ( # noqa
+    int16_to_uint8,
+    masks_to_colorimg,
+)
+
+# %%
+# needed only once
+# makelinks()
+
+# %%
+path = "./experiments/september06/random_split_new_data_less_albu/val/val.yaml"
+
+with open(path, "r") as f:
+    config = yaml.load(f, Loader=yaml.FullLoader)
+dataloader_config = config["_VAL_DATALOADER_CONFIG"]
+dataloader = get_object_instance(dataloader_config)()
+
+# %%
+# SET THIS INDEX for selecting img label in augmentations example
+IMG_IDX = 180
+dataset = dataloader.dataset
+x, y, index = dataset[IMG_IDX]
+
+# %%
+print("Dataset Length: {}".format(len(dataset)))
+print("image -> shape {},  dtype {}".format(x.shape, x.dtype))
+print("mask -> shape {},  dtype {}".format(y.shape, y.dtype))
+
+# %%
+print("Image and Mask: \n")
+image, mask = x[0, ...], y
+
+f, (ax1, ax2) = plt.subplots(1, 2)
+ax1.imshow(image, cmap="gray")
+ax2.imshow(image, cmap="gray")
+ax2.imshow(masks_to_colorimg(mask), alpha=0.5)
+
+
+# %%
+# from albumentation examples
+def visualize(image, mask, original_image=None, original_mask=None):
+    fontsize = 18
+
+    if original_image is None and original_mask is None:
+        f, ax = plt.subplots(2, 1, figsize=(8, 8))
+
+        ax[0].imshow(image)
+        ax[1].imshow(mask)
+    else:
+        f, ax = plt.subplots(2, 2, figsize=(8, 8))
+
+        ax[0, 0].imshow(original_image)
+        ax[0, 0].set_title("Original image", fontsize=fontsize)
+
+        ax[1, 0].imshow(original_mask)
+        ax[1, 0].set_title("Original mask", fontsize=fontsize)
+
+        ax[0, 1].imshow(image)
+        ax[0, 1].set_title("Transformed image", fontsize=fontsize)
+
+        ax[1, 1].imshow(mask)
+        ax[1, 1].set_title("Transformed mask", fontsize=fontsize)
+
+
+# %%
+# ORIGINAL
+mask = mask[0]
+visualize(image, mask)
+
+# %%
+# PADDING
+aug = PadIfNeeded(p=1, min_height=256, min_width=256)
+
+augmented = aug(image=image, mask=mask)
+
+image_padded = augmented["image"]
+mask_padded = augmented["mask"]
+
+print(image_padded.shape, mask_padded.shape)
+
+visualize(image_padded, mask_padded, original_image=image, original_mask=mask)
+
+# %%
+# CENTER CROP
+original_height, original_width = 224, 224
+
+aug = CenterCrop(p=1, height=original_height, width=original_width)
+
+augmented = aug(image=image_padded, mask=mask_padded)
+
+image_center_cropped = augmented["image"]
+mask_center_cropped = augmented["mask"]
+
+print(image_center_cropped.shape, mask_center_cropped.shape)
+
+assert (image - image_center_cropped).sum() == 0
+assert (mask - mask_center_cropped).sum() == 0
+
+visualize(
+    image_center_cropped,
+    mask_center_cropped,
+    original_image=image_padded,
+    original_mask=mask_padded,
+)
+
+
+# %%
+# Horizontal Flip
+aug = HorizontalFlip(p=1)
+
+augmented = aug(image=image, mask=mask)
+
+image_h_flipped = augmented["image"]
+mask_h_flipped = augmented["mask"]
+
+visualize(
+    image_h_flipped, mask_h_flipped, original_image=image, original_mask=mask
+)
+
+# %%
+
+# Vertical Flip
+aug = VerticalFlip(p=1)
+
+augmented = aug(image=image, mask=mask)
+
+image_v_flipped = augmented["image"]
+mask_v_flipped = augmented["mask"]
+
+visualize(
+    image_v_flipped, mask_v_flipped, original_image=image, original_mask=mask
+)
+
+# %%
+
+# RandomRotate90  (Randomly rotates by 0, 90, 180, 270 degrees)
+aug = RandomRotate90(p=1)
+
+augmented = aug(image=image, mask=mask)
+
+image_rot90 = augmented["image"]
+mask_rot90 = augmented["mask"]
+
+visualize(image_rot90, mask_rot90, original_image=image, original_mask=mask)
+# %%
+
+# Transpose (switch X and Y axis)
+aug = Transpose(p=1)
+
+augmented = aug(image=image, mask=mask)
+
+image_transposed = augmented["image"]
+mask_transposed = augmented["mask"]
+
+visualize(
+    image_transposed, mask_transposed, original_image=image, original_mask=mask
+)
+
+# %%
+
+# ElasticTransform
+aug = ElasticTransform(
+    p=1, alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03
+)
+
+augmented = aug(image=image, mask=mask)
+
+image_elastic = augmented["image"]
+mask_elastic = augmented["mask"]
+
+visualize(
+    image_elastic, mask_elastic, original_image=image, original_mask=mask
+)
+
+# %%
+
+# ElasticTransform default
+aug = ElasticTransform()
+
+augmented = aug(image=image, mask=mask)
+
+image_elastic = augmented["image"]
+mask_elastic = augmented["mask"]
+
+visualize(
+    image_elastic, mask_elastic, original_image=image, original_mask=mask
+)
+
+# %%
+
+# GridDistortion
+aug = GridDistortion(distort_limit=0.3, p=1)
+
+augmented = aug(image=image, mask=mask)
+
+image_grid = augmented["image"]
+mask_grid = augmented["mask"]
+
+visualize(image_grid, mask_grid, original_image=image, original_mask=mask)
+
+# %%
+# Optical Distortion
+aug = OpticalDistortion(p=1, distort_limit=1, shift_limit=0.3)
+
+augmented = aug(image=image, mask=mask)
+
+image_optical = augmented["image"]
+mask_optical = augmented["mask"]
+
+visualize(
+    image_optical, mask_optical, original_image=image, original_mask=mask
+)
+
+# %%
+# Optical Distortion 3
+aug = OpticalDistortion(p=1, distort_limit=1, shift_limit=0.3)
+
+augmented = aug(image=image, mask=mask)
+
+image_optical = augmented["image"]
+mask_optical = augmented["mask"]
+
+visualize(
+    image_optical, mask_optical, original_image=image, original_mask=mask
+)
+
+# %%
+# Optical Distortion default
+aug = OpticalDistortion(p=1)
+
+augmented = aug(image=image, mask=mask)
+
+image_optical = augmented["image"]
+mask_optical = augmented["mask"]
+
+visualize(
+    image_optical, mask_optical, original_image=image, original_mask=mask
+)
+
+# %%
+# Shift scale rotate
+aug = ShiftScaleRotate(
+    border_mode=0, rotate_limit=20, scale_limit=0.3, shift_limit=0.1
+)
+
+augmented = aug(image=image, mask=mask)
+
+image_optical = augmented["image"]
+mask_optical = augmented["mask"]
+
+visualize(
+    image_optical, mask_optical, original_image=image, original_mask=mask
+)
+
+# %%
+
+# RandomSizedCrop
+
+aug = RandomSizedCrop(p=1, min_max_height=(100, 200), height=128, width=128)
+
+augmented = aug(image=image, mask=mask)
+
+image_scaled = augmented["image"]
+mask_scaled = augmented["mask"]
+
+visualize(image_scaled, mask_scaled, original_image=image, original_mask=mask)
+
+# %%
+
+# RandomResizedCrop
+
+aug = RandomResizedCrop(p=1, height=72, width=72, scale=(0.25, 1.0))
+
+augmented = aug(image=image, mask=mask)
+
+image_scaled = augmented["image"]
+mask_scaled = augmented["mask"]
+
+visualize(image_scaled, mask_scaled, original_image=image, original_mask=mask)
+
+
+# %%
+# CLAHE
+aug = CLAHE()
+image8 = (image * 256).astype("uint8")
+augmented = aug(image=image8, mask=mask.astype("uint8"))
+
+image_scaled = augmented["image"]
+mask_scaled = augmented["mask"]
+
+visualize(image_scaled, mask_scaled, original_image=image8, original_mask=mask)
+
+
+# %%
+# RandomBrightnessContrast
+aug = RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2)
+image8 = (image * 256).astype("uint8")
+augmented = aug(image=image8, mask=mask)
+
+image_scaled = augmented["image"]
+mask_scaled = augmented["mask"]
+
+visualize(image_scaled, mask_scaled, original_image=image8, original_mask=mask)
+
+
+# %%
+# RandomGamma
+
+aug = RandomGamma(gamma_limit=(40, 200))
+image8 = (image * 256).astype("uint8")
+augmented = aug(image=image8, mask=mask)
+
+image_scaled = augmented["image"]
+mask_scaled = augmented["mask"]
+
+visualize(image_scaled, mask_scaled, original_image=image8, original_mask=mask)
+
+
+# %%
+# IAASharpen
+aug = IAASharpen(alpha=(0.1, 0.2), lightness=(0.5, 0.7))
+image8 = (image * 256).astype("uint8")
+augmented = aug(image=image8, mask=mask)
+
+image_scaled = augmented["image"]
+mask_scaled = augmented["mask"]
+
+visualize(image_scaled, mask_scaled, original_image=image8, original_mask=mask)
+
+
+# %%
+# Blur
+aug = Blur(blur_limit=2)
+image8 = (image * 256).astype("uint8")
+augmented = aug(image=image8, mask=mask)
+
+image_scaled = augmented["image"]
+mask_scaled = augmented["mask"]
+
+visualize(image_scaled, mask_scaled, original_image=image8, original_mask=mask)
+
+
+# %%
+# Motion Blur
+aug = MotionBlur(blur_limit=5)
+image8 = (image * 256).astype("uint8")
+augmented = aug(image=image8, mask=mask)
+
+image_scaled = augmented["image"]
+mask_scaled = augmented["mask"]
+
+visualize(image_scaled, mask_scaled, original_image=image8, original_mask=mask)
+
+
+# %%
+# Image Compression
+aug = ImageCompression(quality_lower=50, quality_upper=50)
+image8 = (image * 256).astype("uint8")
+augmented = aug(image=image8, mask=mask)
+
+image_scaled = augmented["image"]
+mask_scaled = augmented["mask"]
+
+visualize(image_scaled, mask_scaled, original_image=image8, original_mask=mask)
+
+
+# %%
+# IAAPerspective
+aug = IAAPerspective()
+image8 = (image * 256).astype("uint8")
+mask8 = mask.astype("uint8")
+augmented = aug(image=image8, mask=mask8)
+
+image_scaled = augmented["image"]
+mask_scaled = augmented["mask"]
+
+visualize(
+    image_scaled, mask_scaled, original_image=image8, original_mask=mask8
+)
+
+
+# %%
+# MultiplicativeNoise
+aug = MultiplicativeNoise(multiplier=(0.8, 1.2))
+image8 = (image * 256).astype("uint8")
+mask8 = mask.astype("uint8")
+augmented = aug(image=image8, mask=mask8)
+
+image_scaled = augmented["image"]
+mask_scaled = augmented["mask"]
+
+visualize(
+    image_scaled, mask_scaled, original_image=image8, original_mask=mask8
+)
+
+
+# %%
+# combine different transformations
+aug = Compose([VerticalFlip(p=0.5), RandomRotate90(p=0.5)])
+
+augmented = aug(image=image, mask=mask)
+
+image_light = augmented["image"]
+mask_light = augmented["mask"]
+
+visualize(image_light, mask_light, original_image=image, original_mask=mask)
+# %%
+
+# Medium augmentations
+aug = Compose(
+    [
+        OneOf(
+            [
+                RandomSizedCrop(
+                    min_max_height=(50, 101),
+                    height=original_height,
+                    width=original_width,
+                    p=0.5,
+                ),
+                PadIfNeeded(
+                    min_height=original_height, min_width=original_width, p=0.5
+                ),
+            ],
+            p=1,
+        ),
+        VerticalFlip(p=0.5),
+        RandomRotate90(p=0.5),
+        OneOf(
+            [
+                ElasticTransform(
+                    p=0.5, alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03
+                ),
+                GridDistortion(p=0.5),
+                OpticalDistortion(p=1, distort_limit=1, shift_limit=0.5),
+            ],
+            p=0.8,
+        ),
+    ]
+)
+
+augmented = aug(image=image, mask=mask)
+
+image_medium = augmented["image"]
+mask_medium = augmented["mask"]
+
+visualize(image_medium, mask_medium, original_image=image, original_mask=mask)
+# %%
+
+# Non-spatial stransformations
+aug = Compose(
+    [
+        OneOf(
+            [
+                RandomSizedCrop(
+                    min_max_height=(50, 90),
+                    height=original_height,
+                    width=original_width,
+                    p=0.5,
+                ),
+                PadIfNeeded(
+                    min_height=original_height, min_width=original_width, p=0.5
+                ),
+            ],
+            p=1,
+        ),
+        VerticalFlip(p=0.5),
+        RandomRotate90(p=0.5),
+        OneOf(
+            [
+                ElasticTransform(
+                    p=0.5, alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03
+                ),
+                GridDistortion(p=0.5),
+                OpticalDistortion(p=1, distort_limit=2, shift_limit=0.5),
+            ],
+            p=0.8,
+        ),
+        # CLAHE(p=0.8), # ONLY SUPORTS UINT8
+        RandomBrightnessContrast(p=0.8),
+        RandomGamma(p=0.8),
+    ]
+)
+
+augmented = aug(image=image, mask=mask)
+
+image_heavy = augmented["image"]
+mask_heavy = augmented["mask"]
+
+visualize(image_heavy, mask_heavy, original_image=image, original_mask=mask)
+
+# %%