--- a
+++ b/fetal_net/augment.py
@@ -0,0 +1,471 @@
+import numpy as np
+from nilearn.image import resample_to_img
+import random
+import itertools
+
+from skimage.exposure import exposure
+from skimage.filters import gaussian
+from skimage.util import random_noise
+
+from fetal_net.utils.utils import get_image, interpolate_affine_range, MinMaxScaler
+from imgaug import augmenters as iaa
+
+
+def scale_image(affine, scale_factor):
+    scale_factor = np.diag(list(scale_factor) + [1])
+    new_affine = scale_factor.dot(affine)
+    return new_affine
+
+
+def translate_image(affine, translate_factor):
+    translate_factor = np.asarray(translate_factor)
+    new_affine = np.copy(affine)
+    new_affine[0:3, 3] = new_affine[0:3, 3] + translate_factor
+    return new_affine
+
+
+def rotate_image_axis(affine, rotate_factor, axis):
+    return {
+        0: rotate_image_x,
+        1: rotate_image_y,
+        2: rotate_image_z
+    }[axis](affine, rotate_factor)
+
+
+def rotate_image_x(affine, rotate_factor):
+    sin_gamma = np.sin(rotate_factor)
+    cos_gamma = np.cos(rotate_factor)
+    rotation_affine = np.array([[1, 0, 0, 0],
+                                [0, cos_gamma, -sin_gamma, 0],
+                                [0, sin_gamma, cos_gamma, 0],
+                                [0, 0, 0, 1]])
+    new_affine = rotation_affine.dot(affine)
+    return new_affine
+
+
+def rotate_image_y(affine, rotate_factor):
+    sin_gamma = np.sin(rotate_factor)
+    cos_gamma = np.cos(rotate_factor)
+    rotation_affine = np.array([[cos_gamma, 0, sin_gamma, 0],
+                                [0, 1, 0, 0],
+                                [-sin_gamma, 0, cos_gamma, 0],
+                                [0, 0, 0, 1]])
+    new_affine = rotation_affine.dot(affine)
+    return new_affine
+
+
+def rotate_image_z(affine, rotate_factor):
+    sin_gamma = np.sin(rotate_factor)
+    cos_gamma = np.cos(rotate_factor)
+    rotation_affine = np.array([[cos_gamma, -sin_gamma, 0, 0],
+                                [sin_gamma, cos_gamma, 0, 0],
+                                [0, 0, 1, 0],
+                                [0, 0, 0, 1]])
+    new_affine = rotation_affine.dot(affine)
+    return new_affine
+
+
+def rotate_image(affine, rotate_angles):
+    new_affine = np.copy(affine)
+
+    # apply rotations
+    for i, rotate_angle in enumerate(rotate_angles):
+        if rotate_angle != 0:
+            new_affine = rotate_image_axis(new_affine, rotate_angle, axis=i)
+
+    return new_affine
+
+
+def flip_image(affine, axis):
+    new_affine = np.copy(affine)
+
+    for ax in axis:
+        new_affine = rotate_image_axis(new_affine, np.deg2rad(180), axis=ax)
+
+    return new_affine
+
+
+def shot_noise(data):
+    mm_scaler = MinMaxScaler((0, 1))
+    data = mm_scaler.fit_transform(data)
+    # TODO: remove hardcoded quantization number :(
+    data = np.floor(data * 1023) / 1023  # quantization of the data is needed before poisson noise
+    # TODO: check if clip=True really needed
+    new_data = random_noise(data, mode='poisson', clip=True)
+    return mm_scaler.inverse_transform(new_data)
+
+
+def add_gaussian_noise(data, sigma):
+    mm_scaler = MinMaxScaler((0, 1))
+    data = mm_scaler.fit_transform(data)
+    new_data = random_noise(data, mode='gaussian', clip=True, var=sigma**2)
+    return mm_scaler.inverse_transform(new_data)
+
+
+def add_speckle_noise(data, sigma):
+    mm_scaler = MinMaxScaler((0, 1))
+    data = mm_scaler.fit_transform(data)
+    new_data = random_noise(data, mode='speckle', clip=True, var=sigma**2)
+    return mm_scaler.inverse_transform(new_data)
+
+
+def apply_gaussian_filter(data, sigma):
+    return gaussian(data, sigma=sigma)
+
+
+def apply_coarse_dropout(data, rate, size_percent, per_channel=True):
+    mm_scaler = MinMaxScaler((0, 255))
+    data = mm_scaler.fit_transform(data)
+    new_data = iaa.CoarseDropout(p=rate, size_percent=size_percent, per_channel=per_channel).augment_image(data)
+    return mm_scaler.inverse_transform(new_data)
+
+
+def contrast_augment(data, min_per, max_per):
+    # in_range = (np.percentile(data, q=min_per), np.percentile(data, q=max_per))
+    in_range = (min_per, max_per)
+    return exposure.rescale_intensity(data, in_range=in_range, out_range='image')
+
+
+def apply_piecewise_affine(data, truth, prev_truth, mask, scale):
+    rs = np.random.RandomState()
+    vol_pa_transform = iaa.PiecewiseAffine(scale, nb_cols=2, nb_rows=2, order=1, random_state=rs, deterministic=True)
+    truth_pa_transform = iaa.PiecewiseAffine(scale, nb_cols=2, nb_rows=2, order=0, random_state=rs, deterministic=True)
+    data = vol_pa_transform.augment_image(data)
+    truth = truth_pa_transform.augment_image(truth)
+
+    if prev_truth is not None:
+        prev_truth_pa_transform = iaa.PiecewiseAffine(scale, nb_cols=2, nb_rows=2, order=0, random_state=rs,
+                                                      deterministic=True)
+        prev_truth = prev_truth_pa_transform.augment_image(prev_truth)
+
+    if mask is not None:
+        mask_pa_transform = iaa.PiecewiseAffine(scale, nb_cols=2, nb_rows=2, order=0, random_state=rs,
+                                                deterministic=True)
+        mask = mask_pa_transform.augment_image(mask)
+
+    return data, truth, prev_truth, mask
+
+
+def apply_elastic_transform(data, truth, prev_truth, mask, alpha, sigma):
+    rs = np.random.RandomState()
+    vol_et_transform = iaa.ElasticTransformation(alpha=alpha, sigma=sigma, order=1, random_state=rs, deterministic=True,
+                                                 mode="nearest")
+    truth_et_transform = iaa.ElasticTransformation(alpha=alpha, sigma=sigma, order=0, random_state=rs,
+                                                   deterministic=True, mode="nearest")
+
+    data = vol_et_transform.augment_image(data)
+    truth = truth_et_transform.augment_image(truth)
+
+    if prev_truth is not None:
+        prev_truth_et_transform = iaa.ElasticTransformation(alpha=alpha, sigma=sigma, order=0, random_state=rs,
+                                                            deterministic=True, mode="nearest")
+        prev_truth = prev_truth_et_transform.augment_image(prev_truth)
+
+    if mask is not None:
+        mask_et_transform = iaa.ElasticTransformation(alpha=alpha, sigma=sigma, order=0, random_state=rs,
+                                                      deterministic=True, mode="nearest")
+        mask = mask_et_transform.augment_image(mask)
+
+    return data, truth, prev_truth, mask
+
+
+def random_scale_factor(n_dim=3, mean=1, std=0.25):
+    return np.random.normal(mean, std, n_dim)
+
+
+def random_translate_factor(n_dim=3, min=0, max=7):
+    return np.random.uniform(min, max, n_dim)
+
+
+def random_rotation_angle(n_dim=3, mean=0, std=5):
+    return np.random.uniform(low=mean - np.array(std), high=mean + np.array(std), size=n_dim)
+
+
+def random_boolean():
+    return np.random.choice([True, False])
+
+
+def distort_image(data, affine, flip_axis=None, scale_factor=None, rotate_factor=None, translate_factor=None):
+    # print('Affine1: ', str(affine))
+
+    # translate center of image to 0,0,0
+    center_offset = np.array(data.shape) / 2
+    affine = translate_image(affine, -center_offset)
+    # print('Affine - center offset: ', str(affine))
+
+    if flip_axis is not None:
+        affine = flip_image(affine, flip_axis)
+    if scale_factor is not None:
+        affine = scale_image(affine, scale_factor)
+    if rotate_factor is not None:
+        affine = rotate_image(affine, rotate_factor)
+
+    # translate image back to original coordinates
+    affine = translate_image(affine, +center_offset)
+
+    if translate_factor is not None:
+        affine = translate_image(affine, translate_factor)
+
+    return data, affine
+
+
+def random_flip_dimensions(n_dim, flip_factor):
+    return np.arange(n_dim)[
+        [flip_rate > random.random()
+         for flip_rate in flip_factor]
+    ]
+
+
+def augment_data(data, truth, data_min, data_max, mask=None, scale_deviation=None, iso_scale_deviation=None,
+                 rotate_deviation=None,
+                 translate_deviation=None, flip=None, contrast_deviation=None,
+                 poisson_noise=None, gaussian_noise=None, speckle_noise=None,
+                 piecewise_affine=None, elastic_transform=None, intensity_multiplication_range=None,
+                 gaussian_filter=None, coarse_dropout=None, data_range=None, truth_range=None, prev_truth_range=None):
+    n_dim = len(truth.shape)
+    if scale_deviation:
+        scale_factor = random_scale_factor(n_dim, std=scale_deviation)
+    else:
+        scale_factor = [1, 1, 1]
+    if iso_scale_deviation:
+        iso_scale_factor = np.random.uniform(1, iso_scale_deviation["max"])
+        if random_boolean():
+            iso_scale_factor = 1 / iso_scale_factor
+        scale_factor[0] *= iso_scale_factor
+        scale_factor[1] *= iso_scale_factor
+    else:
+        iso_scale_factor = None
+    if rotate_deviation:
+        rotate_factor = random_rotation_angle(n_dim, std=rotate_deviation)
+        rotate_factor = np.deg2rad(rotate_factor)
+    else:
+        rotate_factor = None
+    if flip is not None and flip:
+        flip_axis = random_flip_dimensions(n_dim, flip)
+    else:
+        flip_axis = None
+    if translate_deviation is not None:
+        translate_factor = random_translate_factor(n_dim, -np.array(translate_deviation), np.array(translate_deviation))
+        translate_factor[-1] = np.floor(translate_factor[-1])  # z-translate should be int
+    else:
+        translate_factor = None
+    if contrast_deviation is not None:
+        val_range = data_max - data_min
+        contrast_min_val = data_min + contrast_deviation["min_factor"] * np.random.uniform(-1, 1) * val_range
+        contrast_max_val = data_max + contrast_deviation["max_factor"] * np.random.uniform(-1, 1) * val_range
+    else:
+        contrast_min_val, contrast_max_val = None, None
+    if poisson_noise is not None:
+        apply_poisson_noise = poisson_noise > np.random.random()
+    else:
+        apply_poisson_noise = False
+    if gaussian_noise is not None:
+        apply_gaussian_noise = gaussian_noise["prob"] > np.random.random()
+    else:
+        apply_gaussian_noise = False
+    if speckle_noise is not None:
+        apply_speckle_noise = speckle_noise["prob"] > np.random.random()
+    else:
+        apply_speckle_noise = False
+
+    if gaussian_filter is not None and gaussian_filter["prob"] > 0:
+        gaussian_sigma = gaussian_filter["max_sigma"] * np.random.random()
+        apply_gaussian = gaussian_filter["prob"] > np.random.random()
+    else:
+        apply_gaussian, gaussian_sigma = False, None
+    if piecewise_affine is not None:
+        piecewise_affine_scale = np.random.random() * piecewise_affine["scale"]
+    else:
+        piecewise_affine_scale = 0
+    if (elastic_transform is not None) and (elastic_transform["alpha"] > 0):
+        elastic_transform_scale = np.random.random() * elastic_transform["alpha"]
+    else:
+        elastic_transform_scale = 0
+    if intensity_multiplication_range is not None:
+        a, b = intensity_multiplication_range
+        intensity_multiplication = np.random.random() * (b - a) + a
+    else:
+        intensity_multiplication = 1
+    if coarse_dropout is not None:
+        coarse_dropout_rate = coarse_dropout['rate']
+        coarse_dropout_size = coarse_dropout['size_percent']
+
+    image, affine = data, np.eye(4)
+    distorted_data, distorted_affine = distort_image(image, affine,
+                                                     flip_axis=flip_axis,
+                                                     scale_factor=scale_factor,
+                                                     rotate_factor=rotate_factor,
+                                                     translate_factor=translate_factor)
+    if data_range is None:
+        data = resample_to_img(get_image(distorted_data, distorted_affine), image, interpolation="continuous",
+                               copy=False, clip=True).get_fdata()
+    else:
+
+        data = interpolate_affine_range(distorted_data, distorted_affine, data_range, order=1, mode='constant',
+                                        cval=data_min)
+
+    truth_image, truth_affine = truth, np.eye(4)
+    distorted_truth_data, distorted_truth_affine = distort_image(truth_image, truth_affine,
+                                                                 flip_axis=flip_axis,
+                                                                 scale_factor=scale_factor,
+                                                                 rotate_factor=rotate_factor,
+                                                                 translate_factor=translate_factor)
+    if truth_range is None:
+        truth_data = resample_to_img(get_image(distorted_truth_data, distorted_truth_affine), truth_image,
+                                     interpolation="nearest", copy=False,
+                                     clip=True).get_data()
+    else:
+        truth_data = interpolate_affine_range(distorted_truth_data, distorted_truth_affine,
+                                              truth_range, order=0, mode='constant', cval=0)
+
+    if prev_truth_range is None:
+        prev_truth_data = None
+    else:
+        prev_truth_data = interpolate_affine_range(distorted_truth_data, distorted_truth_affine,
+                                                   prev_truth_range, order=0, mode='constant', cval=0)
+
+    if mask is None:
+        mask_data = None
+    else:
+        mask_image, mask_affine = mask, np.eye(4)
+        distorted_mask_data, distorted_mask_affine = distort_image(mask_image, mask_affine,
+                                                                   flip_axis=flip_axis,
+                                                                   scale_factor=scale_factor,
+                                                                   rotate_factor=rotate_factor,
+                                                                   translate_factor=translate_factor)
+        if truth_range is None:
+            mask_data = resample_to_img(get_image(distorted_mask_data, distorted_mask_affine), mask_image,
+                                        interpolation="nearest", copy=False,
+                                        clip=True).get_data()
+        else:
+            mask_data = interpolate_affine_range(distorted_mask_data, distorted_mask_affine,
+                                                 truth_range, order=0, mode='constant', cval=0)
+
+    if piecewise_affine_scale > 0:
+        data, truth_data, prev_truth_data, mask_data = apply_piecewise_affine(data, truth_data,
+                                                                              prev_truth_data, mask_data,
+                                                                              piecewise_affine_scale)
+
+    if elastic_transform_scale > 0:
+        data, truth_data, prev_truth_data, mask_data = apply_elastic_transform(data, truth_data,
+                                                                               prev_truth_data, mask_data,
+                                                                               elastic_transform_scale,
+                                                                               elastic_transform["sigma"])
+
+    if contrast_deviation is not None:
+        data = contrast_augment(data, contrast_min_val, contrast_max_val)
+
+    if intensity_multiplication != 1:
+        data = data * intensity_multiplication
+
+    if apply_gaussian:
+        data = apply_gaussian_filter(data, gaussian_sigma)
+
+    if apply_poisson_noise:
+        data = shot_noise(data)
+
+    if apply_speckle_noise:
+        data = add_speckle_noise(data, speckle_noise["sigma"])
+
+    if apply_gaussian_noise:
+        data = add_gaussian_noise(data, gaussian_noise["sigma"])
+
+    if coarse_dropout is not None:
+        data = apply_coarse_dropout(data, rate=coarse_dropout_rate, size_percent=coarse_dropout_size,
+                                    per_channel=coarse_dropout["per_channel"])
+
+    return data, truth_data, prev_truth_data, mask_data
+
+
+def generate_permutation_keys():
+    """
+    This function returns a set of "keys" that represent the 48 unique rotations &
+    reflections of a 3D matrix.
+
+    Each item of the set is a tuple:
+    ((rotate_y, rotate_z), flip_x, flip_y, flip_z, transpose)
+
+    As an example, ((0, 1), 0, 1, 0, 1) represents a permutation in which the data is
+    rotated 90 degrees around the z-axis, then reversed on the y-axis, and then
+    transposed.
+
+    48 unique rotations & reflections:
+    https://en.wikipedia.org/wiki/Octahedral_symmetry#The_isometries_of_the_cube
+    """
+    return set(itertools.product(
+        itertools.combinations_with_replacement(range(2), 2), range(2), range(2), range(2), range(2)))
+
+
+def random_permutation_key():
+    """
+    Generates and randomly selects a permutation key. See the documentation for the
+    "generate_permutation_keys" function.
+    """
+    return random.choice(list(generate_permutation_keys()))
+
+
+def permute_data(data, key):
+    """
+    Permutes the given data according to the specification of the given key. Input data
+    must be of shape (n_modalities, x, y, z).
+
+    Input key is a tuple: (rotate_y, rotate_z), flip_x, flip_y, flip_z, transpose)
+
+    As an example, ((0, 1), 0, 1, 0, 1) represents a permutation in which the data is
+    rotated 90 degrees around the z-axis, then reversed on the y-axis, and then
+    transposed.
+    """
+    data = np.copy(data)
+    (rotate_y, rotate_z), flip_x, flip_y, flip_z, transpose = key
+
+    if rotate_y != 0:
+        data = np.rot90(data, rotate_y, axes=(1, 2))
+    # if rotate_z != 0:
+    #     data = np.rot90(data, rotate_z, axes=(2, 3))
+    if flip_x:
+        data = data[:, ::-1]
+    if flip_y:
+        data = data[:, :, ::-1]
+    if flip_z:
+        data = data[:, :, :, ::-1]
+    # if transpose:
+    #    for i in range(data.shape[0]):
+    #        data[i] = data[i].T
+    return data
+
+
+def random_permutation_x_y(x_data, y_data):
+    """
+    Performs random permutation on the data.
+    :param x_data: numpy array containing the data. Data must be of shape (n_modalities, x, y, z).
+    :param y_data: numpy array containing the data. Data must be of shape (n_modalities, x, y, z).
+    :return: the permuted data
+    """
+    key = random_permutation_key()
+    return permute_data(x_data, key), permute_data(y_data, key)
+
+
+def reverse_permute_data(data, key):
+    key = reverse_permutation_key(key)
+    data = np.copy(data)
+    (rotate_y, rotate_z), flip_x, flip_y, flip_z, transpose = key
+
+    # if transpose:
+    #    for i in range(data.shape[0]):
+    #        data[i] = data[i].T
+    if flip_z:
+        data = data[:, :, :, ::-1]
+    if flip_y:
+        data = data[:, :, ::-1]
+    if flip_x:
+        data = data[:, ::-1]
+    # if rotate_z != 0:
+    #    data = np.rot90(data, rotate_z, axes=(2, 3))
+    if rotate_y != 0:
+        data = np.rot90(data, rotate_y, axes=(1, 2))
+    return data
+
+
+def reverse_permutation_key(key):
+    rotation = tuple([-rotate for rotate in key[0]])
+    return rotation, key[1], key[2], key[3], key[4]