import os
import random
import numpy as np
from keras.utils import to_categorical, Sequence
from fetal.utils import AttributeDict as att_dict
from fetal_net.utils.utils import resize
from .augment import augment_data, random_permutation_x_y, get_image
from .utils import pickle_dump, pickle_load
from .utils.patches import get_patch_from_3d_data
class DataFileDummy:
def __init__(self, file, pad=0):
self.data = [np.pad(_, pad, 'constant', constant_values=_.min()) for _ in file.root.data]
self.truth = [np.pad(_, pad, 'constant', constant_values=0) for _ in file.root.truth]
if len(file.root.mask):
self.mask = [_ for _ in file.root.mask]
else:
self.mask = None
self.stats = att_dict(
p1=[np.percentile(_, q=1) for _ in self.data],
min=[np.min(_) for _ in self.data],
max=[np.max(_) for _ in self.data],
)
self.subject_ids = [_ for _ in file.root.subject_ids]
self.root = self
def pad_samples(data_file, patch_shape, truth_downsample):
output_shape = [patch_shape[0] // truth_downsample,
patch_shape[1] // truth_downsample,
1]
padding = np.ceil(np.subtract(patch_shape, output_shape) / 2).astype(int)
data_file.root.data = \
[np.pad(data, [(_, _) for _ in padding], 'constant', constant_values=data_min)
for data, data_min in zip(data_file.data, data_file.stats.min)]
data_file.root.truth = \
[np.pad(truth, [(_, _) for _ in padding], 'constant', constant_values=0)
for truth in data_file.truth]
data_file.root.data = \
[np.pad(data,
[(_, _) for _ in np.ceil(np.maximum(np.subtract(patch_shape, data.shape) + 1, 0) / 2).astype(int)],
'constant', constant_values=data_min)
for data, data_min in zip(data_file.data, data_file.stats.min)]
data_file.root.truth = \
[np.pad(truth,
[(_, _) for _ in np.ceil(np.maximum(np.subtract(patch_shape, truth.shape) + 1, 0) / 2).astype(int)],
'constant', constant_values=0)
for truth in data_file.truth]
def get_training_and_validation_generators(data_file, batch_size, n_labels, training_keys_file, validation_keys_file,
test_keys_file,
patch_shape=None, data_split=0.8, overwrite=False, labels=None, augment=None,
validation_batch_size=None, skip_blank_train=True, skip_blank_val=False,
truth_index=-1, truth_size=1, truth_downsample=None, truth_crop=True,
patches_per_epoch=1,
categorical=True, is3d=False,
prev_truth_index=None, prev_truth_size=None,
drop_easy_patches_train=False, drop_easy_patches_val=False,
samples_pad=3, val_augment=None):
"""
Creates the training and validation generators that can be used when training the model.
:param prev_truth_inedx:
:param categorical:
:param truth_downsample:
:param skip_blank: If True, any blank (all-zero) label images/patches will be skipped by the data generator.
:param validation_batch_size: Batch size for the validation data.
:param training_patch_start_offset: Tuple of length 3 containing integer values. Training data will randomly be
offset by a number of pixels between (0, 0, 0) and the given tuple. (default is None)
:param validation_patch_overlap: Number of pixels/voxels that will be overlapped in the validation data. (requires
patch_shape to not be None)
:param patch_shape: Shape of the data to return with the generator. If None, the whole image will be returned.
(default is None)
that the data will be distorted (in a stretching or shrinking fashion). Set to None, False, or 0 to prevent the
augmentation from distorting the data in this way.
:param augment: If not None, training data will be distorted on the fly so as to avoid over-fitting.
:param labels: List or tuple containing the ordered label values in the image files. The length of the list or tuple
should be equal to the n_labels value.
Example: (10, 25, 50)
The data generator would then return binary truth arrays representing the labels 10, 25, and 30 in that order.
:param data_file: hdf5 file to load the data from.
:param batch_size: Size of the batches that the training generator will provide.
:param n_labels: Number of binary labels.
:param training_keys_file: Pickle file where the index locations of the training data will be stored.
:param validation_keys_file: Pickle file where the index locations of the validation data will be stored.
:param data_split: How the training and validation data will be split. 0 means all the data will be used for
validation and none of it will be used for training. 1 means that all the data will be used for training and none
will be used for validation. Default is 0.8 or 80%.
:param overwrite: If set to True, previous files will be overwritten. The default mode is false, so that the
training and validation splits won't be overwritten when rerunning model training.
:return: Training data generator, validation data generator, number of training steps, number of validation steps
"""
if not validation_batch_size:
validation_batch_size = batch_size
data_file = DataFileDummy(data_file, samples_pad)
pad_samples(data_file, patch_shape, truth_downsample or 1)
training_list, validation_list, test_list = get_validation_split(data_file,
data_split=data_split,
overwrite=overwrite,
training_file=training_keys_file,
validation_file=validation_keys_file,
test_file=test_keys_file)
print("Training: {}".format([data_file.subject_ids[_].decode() for _ in training_list]))
print("Validation: {}".format([data_file.subject_ids[_].decode() for _ in validation_list]))
print("Test: {}".format([data_file.subject_ids[_].decode() for _ in test_list]))
# Set the number of training and testing samples per epoch correctly
num_training_steps = patches_per_epoch // batch_size
print("Number of training steps: ", num_training_steps)
num_validation_steps = patches_per_epoch // validation_batch_size
print("Number of validation steps: ", num_validation_steps)
training_generator = \
data_generator(data_file=data_file, index_list=training_list, batch_size=batch_size,
augment=augment,
n_labels=n_labels, labels=labels, patch_shape=patch_shape,
skip_blank=skip_blank_train,
truth_index=truth_index, truth_size=truth_size,
truth_downsample=truth_downsample, truth_crop=truth_crop,
categorical=categorical, is3d=is3d,
prev_truth_index=prev_truth_index, prev_truth_size=prev_truth_size,
drop_easy_patches=drop_easy_patches_train)
validation_generator = \
data_generator(data_file=data_file, index_list=validation_list, batch_size=validation_batch_size,
augment=val_augment,
n_labels=n_labels, labels=labels, patch_shape=patch_shape,
skip_blank=skip_blank_val,
truth_index=truth_index, truth_size=truth_size,
truth_downsample=truth_downsample, truth_crop=truth_crop,
categorical=categorical, is3d=is3d,
prev_truth_index=prev_truth_index, prev_truth_size=prev_truth_size,
drop_easy_patches=drop_easy_patches_val)
return training_generator, validation_generator, num_training_steps, num_validation_steps
def get_number_of_steps(n_samples, batch_size):
if n_samples <= batch_size:
return n_samples
elif np.remainder(n_samples, batch_size) == 0:
return n_samples // batch_size
else:
return n_samples // batch_size + 1
def get_validation_split(data_file, training_file, validation_file, test_file, data_split=0.8, overwrite=False):
"""
Splits the data into the training and validation indices list.
:param data_file: pytables hdf5 data file
:param training_file:
:param validation_file:
:param data_split:
:param o
verwrite:
:return:
"""
if overwrite or not os.path.exists(training_file):
print("Creating validation split...")
nb_samples = len(data_file.root.data)
sample_list = list(range(nb_samples))
random.shuffle(sample_list)
test_list = [sample_list.pop()]
training_list, validation_list = split_list(sample_list, split=data_split)
pickle_dump(training_list, training_file)
pickle_dump(validation_list, validation_file)
pickle_dump(test_list, test_file)
return training_list, validation_list, test_list
else:
print("Loading previous validation split...")
return pickle_load(training_file), pickle_load(validation_file), pickle_load(test_file)
def split_list(input_list, split=0.8, shuffle_list=True):
if shuffle_list:
random.shuffle(input_list)
n_training = int(len(input_list) * split)
training = input_list[:n_training]
testing = input_list[n_training:]
return training, testing
def random_list_generator(index_list):
while True:
np.random.seed()
yield from random.sample(index_list, len(index_list))
def list_generator(index_list):
while True:
yield from index_list
class FetalSequence(Sequence):
def __init__(self, epoch_size, **kargs):
self.kargs = kargs
self.generator = data_generator(**kargs)
self.epoch_size = epoch_size
def __len__(self):
return self.epoch_size
def __getitem__(self, idx):
next(self.generator)
def reset(self):
self.generator = data_generator(**self.kargs)
def data_generator(data_file, index_list, batch_size=1, n_labels=1, labels=None, augment=None, patch_shape=None,
shuffle_index_list=True, skip_blank=True, truth_index=-1, truth_size=1, truth_downsample=None,
truth_crop=True, categorical=True, prev_truth_index=None, prev_truth_size=None,
drop_easy_patches=False, is3d=False):
index_generator = random_list_generator(index_list) if shuffle_index_list else list_generator(index_list)
while True:
x_list = list()
y_list = list()
mask_list = list()
while len(x_list) < batch_size:
index = next(index_generator)
add_data(x_list, y_list, mask_list, data_file, index, augment=augment,
patch_shape=patch_shape, skip_blank=skip_blank,
truth_index=truth_index, truth_size=truth_size, truth_downsample=truth_downsample,
truth_crop=truth_crop, prev_truth_index=prev_truth_index,
prev_truth_size=prev_truth_size, drop_easy_patches=drop_easy_patches)
yield convert_data(x_list, y_list, mask_list, n_labels=n_labels, labels=labels, categorical=categorical,
is3d=is3d)
def add_data(x_list, y_list, mask_list, data_file, index, truth_index, truth_size=1, augment=None, patch_shape=None,
skip_blank=True,
truth_downsample=None, truth_crop=True, prev_truth_index=None, prev_truth_size=None,
drop_easy_patches=False):
"""
Adds data from the data file to the given lists of feature and target data
:param prev_truth_index:
:param truth_downsample:
:param skip_blank: Data will not be added if the truth vector is all zeros (default is True).
:param patch_shape: Shape of the patch to add to the data lists. If None, the whole image will be added.
:param x_list: list of data to which data from the data_file will be appended.
:param y_list: list of data to which the target data from the data_file will be appended.
:param data_file: hdf5 data file.
:param index: index of the data file from which to extract the data.
:param augment: if not None, data will be augmented according to the augmentation parameters
:return:
"""
data, truth, mask = get_data_from_file(data_file, index, patch_shape=None)
patch_corner = [
np.random.randint(low=low, high=high)
for low, high in zip((0, 0, 0), truth.shape - np.array(patch_shape)) # - np.array(patch_shape) // 2)
]
if augment is not None:
data_range = [(start, start + size) for start, size in zip(patch_corner, patch_shape)]
truth_range = data_range[:2] + [(patch_corner[2] + truth_index,
patch_corner[2] + truth_index + truth_size)]
if prev_truth_index is not None:
prev_truth_range = data_range[:2] + [(patch_corner[2] + prev_truth_index,
patch_corner[2] + prev_truth_index + prev_truth_size)]
else:
prev_truth_range = None
data, truth, prev_truth, mask = \
augment_data(data, truth,
data_min=data_file.stats.min[index],
data_max=data_file.stats.max[index],
mask=mask,
scale_deviation=augment.get('scale', None),
iso_scale_deviation=augment.get('iso_scale', None),
rotate_deviation=augment.get('rotate', None),
translate_deviation=augment.get('translate', None),
flip=augment.get('flip', None),
contrast_deviation=augment.get('contrast', None),
piecewise_affine=augment.get('piecewise_affine', None),
elastic_transform=augment.get('elastic_transform', None),
intensity_multiplication_range=augment.get('intensity_multiplication', None),
poisson_noise=augment.get("poisson_noise", None),
gaussian_noise=augment.get("gaussian_noise", None),
speckle_noise=augment.get("speckle_noise", None),
gaussian_filter=augment.get("gaussian_filter", None),
coarse_dropout=augment.get("coarse_dropout", None),
data_range=data_range, truth_range=truth_range,
prev_truth_range=prev_truth_range)
else:
data, truth, prev_truth, mask = \
extract_patch(data, patch_corner, patch_shape, truth, mask,
truth_index=truth_index, truth_size=truth_size,
prev_truth_index=prev_truth_index, prev_truth_size=prev_truth_size)
if prev_truth is not None:
data = np.concatenate([data, prev_truth], axis=-1)
if drop_easy_patches:
truth_mean = np.mean(truth[16:-16, 16:-16, :])
if 1 - np.abs(truth_mean - 0.5) < np.random.random():
return
if truth_downsample is not None and truth_downsample > 1:
truth_shape = patch_shape[:-1] + (1,)
new_shape = np.array(truth_shape)
new_shape[:-1] = new_shape[:-1] // truth_downsample
if truth_crop:
truth = get_patch_from_3d_data(truth,
new_shape,
list(np.subtract(truth_shape[:2], new_shape[:2]) // 2) + [1])
else:
truth = resize(get_image(truth), new_shape=new_shape).get_data()
if not skip_blank or np.any(truth != 0):
x_list.append(data)
y_list.append(truth)
if mask is not None:
mask_list.append(mask)
def extract_patch(data, patch_corner, patch_shape, truth, mask, truth_index, truth_size, prev_truth_index=None,
prev_truth_size=1):
data = get_patch_from_3d_data(data, patch_shape, patch_corner)
real_truth = get_patch_from_3d_data(truth,
patch_shape[:-1] + (truth_size,),
patch_corner + np.array((0, 0, truth_index)))
if mask is not None:
mask = get_patch_from_3d_data(mask,
patch_shape[:-1] + (truth_size,),
patch_corner + np.array((0, 0, truth_index)))
if prev_truth_index is not None:
prev_truth = get_patch_from_3d_data(truth,
patch_shape[:-1] + (prev_truth_size,),
patch_corner + np.array((0, 0, prev_truth_index)))
else:
prev_truth = None
return data, real_truth, prev_truth, mask
def extract_random_patch(data, patch_shape, truth, mask, truth_index, prev_truth_index):
# cut relevant patch
patch_corner = [
np.random.randint(low=low, high=high)
for low, high in zip((0, 0, 0), # -np.array(patch_shape) // 2,
truth.shape - np.array(patch_shape)) # - np.array(patch_shape) // 2)
]
return extract_patch(data, patch_corner, patch_shape, truth, mask, truth_index, prev_truth_index)
def get_data_from_file(data_file, index, patch_shape=None):
if patch_shape:
index, patch_index = index
data, truth, mask = get_data_from_file(data_file, index, patch_shape=None)
x = get_patch_from_3d_data(data, patch_shape, patch_index)
y = get_patch_from_3d_data(truth, patch_shape, patch_index)
if mask is not None:
z = get_patch_from_3d_data(mask, patch_shape, patch_index)
else:
z = None
else:
if data_file.root.mask is not None:
z = data_file.root.mask[index]
else:
z = None
x, y = data_file.root.data[index], data_file.root.truth[index]
return x, y, z
def convert_data(x_list, y_list, mask_list, n_labels=1, labels=None, categorical=True, is3d=False):
x = np.asarray(x_list)
y = np.asarray(y_list)
masks = np.asarray(mask_list)
# if n_labels == 1:
# y[y > 0] = 1
# elif n_labels > 1:
# y = get_multi_class_labels(y, n_labels=n_labels, labels=labels)
inputs = []
if categorical:
y = to_categorical(y, 2)
if is3d:
x = np.expand_dims(x, 1)
y = np.expand_dims(y, 1)
masks = np.expand_dims(mask_list, 1)
inputs = x
if len(masks) > 0:
inputs = [x, masks]
return inputs, y
def get_multi_class_labels(data, n_labels, labels=None):
"""
Translates a label map into a set of binary labels.
:param data: numpy array containing the label map with shape: (n_samples, 1, ...).
:param n_labels: number of labels.
:param labels: integer values of the labels.
:return: binary numpy array of shape: (n_samples, n_labels, ...)
"""
new_shape = [data.shape[0], n_labels] + list(data.shape[2:])
y = np.zeros(new_shape, np.int8)
for label_index in range(n_labels):
if labels is not None:
y[:, label_index][data[:, 0] == labels[label_index]] = 1
else:
y[:, label_index][data[:, 0] == (label_index + 1)] = 1
return y