In [None]:
import numpy as np

import sys
sys.path.append('..')

import nibabel as nib
from matplotlib import pyplot as plt

from fetal_net.augment import augment_data

In [None]:
def slice_it(arr, inds):
    return arr[inds[0][0]:inds[0][1], inds[1][0]: inds[1][1], inds[2][0]:inds[2][1]]

In [None]:
vol = nib.load('../../Datasets/fetus_window_1_99/255/volume.nii')
mask = nib.load('../../Datasets/fetus_window_1_99/255/truth.nii')
vol.shape

In [None]:
patch_corner = [70, 70, 30]
patch_shape = [128,128,5]
data_range = [(start, start + size) for start, size in zip(patch_corner, patch_shape)]
data_range

In [None]:
truth_index = 2
truth_size = 1
truth_range = data_range[:2] + [(patch_corner[2] + truth_index,
                                patch_corner[2] + truth_index + truth_size)]
truth_range

# Gaussian Filter

In [None]:
data = vol.get_fdata()
truth = mask.get_fdata()
data2, truth2, _ = augment_data(data, truth, data.min(), data.max(), data_range=data_range, truth_range=truth_range,
                           gaussian_filter={
            'max_sigma': 1.5,
            'prob': 1,
        }, poisson_noise=1)
plt.figure(figsize = (16,14))
plt.imshow(np.c_[slice_it(data, data_range)[..., 2], data2[..., 2]], cmap='gray')

# Shot Noise

In [None]:
data = vol.get_fdata()
truth = mask.get_fdata()
data2, truth2, _ = augment_data(data, truth, data.min(), data.max(), data_range=data_range, truth_range=truth_range,
                           poisson_noise=0.5)
plt.figure(figsize = (16,14))
plt.imshow(np.c_[slice_it(data, data_range)[..., 2], data2[..., 2]], cmap='gray')

# contrast deviation

In [None]:
data = vol.get_fdata()
truth = mask.get_fdata()
print(data.min(), data.max())
data2, truth2, _ = augment_data(data, truth, data.min(), data.max(), data_range=data_range, truth_range=truth_range,
                           contrast_deviation={'min_factor': 0.2, 'max_factor': 0.8})
plt.figure(figsize = (16,14))
plt.imshow(np.c_[slice_it(data, data_range)[..., 2], data2[..., 2]], cmap='gray')

# intensity_multiplication_range

In [None]:
data = vol.get_fdata()
truth = mask.get_fdata()
data2, truth2, _ = augment_data(data, truth, data.min(), data.max(), data_range=data_range, truth_range=truth_range,
                           intensity_multiplication_range=[0.8, 1.2])
plt.figure(figsize = (16,14))
plt.imshow(np.c_[slice_it(data, data_range)[..., 2], data2[..., 2]], cmap='gray')

# piecewise_affine

In [None]:
data = vol.get_fdata()
truth = mask.get_fdata()
data2, truth2, _ = augment_data(data, truth, data.min(), data.max(), data_range=data_range, truth_range=truth_range,
                           piecewise_affine={'scale': 0.5})
plt.figure(figsize = (16,14))
plt.imshow(np.c_[slice_it(data, data_range)[..., 2], data2[..., 2]], cmap='gray')

# elastic_transform

In [None]:
data = vol.get_fdata()
truth = mask.get_fdata()
data2, truth2, _ = augment_data(data, truth, data.min(), data.max(), data_range=data_range, truth_range=truth_range,
                           elastic_transform={'alpha': 5, 'sigma': 1})
plt.figure(figsize = (16,14))
plt.imshow(np.c_[slice_it(data, data_range)[..., 2], data2[..., 2]], cmap='gray')

# scale_deviation

In [None]:
data = vol.get_fdata()
truth = mask.get_fdata()
data2, truth2, _ = augment_data(data, truth, data.min(), data.max(), data_range=data_range, truth_range=truth_range,
                           scale_deviation=[0.1, 0.1, 0.0])
plt.figure(figsize = (16,14))
plt.imshow(np.c_[slice_it(data, data_range)[..., 2], data2[..., 2]], cmap='gray')

# rotate

In [None]:
data = vol.get_fdata()
truth = mask.get_fdata()
data2, truth2, _ = augment_data(data, truth, data.min(), data.max(), data_range=data_range, truth_range=truth_range,
                                rotate_deviation=[0, 0, 1800])
plt.figure(figsize = (16,14))
plt.imshow(np.c_[slice_it(data, data_range)[..., 2], data2[..., 2]], cmap='gray')

#print(slice_it(truth, truth_range).shape)
#print(truth2.shape)
plt.figure(figsize = (16,14))
plt.imshow(np.c_[slice_it(truth, truth_range)[..., 0], truth2[..., 0]], cmap='gray')

# Flip

In [None]:
data = vol.get_fdata()
truth = mask.get_fdata()
data2, truth2, _ = augment_data(data, truth, data.min(), data.max(), data_range=data_range, truth_range=truth_range,
                                flip=[0, 0, 1])
plt.figure(figsize = (16,14))
plt.imshow(np.c_[slice_it(data, data_range)[..., 2], data2[..., 2]], cmap='gray')

#print(slice_it(truth, truth_range).shape)
#print(truth2.shape)
plt.figure(figsize = (16,14))
plt.imshow(np.c_[slice_it(truth, truth_range)[..., 0], truth2[..., 0]], cmap='gray')

# Translate

In [None]:
data = vol.get_fdata()
truth = mask.get_fdata()
data2, truth2, _ = augment_data(data, truth, data.min(), data.max(), data_range=data_range, truth_range=truth_range,
                                translate_deviation=[0, 0, 10])
plt.figure(figsize = (16,14))
plt.imshow(np.c_[slice_it(data, data_range)[..., 2], data2[..., 2]], cmap='gray')

#print(slice_it(truth, truth_range).shape)
#print(truth2.shape)
plt.figure(figsize = (16,14))
plt.imshow(np.c_[slice_it(truth, truth_range)[..., 0], truth2[..., 0]], cmap='gray')

In [None]:
z_trans = 9

data_range2 = data_range.copy()
data_range2[-1] = np.add(data_range2[-1], z_trans)

plt.figure(figsize = (16,14))
plt.imshow(np.c_[slice_it(data, data_range2)[..., 2], data2[..., 2]], cmap='gray')

truth_range2 = truth_range
truth_range2[-1] = np.add(truth_range2[-1], z_trans)
plt.figure(figsize = (16,14))
plt.imshow(np.c_[slice_it(truth, truth_range2)[..., 0], truth2[..., 0]], cmap='gray')