--- a +++ b/ext/lab2im/layers.py @@ -0,0 +1,2060 @@ +""" +This file regroups several custom keras layers used in the generation model: + - RandomSpatialDeformation, + - RandomCrop, + - RandomFlip, + - SampleConditionalGMM, + - SampleResolution, + - GaussianBlur, + - DynamicGaussianBlur, + - MimicAcquisition, + - BiasFieldCorruption, + - IntensityAugmentation, + - DiceLoss, + - WeightedL2Loss, + - ResetValuesToZero, + - ConvertLabels, + - PadAroundCentre, + - MaskEdges + - ImageGradients + - RandomDilationErosion + + +If you use this code, please cite the first SynthSeg paper: +https://github.com/BBillot/lab2im/blob/master/bibtex.bib + +Copyright 2020 Benjamin Billot + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License at +https://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software distributed under the License is +distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +implied. See the License for the specific language governing permissions and limitations under the +License. +""" + + +# python imports +import keras +import numpy as np +import tensorflow as tf +import keras.backend as K +from keras.layers import Layer + +# project imports +from ext.lab2im import utils +from ext.lab2im import edit_tensors as l2i_et + +# third-party imports +from ext.neuron import utils as nrn_utils +import ext.neuron.layers as nrn_layers + + +class RandomSpatialDeformation(Layer): + """This layer spatially deforms one or several tensors with a combination of affine and elastic transformations. + The input tensors are expected to have the same shape [batchsize, shape_dim1, ..., shape_dimn, channel]. + The non-linear deformation is obtained by: + 1) a small-size SVF is sampled from a centred normal distribution of random standard deviation. + 2) it is resized with trilinear interpolation to half the shape of the input tensor + 3) it is integrated to obtain a diffeomorphic transformation + 4) finally, it is resized (again with trilinear interpolation) to full image size + :param scaling_bounds: (optional) range of the random scaling to apply. The scaling factor for each dimension is + sampled from a uniform distribution of predefined bounds. Can either be: + 1) a number, in which case the scaling factor is independently sampled from the uniform distribution of bounds + [1-scaling_bounds, 1+scaling_bounds] for each dimension. + 2) a sequence, in which case the scaling factor is sampled from the uniform distribution of bounds + (1-scaling_bounds[i], 1+scaling_bounds[i]) for the i-th dimension. + 3) a numpy array of shape (2, n_dims), in which case the scaling factor is sampled from the uniform distribution + of bounds (scaling_bounds[0, i], scaling_bounds[1, i]) for the i-th dimension. + 4) False, in which case scaling is completely turned off. + Default is scaling_bounds = 0.15 (case 1) + :param rotation_bounds: (optional) same as scaling bounds but for the rotation angle, except that for cases 1 + and 2, the bounds are centred on 0 rather than 1, i.e. [0+rotation_bounds[i], 0-rotation_bounds[i]]. + Default is rotation_bounds = 15. + :param shearing_bounds: (optional) same as scaling bounds. Default is shearing_bounds = 0.012. + :param translation_bounds: (optional) same as scaling bounds. Default is translation_bounds = False, but we + encourage using it when cropping is deactivated (i.e. when output_shape=None in BrainGenerator). + :param enable_90_rotations: (optional) whether to rotate the input by a random angle chosen in {0, 90, 180, 270}. + This is done regardless of the value of rotation_bounds. If true, a different value is sampled for each dimension. + :param nonlin_std: (optional) maximum value of the standard deviation of the normal distribution from which we + sample the small-size SVF. Set to 0 if you wish to completely turn the elastic deformation off. + :param nonlin_scale: (optional) if nonlin_std is not False, factor between the shapes of the input tensor + and the shape of the input non-linear tensor. + :param inter_method: (optional) interpolation method when deforming the input tensor. Can be 'linear', or 'nearest' + :param prob_deform: (optional) probability to apply spatial deformation + """ + + def __init__(self, + scaling_bounds=0.15, + rotation_bounds=10, + shearing_bounds=0.02, + translation_bounds=False, + enable_90_rotations=False, + nonlin_std=4., + nonlin_scale=.0625, + inter_method='linear', + prob_deform=1, + **kwargs): + + # shape attributes + self.n_inputs = 1 + self.inshape = None + self.n_dims = None + self.small_shape = None + + # deformation attributes + self.scaling_bounds = scaling_bounds + self.rotation_bounds = rotation_bounds + self.shearing_bounds = shearing_bounds + self.translation_bounds = translation_bounds + self.enable_90_rotations = enable_90_rotations + self.nonlin_std = nonlin_std + self.nonlin_scale = nonlin_scale + + # boolean attributes + self.apply_affine_trans = (self.scaling_bounds is not False) | (self.rotation_bounds is not False) | \ + (self.shearing_bounds is not False) | (self.translation_bounds is not False) | \ + self.enable_90_rotations + self.apply_elastic_trans = self.nonlin_std > 0 + self.prob_deform = prob_deform + + # interpolation methods + self.inter_method = inter_method + + super(RandomSpatialDeformation, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["scaling_bounds"] = self.scaling_bounds + config["rotation_bounds"] = self.rotation_bounds + config["shearing_bounds"] = self.shearing_bounds + config["translation_bounds"] = self.translation_bounds + config["enable_90_rotations"] = self.enable_90_rotations + config["nonlin_std"] = self.nonlin_std + config["nonlin_scale"] = self.nonlin_scale + config["inter_method"] = self.inter_method + config["prob_deform"] = self.prob_deform + return config + + def build(self, input_shape): + + if not isinstance(input_shape, list): + inputshape = [input_shape] + else: + self.n_inputs = len(input_shape) + inputshape = input_shape + self.inshape = inputshape[0][1:] + self.n_dims = len(self.inshape) - 1 + + if self.apply_elastic_trans: + self.small_shape = utils.get_resample_shape(self.inshape[:self.n_dims], + self.nonlin_scale, self.n_dims) + else: + self.small_shape = None + + self.inter_method = utils.reformat_to_list(self.inter_method, length=self.n_inputs, dtype='str') + + self.built = True + super(RandomSpatialDeformation, self).build(input_shape) + + def call(self, inputs, **kwargs): + + # reformat inputs and get its shape + if self.n_inputs < 2: + inputs = [inputs] + types = [v.dtype for v in inputs] + inputs = [tf.cast(v, dtype='float32') for v in inputs] + batchsize = tf.split(tf.shape(inputs[0]), [1, self.n_dims + 1])[0] + + # initialise list of transforms to operate + list_trans = list() + + # add affine deformation to inputs list + if self.apply_affine_trans: + affine_trans = utils.sample_affine_transform(batchsize, + self.n_dims, + self.rotation_bounds, + self.scaling_bounds, + self.shearing_bounds, + self.translation_bounds, + self.enable_90_rotations) + list_trans.append(affine_trans) + + # prepare non-linear deformation field and add it to inputs list + if self.apply_elastic_trans: + + # sample small field from normal distribution of specified std dev + trans_shape = tf.concat([batchsize, tf.convert_to_tensor(self.small_shape, dtype='int32')], axis=0) + trans_std = tf.random.uniform((1, 1), maxval=self.nonlin_std) + elastic_trans = tf.random.normal(trans_shape, stddev=trans_std) + + # reshape this field to half size (for smoother SVF), integrate it, and reshape to full image size + resize_shape = [max(int(self.inshape[i] / 2), self.small_shape[i]) for i in range(self.n_dims)] + elastic_trans = nrn_layers.Resize(size=resize_shape, interp_method='linear')(elastic_trans) + elastic_trans = nrn_layers.VecInt()(elastic_trans) + elastic_trans = nrn_layers.Resize(size=self.inshape[:self.n_dims], interp_method='linear')(elastic_trans) + list_trans.append(elastic_trans) + + # apply deformations and return tensors with correct dtype + if self.apply_affine_trans | self.apply_elastic_trans: + if self.prob_deform == 1: + inputs = [nrn_layers.SpatialTransformer(m)([v] + list_trans) for (m, v) in + zip(self.inter_method, inputs)] + else: + rand_trans = tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_deform)) + inputs = [K.switch(rand_trans, nrn_layers.SpatialTransformer(m)([v] + list_trans), v) + for (m, v) in zip(self.inter_method, inputs)] + if self.n_inputs < 2: + return tf.cast(inputs[0], types[0]) + else: + return [tf.cast(v, t) for (t, v) in zip(types, inputs)] + + +class RandomCrop(Layer): + """Randomly crop all input tensors to a given shape. This cropping is applied to all channels. + The input tensors are expected to have shape [batchsize, shape_dim1, ..., shape_dimn, channel]. + :param crop_shape: list with cropping shape in each dimension (excluding batch and channel dimension) + + example: + if input is a tensor of shape [batchsize, 160, 160, 160, 3], + output = RandomCrop(crop_shape=[96, 128, 96])(input) + will yield an output of shape [batchsize, 96, 128, 96, 3] that is obtained by cropping with randomly selected + cropping indices. + """ + + def __init__(self, crop_shape, **kwargs): + + self.several_inputs = True + self.crop_max_val = None + self.crop_shape = crop_shape + self.n_dims = len(crop_shape) + self.list_n_channels = None + super(RandomCrop, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["crop_shape"] = self.crop_shape + return config + + def build(self, input_shape): + + if not isinstance(input_shape, list): + self.several_inputs = False + inputshape = [input_shape] + else: + inputshape = input_shape + self.crop_max_val = np.array(np.array(inputshape[0][1:self.n_dims + 1])) - np.array(self.crop_shape) + self.list_n_channels = [i[-1] for i in inputshape] + self.built = True + super(RandomCrop, self).build(input_shape) + + def call(self, inputs, **kwargs): + + # if one input only is provided, performs the cropping directly + if not self.several_inputs: + return tf.map_fn(self._single_slice, inputs, dtype=inputs.dtype) + + # otherwise we concatenate all inputs before cropping, so that they are all cropped at the same location + else: + types = [v.dtype for v in inputs] + inputs = tf.concat([tf.cast(v, 'float32') for v in inputs], axis=-1) + inputs = tf.map_fn(self._single_slice, inputs, dtype=tf.float32) + inputs = tf.split(inputs, self.list_n_channels, axis=-1) + return [tf.cast(v, t) for (t, v) in zip(types, inputs)] + + def _single_slice(self, vol): + crop_idx = tf.cast(tf.random.uniform([self.n_dims], 0, np.array(self.crop_max_val), 'float32'), dtype='int32') + crop_idx = tf.concat([crop_idx, tf.zeros([1], dtype='int32')], axis=0) + crop_size = tf.convert_to_tensor(self.crop_shape + [-1], dtype='int32') + return tf.slice(vol, begin=crop_idx, size=crop_size) + + def compute_output_shape(self, input_shape): + output_shape = [tuple([None] + self.crop_shape + [v]) for v in self.list_n_channels] + return output_shape if self.several_inputs else output_shape[0] + + +class RandomFlip(Layer): + """This layer randomly flips the input tensor along the specified axes with a specified probability. + It can also take multiple tensors as inputs (if they have the same shape). The same flips will be applied to all + input tensors. These are expected to have shape [batchsize, shape_dim1, ..., shape_dimn, channel]. + If specified, this layer can also swap corresponding values. This is especially useful when flipping label maps + with different labels for right/left structures, such that the flipped label maps keep a consistent labelling. + :param axis: integer, or list of integers specifying the dimensions along which to flip. + If a list, the input tensors can be flipped simultaneously in several directions. The values in flip_axis exclude + the batch dimension (e.g. 0 will flip the tensor along the first axis after the batch dimension). + Default is None, where the tensors can be flipped along all axes (except batch and channel axes). + :param swap_labels: boolean to specify whether to swap the values of each input. Values are only swapped if an odd + number of flips is applied. + Can also be a list if several tensors are given as input. + All the inputs for which the values need to be swapped must be int32 or int64. + :param label_list: if swap_labels is True, list of all labels contained in labels. Must be ordered as follows, first + the neutral labels (i.e. non-sided), then left labels and right labels. + :param n_neutral_labels: if swap_labels is True, number of non-sided labels + :param prob: probability to flip along each specified axis + + example 1: + if input is a tensor of shape (batchsize, 10, 100, 200, 3) + output = RandomFlip()(input) will randomly flip input along one of the 1st, 2nd, or 3rd axis (i.e. those with shape + 10, 100, 200). + + example 2: + if input is a tensor of shape (batchsize, 10, 100, 200, 3) + output = RandomFlip(flip_axis=1)(input) will randomly flip input along the 3rd axis (with shape 100), i.e. the axis + with index 1 if we don't count the batch axis. + + example 3: + input = tf.convert_to_tensor(np.array([[1, 0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 2, 2, 0], + [1, 0, 0, 0, 2, 2, 0], + [1, 0, 0, 0, 2, 2, 0], + [1, 0, 0, 0, 0, 0, 0]])) + label_list = np.array([0, 1, 2]) + n_neutral_labels = 1 + output = RandomFlip(flip_axis=1, swap_labels=True, label_list=label_list, n_neutral_labels=n_neutral_labels)(input) + where output will either be equal to input (bear in mind the flipping occurs with a 0.5 probability), or: + output = [[0, 0, 0, 0, 0, 0, 2], + [0, 1, 1, 0, 0, 0, 2], + [0, 1, 1, 0, 0, 0, 2], + [0, 1, 1, 0, 0, 0, 2], + [0, 0, 0, 0, 0, 0, 2]] + Note that the input must have a dtype int32 or int64 for its values to be swapped, otherwise an error will be raised + + example 4: + if labels is the same as in the input of example 3, and image is a float32 image, then we can swap consistently both + the labels and the image with: + labels, image = RandomFlip(flip_axis=1, swap_labels=[True, False], label_list=label_list, + n_neutral_labels=n_neutral_labels)([labels, image]]) + Note that the labels must have a dtype int32 or int64 to be swapped, otherwise an error will be raised. + This doesn't concern the image input, as its values are not swapped. + """ + + def __init__(self, axis=None, swap_labels=False, label_list=None, n_neutral_labels=None, prob=0.5, **kwargs): + + # shape attributes + self.several_inputs = True + self.n_dims = None + self.list_n_channels = None + + # axis along which to flip + self.axis = utils.reformat_to_list(axis) + self.flip_axes = None + + # whether to swap labels, and corresponding label list + self.swap_labels = utils.reformat_to_list(swap_labels) + self.label_list = label_list + self.n_neutral_labels = n_neutral_labels + self.swap_lut = None + + self.prob = prob + + super(RandomFlip, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["axis"] = self.axis + config["swap_labels"] = self.swap_labels + config["label_list"] = self.label_list + config["n_neutral_labels"] = self.n_neutral_labels + config["prob"] = self.prob + return config + + def build(self, input_shape): + + if not isinstance(input_shape, list): + self.several_inputs = False + inputshape = [input_shape] + else: + inputshape = input_shape + self.n_dims = len(inputshape[0][1:-1]) + self.list_n_channels = [i[-1] for i in inputshape] + self.swap_labels = utils.reformat_to_list(self.swap_labels, length=len(inputshape)) + self.flip_axes = np.arange(self.n_dims).tolist() if self.axis is None else self.axis + + # create label list with swapped labels + if any(self.swap_labels): + assert (self.label_list is not None) & (self.n_neutral_labels is not None), \ + 'please provide a label_list, and n_neutral_labels when swapping the values of at least one input' + n_labels = len(self.label_list) + if self.n_neutral_labels == n_labels: + self.swap_labels = [False] * len(self.swap_labels) + else: + rl_split = np.split(self.label_list, [self.n_neutral_labels, + self.n_neutral_labels + int((n_labels-self.n_neutral_labels)/2)]) + label_list_swap = np.concatenate((rl_split[0], rl_split[2], rl_split[1])) + swap_lut = utils.get_mapping_lut(self.label_list, label_list_swap) + self.swap_lut = tf.convert_to_tensor(swap_lut, dtype='int32') + + self.built = True + super(RandomFlip, self).build(input_shape) + + def call(self, inputs, **kwargs): + + # convert inputs to list, and get each input type + inputs = [inputs] if not self.several_inputs else inputs + types = [v.dtype for v in inputs] + + # store whether to flip along each specified dimension + batchsize = tf.split(tf.shape(inputs[0]), [1, self.n_dims + 1])[0] + size = tf.concat([batchsize, len(self.flip_axes) * tf.ones(1, dtype='int32')], axis=0) + rand_flip = K.less(tf.random.uniform(size, 0, 1), self.prob) + + # swap right/left labels if we apply an odd number of flips + odd = tf.math.floormod(tf.reduce_sum(tf.cast(rand_flip, 'int32'), -1, keepdims=True), 2) != 0 + swapped_inputs = list() + for i in range(len(inputs)): + if self.swap_labels[i]: + swapped_inputs.append(tf.map_fn(self._single_swap, [inputs[i], odd], dtype=types[i])) + else: + swapped_inputs.append(inputs[i]) + + # flip inputs and convert them back to their original type + inputs = tf.concat([tf.cast(v, 'float32') for v in swapped_inputs], axis=-1) + inputs = tf.map_fn(self._single_flip, [inputs, rand_flip], dtype=tf.float32) + inputs = tf.split(inputs, self.list_n_channels, axis=-1) + + if self.several_inputs: + return [tf.cast(v, t) for (t, v) in zip(types, inputs)] + else: + return tf.cast(inputs[0], types[0]) + + def _single_swap(self, inputs): + return K.switch(inputs[1], tf.gather(self.swap_lut, inputs[0]), inputs[0]) + + @staticmethod + def _single_flip(inputs): + flip_axis = tf.where(inputs[1]) + return K.switch(tf.equal(tf.size(flip_axis), 0), inputs[0], tf.reverse(inputs[0], axis=flip_axis[..., 0])) + + +class SampleConditionalGMM(Layer): + """This layer generates an image by sampling a Gaussian Mixture Model conditioned on a label map given as input. + The parameters of the GMM are given as two additional inputs to the layer (means and standard deviations): + image = SampleConditionalGMM(generation_labels)([label_map, means, stds]) + + :param generation_labels: list of all possible label values contained in the input label maps. + Must be a list or a 1D numpy array of size N, where N is the total number of possible label values. + + Layer inputs: + label_map: input label map of shape [batchsize, shape_dim1, ..., shape_dimn, n_channel]. + All the values of label_map must be contained in generation_labels, but the input label_map doesn't necessarily have + to contain all the values in generation_labels. + means: tensor containing the mean values of all Gaussian distributions of the GMM. + It must be of shape [batchsize, N, n_channel], and in the same order as generation label, + i.e. the ith value of generation_labels will be associated to the ith value of means. + stds: same as means but for the standard deviations of the GMM. + """ + + def __init__(self, generation_labels, **kwargs): + self.generation_labels = generation_labels + self.n_labels = None + self.n_channels = None + self.max_label = None + self.indices = None + self.shape = None + super(SampleConditionalGMM, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["generation_labels"] = self.generation_labels + return config + + def build(self, input_shape): + + # check n_labels and n_channels + assert len(input_shape) == 3, 'should have three inputs: labels, means, std devs (in that order).' + self.n_channels = input_shape[1][-1] + self.n_labels = len(self.generation_labels) + assert self.n_labels == input_shape[1][1], 'means should have the same number of values as generation_labels' + assert self.n_labels == input_shape[2][1], 'stds should have the same number of values as generation_labels' + + # scatter parameters (to build mean/std lut) + self.max_label = np.max(self.generation_labels) + 1 + indices = np.concatenate([self.generation_labels + self.max_label * i for i in range(self.n_channels)], axis=-1) + self.shape = tf.convert_to_tensor([np.max(indices) + 1], dtype='int32') + self.indices = tf.convert_to_tensor(utils.add_axis(indices, axis=[0, -1]), dtype='int32') + + self.built = True + super(SampleConditionalGMM, self).build(input_shape) + + def call(self, inputs, **kwargs): + + # reformat labels and scatter indices + batch = tf.split(tf.shape(inputs[0]), [1, -1])[0] + tmp_indices = tf.tile(self.indices, tf.concat([batch, tf.convert_to_tensor([1, 1], dtype='int32')], axis=0)) + labels = tf.concat([tf.cast(inputs[0], dtype='int32') + self.max_label * i for i in range(self.n_channels)], -1) + + # build mean map + means = tf.concat([inputs[1][..., i] for i in range(self.n_channels)], 1) + tile_shape = tf.concat([batch, tf.convert_to_tensor([1, ], dtype='int32')], axis=0) + means = tf.tile(tf.expand_dims(tf.scatter_nd(tmp_indices, means, self.shape), 0), tile_shape) + means_map = tf.map_fn(lambda x: tf.gather(x[0], x[1]), [means, labels], dtype=tf.float32) + + # same for stds + stds = tf.concat([inputs[2][..., i] for i in range(self.n_channels)], 1) + stds = tf.tile(tf.expand_dims(tf.scatter_nd(tmp_indices, stds, self.shape), 0), tile_shape) + stds_map = tf.map_fn(lambda x: tf.gather(x[0], x[1]), [stds, labels], dtype=tf.float32) + + return stds_map * tf.random.normal(tf.shape(labels)) + means_map + + def compute_output_shape(self, input_shape): + return input_shape[0] if (self.n_channels == 1) else tuple(list(input_shape[0][:-1]) + [self.n_channels]) + + +class SampleResolution(Layer): + """Build a random resolution tensor by sampling a uniform distribution of provided range. + + You can use this layer in the following ways: + resolution = SampleConditionalGMM(min_resolution)() in this case resolution will be a tensor of shape (n_dims,), + where n_dims is the length of the min_resolution parameter (provided as a list, see below). + resolution = SampleConditionalGMM(min_resolution)(input), where input is a tensor for which the first dimension + represents the batch_size. In this case resolution will be a tensor of shape (batchsize, n_dims,). + + :param min_resolution: list of length n_dims specifying the inferior bounds of the uniform distributions to + sample from for each value. + :param max_res_iso: If not None, all the values of resolution will be equal to the same value, which is randomly + sampled at each minibatch in U(min_resolution, max_res_iso). + :param max_res_aniso: If not None, we first randomly select a direction i in the range [0, n_dims-1], and we sample + a value in the corresponding uniform distribution U(min_resolution[i], max_res_aniso[i]). + The other values of resolution will be set to min_resolution. + :param prob_iso: if both max_res_iso and max_res_aniso are specified, this allows to specify the probability of + sampling an isotropic resolution (therefore using max_res_iso) with respect to anisotropic resolution + (which would use max_res_aniso). + :param prob_min: if not zero, this allows to return with the specified probability an output resolution equal + to min_resolution. + :param return_thickness: if set to True, this layer will also return a thickness value of the same shape as + resolution, which will be sampled independently for each axis from the uniform distribution + U(min_resolution, resolution). + + """ + + def __init__(self, + min_resolution, + max_res_iso=None, + max_res_aniso=None, + prob_iso=0.1, + prob_min=0.05, + return_thickness=True, + **kwargs): + + self.min_res = min_resolution + self.max_res_iso_input = max_res_iso + self.max_res_iso = None + self.max_res_aniso_input = max_res_aniso + self.max_res_aniso = None + self.prob_iso = prob_iso + self.prob_min = prob_min + self.return_thickness = return_thickness + self.n_dims = len(self.min_res) + self.add_batchsize = False + self.min_res_tens = None + super(SampleResolution, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["min_resolution"] = self.min_res + config["max_res_iso"] = self.max_res_iso + config["max_res_aniso"] = self.max_res_aniso + config["prob_iso"] = self.prob_iso + config["prob_min"] = self.prob_min + config["return_thickness"] = self.return_thickness + return config + + def build(self, input_shape): + + # check maximum resolutions + assert ((self.max_res_iso_input is not None) | (self.max_res_aniso_input is not None)), \ + 'at least one of maximum isotropic or anisotropic resolutions must be provided, received none' + + # reformat resolutions as numpy arrays + self.min_res = np.array(self.min_res) + if self.max_res_iso_input is not None: + self.max_res_iso = np.array(self.max_res_iso_input) + assert len(self.min_res) == len(self.max_res_iso), \ + 'min and isotropic max resolution must have the same length, ' \ + 'had {0} and {1}'.format(self.min_res, self.max_res_iso) + if np.array_equal(self.min_res, self.max_res_iso): + self.max_res_iso = None + if self.max_res_aniso_input is not None: + self.max_res_aniso = np.array(self.max_res_aniso_input) + assert len(self.min_res) == len(self.max_res_aniso), \ + 'min and anisotropic max resolution must have the same length, ' \ + 'had {} and {}'.format(self.min_res, self.max_res_aniso) + if np.array_equal(self.min_res, self.max_res_aniso): + self.max_res_aniso = None + + # check prob iso + if (self.max_res_iso is not None) & (self.max_res_aniso is not None) & (self.prob_iso == 0): + raise Exception('prob iso is 0 while sampling either isotropic and anisotropic resolutions is enabled') + + if input_shape: + self.add_batchsize = True + + self.min_res_tens = tf.convert_to_tensor(self.min_res, dtype='float32') + + self.built = True + super(SampleResolution, self).build(input_shape) + + def call(self, inputs, **kwargs): + + if not self.add_batchsize: + shape = [self.n_dims] + dim = tf.random.uniform(shape=(1, 1), minval=0, maxval=self.n_dims, dtype='int32') + mask = tf.tensor_scatter_nd_update(tf.zeros([self.n_dims], dtype='bool'), dim, + tf.convert_to_tensor([True], dtype='bool')) + else: + batch = tf.split(tf.shape(inputs), [1, -1])[0] + tile_shape = tf.concat([batch, tf.convert_to_tensor([1], dtype='int32')], axis=0) + self.min_res_tens = tf.tile(tf.expand_dims(self.min_res_tens, 0), tile_shape) + + shape = tf.concat([batch, tf.convert_to_tensor([self.n_dims], dtype='int32')], axis=0) + indices = tf.stack([tf.range(0, batch[0]), tf.random.uniform(batch, 0, self.n_dims, dtype='int32')], 1) + mask = tf.tensor_scatter_nd_update(tf.zeros(shape, dtype='bool'), indices, tf.ones(batch, dtype='bool')) + + # return min resolution as tensor if min=max + if (self.max_res_iso is None) & (self.max_res_aniso is None): + new_resolution = self.min_res_tens + + # sample isotropic resolution only + elif (self.max_res_iso is not None) & (self.max_res_aniso is None): + new_resolution_iso = tf.random.uniform(shape, minval=self.min_res, maxval=self.max_res_iso) + new_resolution = K.switch(tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_min)), + self.min_res_tens, + new_resolution_iso) + + # sample anisotropic resolution only + elif (self.max_res_iso is None) & (self.max_res_aniso is not None): + new_resolution_aniso = tf.random.uniform(shape, minval=self.min_res, maxval=self.max_res_aniso) + new_resolution = K.switch(tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_min)), + self.min_res_tens, + tf.where(mask, new_resolution_aniso, self.min_res_tens)) + + # sample either anisotropic or isotropic resolution + else: + new_resolution_iso = tf.random.uniform(shape, minval=self.min_res, maxval=self.max_res_iso) + new_resolution_aniso = tf.random.uniform(shape, minval=self.min_res, maxval=self.max_res_aniso) + new_resolution = K.switch(tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_iso)), + new_resolution_iso, + tf.where(mask, new_resolution_aniso, self.min_res_tens)) + new_resolution = K.switch(tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_min)), + self.min_res_tens, + new_resolution) + + if self.return_thickness: + return [new_resolution, tf.random.uniform(tf.shape(self.min_res_tens), self.min_res_tens, new_resolution)] + else: + return new_resolution + + def compute_output_shape(self, input_shape): + if self.return_thickness: + return [(None, self.n_dims)] * 2 if self.add_batchsize else [self.n_dims] * 2 + else: + return (None, self.n_dims) if self.add_batchsize else self.n_dims + + +class GaussianBlur(Layer): + """Applies gaussian blur to an input image. + The input image is expected to have shape [batchsize, shape_dim1, ..., shape_dimn, channel]. + :param sigma: standard deviation of the blurring kernels to apply. Can be a number, a list of length n_dims, or + a numpy array. + :param random_blur_range: (optional) if not None, this introduces a randomness in the blurring kernels, where + sigma is now multiplied by a coefficient dynamically sampled from a uniform distribution with bounds + [1/random_blur_range, random_blur_range]. + :param use_mask: (optional) whether a mask of the input will be provided as an additional layer input. This is used + to mask the blurred image, and to correct for edge blurring effects. + + example 1: + output = GaussianBlur(sigma=0.5)(input) will isotropically blur the input with a gaussian kernel of std 0.5. + + example 2: + if input is a tensor of shape [batchsize, 10, 100, 200, 2] + output = GaussianBlur(sigma=[0.5, 1, 10])(input) will blur the input a different gaussian kernel in each dimension. + + example 3: + output = GaussianBlur(sigma=0.5, random_blur_range=1.15)(input) + will blur the input a different gaussian kernel in each dimension, as each dimension will be associated with + a kernel, whose standard deviation will be uniformly sampled from [0.5/1.15; 0.5*1.15]. + + example 4: + output = GaussianBlur(sigma=0.5, use_mask=True)([input, mask]) + will 1) blur the input a different gaussian kernel in each dimension, 2) mask the blurred image with the provided + mask, and 3) correct for edge blurring effects. If the provided mask is not of boolean type, it will be thresholded + above positive values. + """ + + def __init__(self, sigma, random_blur_range=None, use_mask=False, **kwargs): + self.sigma = utils.reformat_to_list(sigma) + assert np.all(np.array(self.sigma) >= 0), 'sigma should be superior or equal to 0' + self.use_mask = use_mask + + self.n_dims = None + self.n_channels = None + self.blur_range = random_blur_range + self.stride = None + self.separable = None + self.kernels = None + self.convnd = None + super(GaussianBlur, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["sigma"] = self.sigma + config["random_blur_range"] = self.blur_range + config["use_mask"] = self.use_mask + return config + + def build(self, input_shape): + + # get shapes + if self.use_mask: + assert len(input_shape) == 2, 'please provide a mask as second layer input when use_mask=True' + self.n_dims = len(input_shape[0]) - 2 + self.n_channels = input_shape[0][-1] + else: + self.n_dims = len(input_shape) - 2 + self.n_channels = input_shape[-1] + + # prepare blurring kernel + self.stride = [1] * (self.n_dims + 2) + self.sigma = utils.reformat_to_list(self.sigma, length=self.n_dims) + self.separable = np.linalg.norm(np.array(self.sigma)) > 5 + if self.blur_range is None: # fixed kernels + self.kernels = l2i_et.gaussian_kernel(self.sigma, separable=self.separable) + else: + self.kernels = None + + # prepare convolution + self.convnd = getattr(tf.nn, 'conv%dd' % self.n_dims) + + self.built = True + super(GaussianBlur, self).build(input_shape) + + def call(self, inputs, **kwargs): + + if self.use_mask: + image = inputs[0] + mask = tf.cast(inputs[1], 'bool') + else: + image = inputs + mask = None + + # redefine the kernels at each new step when blur_range is activated + if self.blur_range is not None: + self.kernels = l2i_et.gaussian_kernel(self.sigma, blur_range=self.blur_range, separable=self.separable) + + if self.separable: + for k in self.kernels: + if k is not None: + image = tf.concat([self.convnd(tf.expand_dims(image[..., n], -1), k, self.stride, 'SAME') + for n in range(self.n_channels)], -1) + if self.use_mask: + maskb = tf.cast(mask, 'float32') + maskb = tf.concat([self.convnd(tf.expand_dims(maskb[..., n], -1), k, self.stride, 'SAME') + for n in range(self.n_channels)], -1) + image = image / (maskb + K.epsilon()) + image = tf.where(mask, image, tf.zeros_like(image)) + else: + if any(self.sigma): + image = tf.concat([self.convnd(tf.expand_dims(image[..., n], -1), self.kernels, self.stride, 'SAME') + for n in range(self.n_channels)], -1) + if self.use_mask: + maskb = tf.cast(mask, 'float32') + maskb = tf.concat([self.convnd(tf.expand_dims(maskb[..., n], -1), self.kernels, self.stride, 'SAME') + for n in range(self.n_channels)], -1) + image = image / (maskb + K.epsilon()) + image = tf.where(mask, image, tf.zeros_like(image)) + + return image + + +class DynamicGaussianBlur(Layer): + """Applies gaussian blur to an input image, where the standard deviation of the blurring kernel is provided as a + layer input, which enables to perform dynamic blurring (i.e. the blurring kernel can vary at each minibatch). + :param max_sigma: maximum value of the standard deviation that will be provided as input. This is used to compute + the size of the blurring kernels. It must be provided as a list of length n_dims. + :param random_blur_range: (optional) if not None, this introduces a randomness in the blurring kernels, where + sigma is now multiplied by a coefficient dynamically sampled from a uniform distribution with bounds + [1/random_blur_range, random_blur_range]. + + example: + blurred_image = DynamicGaussianBlur(max_sigma=[5.]*3, random_blurring_range=1.15)([image, sigma]) + will return a blurred version of image, where the standard deviation of each dimension (given as a tensor, and with + values lower than 5 for each axis) is multiplied by a random coefficient uniformly sampled from [1/1.15; 1.15]. + """ + + def __init__(self, max_sigma, random_blur_range=None, **kwargs): + self.max_sigma = max_sigma + self.n_dims = None + self.n_channels = None + self.convnd = None + self.blur_range = random_blur_range + self.separable = None + super(DynamicGaussianBlur, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["max_sigma"] = self.max_sigma + config["random_blur_range"] = self.blur_range + return config + + def build(self, input_shape): + assert len(input_shape) == 2, 'sigma should be provided as an input tensor for dynamic blurring' + self.n_dims = len(input_shape[0]) - 2 + self.n_channels = input_shape[0][-1] + self.convnd = getattr(tf.nn, 'conv%dd' % self.n_dims) + self.max_sigma = utils.reformat_to_list(self.max_sigma, length=self.n_dims) + self.separable = np.linalg.norm(np.array(self.max_sigma)) > 5 + self.built = True + super(DynamicGaussianBlur, self).build(input_shape) + + def call(self, inputs, **kwargs): + image = inputs[0] + sigma = inputs[-1] + kernels = l2i_et.gaussian_kernel(sigma, self.max_sigma, self.blur_range, self.separable) + if self.separable: + for kernel in kernels: + image = tf.map_fn(self._single_blur, [image, kernel], dtype=tf.float32) + else: + image = tf.map_fn(self._single_blur, [image, kernels], dtype=tf.float32) + return image + + def _single_blur(self, inputs): + if self.n_channels > 1: + split_channels = tf.split(inputs[0], [1] * self.n_channels, axis=-1) + blurred_channel = list() + for channel in split_channels: + blurred = self.convnd(tf.expand_dims(channel, 0), inputs[1], [1] * (self.n_dims + 2), padding='SAME') + blurred_channel.append(tf.squeeze(blurred, axis=0)) + output = tf.concat(blurred_channel, -1) + else: + output = self.convnd(tf.expand_dims(inputs[0], 0), inputs[1], [1] * (self.n_dims + 2), padding='SAME') + output = tf.squeeze(output, axis=0) + return output + + +class MimicAcquisition(Layer): + """ + Layer that takes an image as input, and simulates data that has been acquired at low resolution. + The output is obtained by resampling the input twice: + - first at a resolution given as an input (i.e. the "acquisition" resolution), + - then at the output resolution (specified output shape). + The input tensor is expected to have shape [batchsize, shape_dim1, ..., shape_dimn, channel]. + + :param volume_res: resolution of the provided inputs. Must be a 1-D numpy array with n_dims elements. + :param min_subsample_res: lower bound of the acquisition resolutions to mimic (i.e. the input resolution must have + values higher than min-subsample_res). + :param resample_shape: shape of the output tensor + :param build_dist_map: whether to return distance maps as outputs. These indicate the distance between each voxel + and the nearest non-interpolated voxel (during the second resampling). + :param prob_noise: probability to apply noise injection + + example 1: + im_res = [1., 1., 1.] + low_res = [1., 1., 3.] + res = tf.convert_to_tensor([1., 1., 4.5]) + image is a tensor of shape (None, 256, 256, 256, 3) + resample_shape = [256, 256, 256] + output = MimicAcquisition(im_res, low_res, resample_shape)([image, res]) + output will be a tensor of shape (None, 256, 256, 256, 3), obtained by downsampling image to [1., 1., 4.5]. + and re-upsampling it at initial resolution (because resample_shape is equal to the input shape). In this example all + examples of the batch will be downsampled to the same resolution (because res has no batch dimension). + Note that the provided res must have higher values than min_low_res. + + example 2: + im_res = [1., 1., 1.] + min_low_res = [1., 1., 1.] + res is a tensor of shape (None, 3), obtained for example by using the SampleResolution layer (see above). + image is a tensor of shape (None, 256, 256, 256, 1) + resample_shape = [128, 128, 128] + output = MimicAcquisition(im_res, low_res, resample_shape)([image, res]) + output will be a tensor of shape (None, 128, 128, 128, 1), obtained by downsampling each examples of the batch to + the matching resolution in res, and resampling them all to half the initial resolution. + Note that the provided res must have higher values than min_low_res. + """ + + def __init__(self, volume_res, min_subsample_res, resample_shape, build_dist_map=False, + noise_std=0, prob_noise=0.95, **kwargs): + + # resolutions and dimensions + self.volume_res = volume_res + self.min_subsample_res = min_subsample_res + self.n_dims = len(self.volume_res) + self.n_channels = None + self.add_batchsize = None + + # noise + self.noise_std = noise_std + self.prob_noise = prob_noise + + # input and output shapes + self.inshape = None + self.resample_shape = resample_shape + + # meshgrids for resampling + self.down_grid = None + self.up_grid = None + + # whether to return a map indicating the distance from the interpolated voxels, to acquired ones. + self.build_dist_map = build_dist_map + + super(MimicAcquisition, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["volume_res"] = self.volume_res + config["min_subsample_res"] = self.min_subsample_res + config["resample_shape"] = self.resample_shape + config["build_dist_map"] = self.build_dist_map + config["noise_std"] = self.noise_std + config["prob_noise"] = self.prob_noise + return config + + def build(self, input_shape): + + # set up input shape and acquisition shape + self.inshape = input_shape[0][1:] + self.n_channels = input_shape[0][-1] + self.add_batchsize = False if (input_shape[1][0] is None) else True + down_tensor_shape = np.int32(np.array(self.inshape[:-1]) * self.volume_res / self.min_subsample_res) + + # build interpolation meshgrids + self.down_grid = tf.expand_dims(tf.stack(nrn_utils.volshape_to_ndgrid(down_tensor_shape), -1), axis=0) + self.up_grid = tf.expand_dims(tf.stack(nrn_utils.volshape_to_ndgrid(self.resample_shape), -1), axis=0) + + self.built = True + super(MimicAcquisition, self).build(input_shape) + + def call(self, inputs, **kwargs): + + # sort inputs + assert len(inputs) == 2, 'inputs must have two items, the tensor to resample, and the downsampling resolution' + vol = inputs[0] + subsample_res = tf.cast(inputs[1], dtype='float32') + vol = K.reshape(vol, [-1, *self.inshape]) # necessary for multi_gpu models + batchsize = tf.split(tf.shape(vol), [1, -1])[0] + tile_shape = tf.concat([batchsize, tf.ones([1], dtype='int32')], 0) + + # get downsampling and upsampling factors + if self.add_batchsize: + subsample_res = tf.tile(tf.expand_dims(subsample_res, 0), tile_shape) + down_shape = tf.cast(tf.convert_to_tensor(np.array(self.inshape[:-1]) * self.volume_res, dtype='float32') / + subsample_res, dtype='int32') + down_zoom_factor = tf.cast(down_shape / tf.convert_to_tensor(self.inshape[:-1]), dtype='float32') + up_zoom_factor = tf.cast(tf.convert_to_tensor(self.resample_shape, dtype='int32') / down_shape, dtype='float32') + + # downsample + down_loc = tf.tile(self.down_grid, tf.concat([batchsize, tf.ones([self.n_dims + 1], dtype='int32')], 0)) + down_loc = tf.cast(down_loc, 'float32') / l2i_et.expand_dims(down_zoom_factor, axis=[1] * self.n_dims) + inshape_tens = tf.tile(tf.expand_dims(tf.convert_to_tensor(self.inshape[:-1]), 0), tile_shape) + inshape_tens = l2i_et.expand_dims(inshape_tens, axis=[1] * self.n_dims) + down_loc = K.clip(down_loc, 0., tf.cast(inshape_tens, 'float32')) + vol = tf.map_fn(self._single_down_interpn, [vol, down_loc], tf.float32) + + # add noise with predefined probability + if self.noise_std > 0: + sample_shape = tf.concat([batchsize, tf.ones([self.n_dims], dtype='int32'), + self.n_channels * tf.ones([1], dtype='int32')], 0) + noise = tf.random.normal(tf.shape(vol), stddev=tf.random.uniform(sample_shape, maxval=self.noise_std)) + if self.prob_noise == 1: + vol += noise + else: + vol = K.switch(tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_noise)), vol + noise, vol) + + # upsample + up_loc = tf.tile(self.up_grid, tf.concat([batchsize, tf.ones([self.n_dims + 1], dtype='int32')], axis=0)) + up_loc = tf.cast(up_loc, 'float32') / l2i_et.expand_dims(up_zoom_factor, axis=[1] * self.n_dims) + vol = tf.map_fn(self._single_up_interpn, [vol, up_loc], tf.float32) + + # return upsampled volume + if not self.build_dist_map: + return vol + + # return upsampled volumes with distance maps + else: + + # get grid points + floor = tf.math.floor(up_loc) + ceil = tf.math.ceil(up_loc) + + # get distances of every voxel to higher and lower grid points for every dimension + f_dist = up_loc - floor + c_dist = ceil - up_loc + + # keep minimum 1d distances, and compute 3d distance to nearest grid point + dist = tf.math.minimum(f_dist, c_dist) * l2i_et.expand_dims(subsample_res, axis=[1] * self.n_dims) + dist = tf.math.sqrt(tf.math.reduce_sum(tf.math.square(dist), axis=-1, keepdims=True)) + + return [vol, dist] + + @staticmethod + def _single_down_interpn(inputs): + return nrn_utils.interpn(inputs[0], inputs[1], interp_method='nearest') + + @staticmethod + def _single_up_interpn(inputs): + return nrn_utils.interpn(inputs[0], inputs[1], interp_method='linear') + + def compute_output_shape(self, input_shape): + output_shape = tuple([None] + self.resample_shape + [input_shape[0][-1]]) + return [output_shape] * 2 if self.build_dist_map else output_shape + + +class BiasFieldCorruption(Layer): + """This layer applies a smooth random bias field to the input by applying the following steps: + 1) we first sample a value for the standard deviation of a centred normal distribution + 2) a small-size SVF is sampled from this normal distribution + 3) the small SVF is then resized with trilinear interpolation to image size + 4) it is rescaled to positive values by taking the voxel-wise exponential + 5) it is multiplied to the input tensor. + The input tensor is expected to have shape [batchsize, shape_dim1, ..., shape_dimn, channel]. + + :param bias_field_std: maximum value of the standard deviation sampled in 1 (it will be sampled from the range + [0, bias_field_std]) + :param bias_scale: ratio between the shape of the input tensor and the shape of the sampled SVF. + :param same_bias_for_all_channels: whether to apply the same bias field to all the channels of the input tensor. + :param prob: probability to apply this bias field corruption. + """ + + def __init__(self, bias_field_std=.5, bias_scale=.025, same_bias_for_all_channels=False, prob=0.95, **kwargs): + + # input shape + self.several_inputs = False + self.inshape = None + self.n_dims = None + self.n_channels = None + + # sampling shape + self.std_shape = None + self.small_bias_shape = None + + # bias field parameters + self.bias_field_std = bias_field_std + self.bias_scale = bias_scale + self.same_bias_for_all_channels = same_bias_for_all_channels + self.prob = prob + + super(BiasFieldCorruption, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["bias_field_std"] = self.bias_field_std + config["bias_scale"] = self.bias_scale + config["same_bias_for_all_channels"] = self.same_bias_for_all_channels + config["prob"] = self.prob + return config + + def build(self, input_shape): + + # input shape + if isinstance(input_shape, list): + self.several_inputs = True + self.inshape = input_shape + else: + self.inshape = [input_shape] + self.n_dims = len(self.inshape[0]) - 2 + self.n_channels = self.inshape[0][-1] + + # sampling shapes + self.std_shape = [1] * (self.n_dims + 1) + self.small_bias_shape = utils.get_resample_shape(self.inshape[0][1:self.n_dims + 1], self.bias_scale, 1) + if not self.same_bias_for_all_channels: + self.std_shape[-1] = self.n_channels + self.small_bias_shape[-1] = self.n_channels + + self.built = True + super(BiasFieldCorruption, self).build(input_shape) + + def call(self, inputs, **kwargs): + + if not self.several_inputs: + inputs = [inputs] + + if self.bias_field_std > 0: + + # sampling shapes + batchsize = tf.split(tf.shape(inputs[0]), [1, -1])[0] + std_shape = tf.concat([batchsize, tf.convert_to_tensor(self.std_shape, dtype='int32')], 0) + bias_shape = tf.concat([batchsize, tf.convert_to_tensor(self.small_bias_shape, dtype='int32')], axis=0) + + # sample small bias field + bias_field = tf.random.normal(bias_shape, stddev=tf.random.uniform(std_shape, maxval=self.bias_field_std)) + + # resize bias field and take exponential + bias_field = nrn_layers.Resize(size=self.inshape[0][1:self.n_dims + 1], interp_method='linear')(bias_field) + bias_field = tf.math.exp(bias_field) + + # apply bias field with predefined probability + if self.prob == 1: + return [tf.math.multiply(bias_field, v) for v in inputs] + else: + rand_trans = tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob)) + if self.several_inputs: + return [K.switch(rand_trans, tf.math.multiply(bias_field, v), v) for v in inputs] + else: + return K.switch(rand_trans, tf.math.multiply(bias_field, inputs[0]), inputs[0]) + + else: + return inputs + + +class IntensityAugmentation(Layer): + """This layer enables to augment the intensities of the input tensor, as well as to apply min_max normalisation. + The following steps are applied (all are optional): + 1) white noise corruption, with a randomly sampled std dev. + 2) clip the input between two values + 3) min-max normalisation + 4) gamma augmentation (i.e. voxel-wise exponentiation by a randomly sampled power) + The input tensor is expected to have shape [batchsize, shape_dim1, ..., shape_dimn, channel]. + + :param noise_std: maximum value of the standard deviation of the Gaussian white noise used in 1 (it will be sampled + from the range [0, noise_std]). Set to 0 to skip this step. + :param clip: clip the input tensor between the given values. Can either be: a number (in which case we clip between + 0 and the given value), or a list or a numpy array with two elements. Default is 0, where no clipping occurs. + :param normalise: whether to apply min-max normalisation, to normalise between 0 and 1. Default is True. + :param norm_perc: percentiles (between 0 and 1) of the sorted intensity values for robust normalisation. Can be: + a number (in which case the robust minimum is the provided percentile of sorted values, and the maximum is the + 1 - norm_perc percentile), or a list/numpy array of 2 elements (percentiles for the minimum and maximum values). + The minimum and maximum values are computed separately for each channel if separate_channels is True. + Default is 0, where we simply take the minimum and maximum values. + :param gamma_std: standard deviation of the normal distribution from which we sample gamma (in log domain). + Default is 0, where no gamma augmentation occurs. + :param contrast_inversion: whether to perform contrast inversion (i.e. 1 - x). If True, this is performed randomly + for each element of the batch, as well as for each channel. + :param separate_channels: whether to augment all channels separately. Default is True. + :param prob_noise: probability to apply noise injection + :param prob_gamma: probability to apply gamma augmentation + """ + + def __init__(self, noise_std=0, clip=0, normalise=True, norm_perc=0, gamma_std=0, contrast_inversion=False, + separate_channels=True, prob_noise=0.95, prob_gamma=1, **kwargs): + + # shape attributes + self.n_dims = None + self.n_channels = None + self.flatten_shape = None + self.expand_minmax_dim = None + self.one = None + + # inputs + self.noise_std = noise_std + self.clip = clip + self.clip_values = None + self.normalise = normalise + self.norm_perc = norm_perc + self.perc = None + self.gamma_std = gamma_std + self.separate_channels = separate_channels + self.contrast_inversion = contrast_inversion + self.prob_noise = prob_noise + self.prob_gamma = prob_gamma + + super(IntensityAugmentation, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["noise_std"] = self.noise_std + config["clip"] = self.clip + config["normalise"] = self.normalise + config["norm_perc"] = self.norm_perc + config["gamma_std"] = self.gamma_std + config["separate_channels"] = self.separate_channels + config["prob_noise"] = self.prob_noise + config["prob_gamma"] = self.prob_gamma + return config + + def build(self, input_shape): + self.n_dims = len(input_shape) - 2 + self.n_channels = input_shape[-1] + self.flatten_shape = np.prod(np.array(input_shape[1:-1])) + self.flatten_shape = self.flatten_shape * self.n_channels if not self.separate_channels else self.flatten_shape + self.expand_minmax_dim = self.n_dims if self.separate_channels else self.n_dims + 1 + self.one = tf.ones([1], dtype='int32') + if self.clip: + self.clip_values = utils.reformat_to_list(self.clip) + self.clip_values = self.clip_values if len(self.clip_values) == 2 else [0, self.clip_values[0]] + else: + self.clip_values = None + if self.norm_perc: + self.perc = utils.reformat_to_list(self.norm_perc) + self.perc = self.perc if len(self.perc) == 2 else [self.perc[0], 1 - self.perc[0]] + else: + self.perc = None + + self.built = True + super(IntensityAugmentation, self).build(input_shape) + + def call(self, inputs, **kwargs): + + # prepare shape for sampling the noise and gamma std dev (depending on whether we augment channels separately) + batchsize = tf.split(tf.shape(inputs), [1, -1])[0] + if (self.noise_std > 0) | (self.gamma_std > 0) | self.contrast_inversion: + sample_shape = tf.concat([batchsize, tf.ones([self.n_dims], dtype='int32')], 0) + if self.separate_channels: + sample_shape = tf.concat([sample_shape, self.n_channels * self.one], 0) + else: + sample_shape = tf.concat([sample_shape, self.one], 0) + else: + sample_shape = None + + # add noise with predefined probability + if self.noise_std > 0: + noise_stddev = tf.random.uniform(sample_shape, maxval=self.noise_std) + if self.separate_channels: + noise = tf.random.normal(tf.shape(inputs), stddev=noise_stddev) + else: + noise = tf.random.normal(tf.shape(tf.split(inputs, [1, -1], -1)[0]), stddev=noise_stddev) + noise = tf.tile(noise, tf.convert_to_tensor([1] * (self.n_dims + 1) + [self.n_channels])) + if self.prob_noise == 1: + inputs = inputs + noise + else: + inputs = K.switch(tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_noise)), + inputs + noise, inputs) + + # clip images to given values + if self.clip_values is not None: + inputs = K.clip(inputs, self.clip_values[0], self.clip_values[1]) + + # normalise + if self.normalise: + # define robust min and max by sorting values and taking percentile + if self.perc is not None: + if self.separate_channels: + shape = tf.concat([batchsize, self.flatten_shape * self.one, self.n_channels * self.one], 0) + else: + shape = tf.concat([batchsize, self.flatten_shape * self.one], 0) + intensities = tf.sort(tf.reshape(inputs, shape), axis=1) + m = intensities[:, max(int(self.perc[0] * self.flatten_shape), 0), ...] + M = intensities[:, min(int(self.perc[1] * self.flatten_shape), self.flatten_shape - 1), ...] + # simple min and max + else: + m = K.min(inputs, axis=list(range(1, self.expand_minmax_dim + 1))) + M = K.max(inputs, axis=list(range(1, self.expand_minmax_dim + 1))) + # normalise + m = l2i_et.expand_dims(m, axis=[1] * self.expand_minmax_dim) + M = l2i_et.expand_dims(M, axis=[1] * self.expand_minmax_dim) + inputs = tf.clip_by_value(inputs, m, M) + inputs = (inputs - m) / (M - m + K.epsilon()) + + # apply voxel-wise exponentiation with predefined probability + if self.gamma_std > 0: + gamma = tf.random.normal(sample_shape, stddev=self.gamma_std) + if self.prob_gamma == 1: + inputs = tf.math.pow(inputs, tf.math.exp(gamma)) + else: + inputs = K.switch(tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_gamma)), + tf.math.pow(inputs, tf.math.exp(gamma)), inputs) + + # apply random contrast inversion + if self.contrast_inversion: + rand_invert = tf.less(tf.random.uniform(sample_shape, maxval=1), 0.5) + split_channels = tf.split(inputs, [1] * self.n_channels, axis=-1) + split_rand_invert = tf.split(rand_invert, [1] * self.n_channels, axis=-1) + inverted_channel = list() + for (channel, invert) in zip(split_channels, split_rand_invert): + inverted_channel.append(tf.map_fn(self._single_invert, [channel, invert], dtype=channel.dtype)) + inputs = tf.concat(inverted_channel, -1) + + return inputs + + @staticmethod + def _single_invert(inputs): + return K.switch(tf.squeeze(inputs[1]), 1 - inputs[0], inputs[0]) + + +class DiceLoss(Layer): + """This layer computes the soft Dice loss between two tensors. + These tensors are expected to have the same shape (one-hot encoding) [batch, size_dim1, ..., size_dimN, n_labels]. + The first input tensor is the GT and the second is the prediction: dice_loss = DiceLoss()([gt, pred]) + + :param class_weights: (optional) if given, the loss is obtained by a weighted average of the Dice across labels. + Must be a sequence or 1d numpy array of length n_labels. Can also be -1, where the weights are dynamically set to + the inverse of the volume of each label in the ground truth. + :param boundary_weights: (optional) bonus weight that we apply to the voxels close to boundaries between structures + when computing the loss. Default is 0 where no boundary weighting is applied. + :param boundary_dist: (optional) if boundary_weight is not 0, the extra boundary weighting is applied to all voxels + within this distance to a region boundary. Default is 3. + :param skip_background: (optional) whether to skip boundary weighting for the background class, as this may be + redundant when we have several labels. This is only used if boundary_weight is not 0. + :param enable_checks: (optional) whether to make sure that the 2 input tensors are probabilistic (i.e. the label + probabilities sum to 1 at each voxel location). Default is True. + """ + + def __init__(self, + class_weights=None, + boundary_weights=0, + boundary_dist=3, + skip_background=True, + enable_checks=True, + **kwargs): + + self.class_weights = class_weights + self.dynamic_weighting = False + self.class_weights_tens = None + self.boundary_weights = boundary_weights + self.boundary_dist = boundary_dist + self.skip_background = skip_background + self.enable_checks = enable_checks + self.spatial_axes = None + self.avg_pooling_layer = None + super(DiceLoss, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["class_weights"] = self.class_weights + config["boundary_weights"] = self.boundary_weights + config["boundary_dist"] = self.boundary_dist + config["skip_background"] = self.skip_background + config["enable_checks"] = self.enable_checks + return config + + def build(self, input_shape): + + # get shape + assert len(input_shape) == 2, 'DiceLoss expects 2 inputs to compute the Dice loss.' + assert input_shape[0] == input_shape[1], 'the two inputs must have the same shape.' + inshape = input_shape[0][1:] + n_dims = len(inshape[:-1]) + n_labels = inshape[-1] + self.spatial_axes = list(range(1, n_dims + 1)) + self.avg_pooling_layer = getattr(keras.layers, 'AvgPool%dD' % n_dims) + self.skip_background = False if n_labels == 1 else self.skip_background + + # build tensor with class weights + if self.class_weights is not None: + if self.class_weights == -1: + self.dynamic_weighting = True + else: + class_weights_tens = utils.reformat_to_list(self.class_weights, n_labels) + class_weights_tens = tf.convert_to_tensor(class_weights_tens, 'float32') + self.class_weights_tens = l2i_et.expand_dims(class_weights_tens, 0) + + self.built = True + super(DiceLoss, self).build(input_shape) + + def call(self, inputs, **kwargs): + + # make sure tensors are probabilistic + gt = inputs[0] + pred = inputs[1] + if self.enable_checks: # disabling is useful to, e.g., use incomplete label maps + gt = K.clip(gt / (tf.math.reduce_sum(gt, axis=-1, keepdims=True) + tf.keras.backend.epsilon()), 0, 1) + pred = K.clip(pred / (tf.math.reduce_sum(pred, axis=-1, keepdims=True) + tf.keras.backend.epsilon()), 0, 1) + + # compute dice loss for each label + top = 2 * gt * pred + bottom = tf.math.square(gt) + tf.math.square(pred) + + # apply boundary weighting (ie voxels close to region boundaries will be counted several times to compute Dice) + if self.boundary_weights: + avg = self.avg_pooling_layer(pool_size=2 * self.boundary_dist + 1, strides=1, padding='same')(gt) + boundaries = tf.cast(avg > 0., 'float32') * tf.cast(avg < (1 / len(self.spatial_axes) - 1e-4), 'float32') + if self.skip_background: + boundaries_channels = tf.unstack(boundaries, axis=-1) + boundaries = tf.stack([tf.zeros_like(boundaries_channels[0])] + boundaries_channels[1:], axis=-1) + boundary_weights_tensor = 1 + self.boundary_weights * boundaries + top *= boundary_weights_tensor + bottom *= boundary_weights_tensor + else: + boundary_weights_tensor = None + + # compute loss + top = tf.math.reduce_sum(top, self.spatial_axes) + bottom = tf.math.reduce_sum(bottom, self.spatial_axes) + dice = (top + tf.keras.backend.epsilon()) / (bottom + tf.keras.backend.epsilon()) + loss = 1 - dice + + # apply class weighting across labels. In this case loss will have shape (batch), otherwise (batch, n_labels). + if self.dynamic_weighting: # the weight of a class is the inverse of its volume in the gt + if boundary_weights_tensor is not None: # we account for the boundary weighting to compute volume + self.class_weights_tens = 1 / tf.reduce_sum(gt * boundary_weights_tensor, self.spatial_axes) + else: + self.class_weights_tens = 1 / tf.reduce_sum(gt, self.spatial_axes) + if self.class_weights_tens is not None: + self. class_weights_tens /= tf.reduce_sum(self.class_weights_tens, -1) + loss = tf.reduce_sum(loss * self.class_weights_tens, -1) + + return tf.math.reduce_mean(loss) + + def compute_output_shape(self, input_shape): + return [[]] + + +class WeightedL2Loss(Layer): + """This layer computes a L2 loss weighted by a specified factor (target_value) between two tensors. + This is designed to be used on the layer before the softmax. + The tensors are expected to have the same shape [batchsize, size_dim1, ..., size_dimN, n_labels]. + The first input tensor is the GT and the second is the prediction: wl2_loss = WeightedL2Loss()([gt, pred]) + + :param target_value: target value for the layer before softmax: target_value when gt = 1, -target_value when gt = 0. + """ + + def __init__(self, target_value=5, **kwargs): + self.target_value = target_value + self.n_labels = None + super(WeightedL2Loss, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["target_value"] = self.target_value + return config + + def build(self, input_shape): + assert len(input_shape) == 2, 'DiceLoss expects 2 inputs to compute the Dice loss.' + assert input_shape[0] == input_shape[1], 'the two inputs must have the same shape.' + self.n_labels = input_shape[0][-1] + self.built = True + super(WeightedL2Loss, self).build(input_shape) + + def call(self, inputs, **kwargs): + gt = inputs[0] + pred = inputs[1] + weights = tf.expand_dims(1 - gt[..., 0] + 1e-8, -1) + return K.sum(weights * K.square(pred - self.target_value * (2 * gt - 1))) / (K.sum(weights) * self.n_labels) + + def compute_output_shape(self, input_shape): + return [[]] + + +class CrossEntropyLoss(Layer): + """This layer computes the cross-entropy loss between two tensors. + These tensors are expected to have the same shape (one-hot encoding) [batch, size_dim1, ..., size_dimN, n_labels]. + The first input tensor is the GT and the second is the prediction: ce_loss = CrossEntropyLoss()([gt, pred]) + + :param class_weights: (optional) if given, the loss is obtained by a weighted average of the Dice across labels. + Must be a sequence or 1d numpy array of length n_labels. Can also be -1, where the weights are dynamically set to + the inverse of the volume of each label in the ground truth. + :param boundary_weights: (optional) bonus weight that we apply to the voxels close to boundaries between structures + when computing the loss. Default is 0 where no boundary weighting is applied. + :param boundary_dist: (optional) if boundary_weight is not 0, the extra boundary weighting is applied to all voxels + within this distance to a region boundary. Default is 3. + :param skip_background: (optional) whether to skip boundary weighting for the background class, as this may be + redundant when we have several labels. This is only used if boundary_weight is not 0. + :param enable_checks: (optional) whether to make sure that the 2 input tensors are probabilistic (i.e. the label + probabilities sum to 1 at each voxel location). Default is True. + """ + + def __init__(self, + class_weights=None, + boundary_weights=0, + boundary_dist=3, + skip_background=True, + enable_checks=True, + **kwargs): + + self.class_weights = class_weights + self.dynamic_weighting = False + self.class_weights_tens = None + self.boundary_weights = boundary_weights + self.boundary_dist = boundary_dist + self.skip_background = skip_background + self.enable_checks = enable_checks + self.spatial_axes = None + self.avg_pooling_layer = None + super(CrossEntropyLoss, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["class_weights"] = self.class_weights + config["boundary_weights"] = self.boundary_weights + config["boundary_dist"] = self.boundary_dist + config["skip_background"] = self.skip_background + config["enable_checks"] = self.enable_checks + return config + + def build(self, input_shape): + + # get shape + assert len(input_shape) == 2, 'CrossEntropy expects 2 inputs to compute the Dice loss.' + assert input_shape[0] == input_shape[1], 'the two inputs must have the same shape.' + inshape = input_shape[0][1:] + n_dims = len(inshape[:-1]) + n_labels = inshape[-1] + self.spatial_axes = list(range(1, n_dims + 1)) + self.avg_pooling_layer = getattr(keras.layers, 'AvgPool%dD' % n_dims) + self.skip_background = False if n_labels == 1 else self.skip_background + + # build tensor with class weights + if self.class_weights is not None: + if self.class_weights == -1: + self.dynamic_weighting = True + else: + class_weights_tens = utils.reformat_to_list(self.class_weights, n_labels) + class_weights_tens = tf.convert_to_tensor(class_weights_tens, 'float32') + self.class_weights_tens = l2i_et.expand_dims(class_weights_tens, [0] * (1 + n_dims)) + + self.built = True + super(CrossEntropyLoss, self).build(input_shape) + + def call(self, inputs, **kwargs): + + # make sure tensors are probabilistic + gt = inputs[0] + pred = inputs[1] + if self.enable_checks: # disabling is useful to, e.g., use incomplete label maps + gt = K.clip(gt / (tf.math.reduce_sum(gt, axis=-1, keepdims=True) + tf.keras.backend.epsilon()), 0, 1) + pred = pred / (tf.math.reduce_sum(pred, axis=-1, keepdims=True) + tf.keras.backend.epsilon()) + pred = K.clip(pred, tf.keras.backend.epsilon(), 1 - tf.keras.backend.epsilon()) # to avoid log(0) + + # compare prediction/target, ce has the same shape has the input tensors + ce = -gt * tf.math.log(pred) + + # apply boundary weighting (ie voxels close to region boundaries will be counted several times to compute Dice) + if self.boundary_weights: + avg = self.avg_pooling_layer(pool_size=2 * self.boundary_dist + 1, strides=1, padding='same')(gt) + boundaries = tf.cast(avg > 0., 'float32') * tf.cast(avg < (1 / len(self.spatial_axes) - 1e-4), 'float32') + if self.skip_background: + boundaries_channels = tf.unstack(boundaries, axis=-1) + boundaries = tf.stack([tf.zeros_like(boundaries_channels[0])] + boundaries_channels[1:], axis=-1) + boundary_weights_tensor = 1 + self.boundary_weights * boundaries + ce *= boundary_weights_tensor + else: + boundary_weights_tensor = None + + # apply class weighting across labels. By the end of this, ce still has the same shape has the input tensors. + if self.dynamic_weighting: # the weight of a class is the inverse of its volume in the gt + if boundary_weights_tensor is not None: # we account for the boundary weighting to compute volume + self.class_weights_tens = 1 / tf.reduce_sum(gt * boundary_weights_tensor, self.spatial_axes, True) + else: + self.class_weights_tens = 1 / tf.reduce_sum(gt, self.spatial_axes) + if self.class_weights_tens is not None: + self.class_weights_tens /= tf.reduce_sum(self.class_weights_tens, -1) + ce = tf.reduce_sum(ce * self.class_weights_tens, -1) + + # sum along label axis, and take the mean along spatial dimensions + ce = tf.math.reduce_mean(tf.math.reduce_sum(ce, axis=-1)) + + return ce + + def compute_output_shape(self, input_shape): + return [[]] + + +class MomentLoss(Layer): + """This layer computes a moment loss between two tensors. Specifically, it computes the distance between the centres + of gravity for all the channels of the two tensors, and then returns a value averaged across all channels. + These tensors are expected to have the same shape [batch, size_dim1, ..., size_dimN, n_channels]. + The first input tensor is the GT and the second is the prediction: moment_loss = MomentLoss()([gt, pred]) + + :param class_weights: (optional) if given, the loss is obtained by a weighted average of the Dice across labels. + Must be a sequence or 1d numpy array of length n_labels. Can also be -1, where the weights are dynamically set to + the inverse of the volume of each label in the ground truth. + :param enable_checks: (optional) whether to make sure that the 2 input tensors are probabilistic (i.e. the label + probabilities sum to 1 at each voxel location). Default is True. + """ + + def __init__(self, class_weights=None, enable_checks=False, **kwargs): + self.class_weights = class_weights + self.dynamic_weighting = False + self.class_weights_tens = None + self.enable_checks = enable_checks + self.spatial_axes = None + self.coordinates = None + super(MomentLoss, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["class_weights"] = self.class_weights + config["enable_checks"] = self.enable_checks + return config + + def build(self, input_shape): + + # get shape + assert len(input_shape) == 2, 'MomentLoss expects 2 inputs to compute the Dice loss.' + assert input_shape[0] == input_shape[1], 'the two inputs must have the same shape.' + inshape = input_shape[0][1:] + n_dims = len(inshape[:-1]) + n_labels = inshape[-1] + self.spatial_axes = list(range(1, n_dims + 1)) + + # build coordinate meshgrid of size (1, dim1, dim2, ..., dimN, ndim, nchan) + self.coordinates = tf.stack(nrn_utils.volshape_to_ndgrid(inshape[:-1]), -1) + self.coordinates = tf.cast(l2i_et.expand_dims(tf.stack([self.coordinates] * n_labels, -1), 0), 'float32') + + # build tensor with class weights + if self.class_weights is not None: + if self.class_weights == -1: + self.dynamic_weighting = True + else: + class_weights_tens = utils.reformat_to_list(self.class_weights, n_labels) + class_weights_tens = tf.convert_to_tensor(class_weights_tens, 'float32') + self.class_weights_tens = l2i_et.expand_dims(class_weights_tens, 0) + + self.built = True + super(MomentLoss, self).build(input_shape) + + def call(self, inputs, **kwargs): + + # make sure tensors are probabilistic + gt = inputs[0] # (B, dim1, dim2, ..., dimN, nchan) + pred = inputs[1] + if self.enable_checks: # disabling is useful to, e.g., use incomplete label maps + gt = gt / (tf.math.reduce_sum(gt, axis=-1, keepdims=True) + tf.keras.backend.epsilon()) + pred = pred / (tf.math.reduce_sum(pred, axis=-1, keepdims=True) + tf.keras.backend.epsilon()) + + # compute loss + gt_mean_coordinates = self._mean_coordinates(gt) # (B, ndim, nchan) + pred_mean_coordinates = self._mean_coordinates(pred) + loss = tf.math.sqrt(tf.reduce_sum(tf.square(pred_mean_coordinates - gt_mean_coordinates), axis=1)) # (B, nchan) + + # apply class weighting across labels. In this case loss will have shape (batch), otherwise (batch, n_labels). + if self.dynamic_weighting: # the weight of a class is the inverse of its volume in the gt + self.class_weights_tens = 1 / tf.reduce_sum(gt, self.spatial_axes) + if self.class_weights_tens is not None: + self.class_weights_tens /= tf.reduce_sum(self.class_weights_tens, -1) + loss = tf.reduce_sum(loss * self.class_weights_tens, -1) + + return tf.math.reduce_mean(loss) + + def _mean_coordinates(self, tensor): + tensor = l2i_et.expand_dims(tensor, axis=-2) # (B, dim1, dim2, ..., dimN, 1, nchan) + numerator = tf.reduce_sum(tensor * self.coordinates, axis=self.spatial_axes) # (B, ndim, nchan) + denominator = tf.reduce_sum(tensor, axis=self.spatial_axes) + tf.keras.backend.epsilon() + return numerator / denominator + + def compute_output_shape(self, input_shape): + return [[]] + + +class ResetValuesToZero(Layer): + """This layer enables to reset given values to 0 within the input tensors. + + :param values: list of values to be reset to 0. + + example: + input = tf.convert_to_tensor(np.array([[1, 0, 2, 2, 2, 2, 0], + [1, 3, 3, 3, 3, 3, 3], + [1, 0, 0, 0, 4, 4, 4]])) + values = [1, 3] + ResetValuesToZero(values)(input) + >> [[0, 0, 2, 2, 2, 2, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 4, 4, 4]] + """ + + def __init__(self, values, **kwargs): + assert values is not None, 'please provide correct list of values, received None' + self.values = utils.reformat_to_list(values) + self.values_tens = None + self.n_values = len(values) + super(ResetValuesToZero, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["values"] = self.values + return config + + def build(self, input_shape): + self.values_tens = tf.convert_to_tensor(self.values) + self.built = True + super(ResetValuesToZero, self).build(input_shape) + + def call(self, inputs, **kwargs): + values = tf.cast(self.values_tens, dtype=inputs.dtype) + for i in range(self.n_values): + inputs = tf.where(tf.equal(inputs, values[i]), tf.zeros_like(inputs), inputs) + return inputs + + +class ConvertLabels(Layer): + """Convert all labels in a tensor by the corresponding given set of values. + labels_converted = ConvertLabels(source_values, dest_values)(labels). + labels must be an int32 tensor, and labels_converted will also be int32. + + :param source_values: list of all the possible values in labels. Must be a list or a 1D numpy array. + :param dest_values: list of all the target label values. Must be ordered the same as source values: + labels[labels == source_values[i]] = dest_values[i]. + If None (default), dest_values is equal to [0, ..., N-1], where N is the total number of values in source_values, + which enables to remap label maps to [0, ..., N-1]. + """ + + def __init__(self, source_values, dest_values=None, **kwargs): + self.source_values = source_values + self.dest_values = dest_values + self.lut = None + super(ConvertLabels, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["source_values"] = self.source_values + config["dest_values"] = self.dest_values + return config + + def build(self, input_shape): + self.lut = tf.convert_to_tensor(utils.get_mapping_lut(self.source_values, dest=self.dest_values), dtype='int32') + self.built = True + super(ConvertLabels, self).build(input_shape) + + def call(self, inputs, **kwargs): + return tf.gather(self.lut, tf.cast(inputs, dtype='int32')) + + +class PadAroundCentre(Layer): + """Pad the input tensor to the specified shape with the given value. + The input tensor is expected to have shape [batchsize, shape_dim1, ..., shape_dimn, channel]. + :param pad_margin: margin to use for padding. The tensor will be padded by the provided margin on each side. + Can either be a number (all axes padded with the same margin), or a list/numpy array of length n_dims. + example: if tensor is of shape [batch, x, y, z, n_channels] and margin=10, then the padded tensor will be of + shape [batch, x+2*10, y+2*10, z+2*10, n_channels]. + :param pad_shape: shape to pad the tensor to. Can either be a number (all axes padded to the same shape), or a + list/numpy array of length n_dims. + :param value: value to pad the tensors with. Default is 0. + """ + + def __init__(self, pad_margin=None, pad_shape=None, value=0, **kwargs): + self.pad_margin = pad_margin + self.pad_shape = pad_shape + self.value = value + self.pad_margin_tens = None + self.pad_shape_tens = None + self.n_dims = None + super(PadAroundCentre, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["pad_margin"] = self.pad_margin + config["pad_shape"] = self.pad_shape + config["value"] = self.value + return config + + def build(self, input_shape): + # input shape + self.n_dims = len(input_shape) - 2 + shape = list(input_shape) + shape[0] = 0 + shape[-1] = 0 + + if self.pad_margin is not None: + assert self.pad_shape is None, 'please do not provide a padding shape and margin at the same time.' + + # reformat padding margins + pad = np.transpose(np.array([[0] + utils.reformat_to_list(self.pad_margin, self.n_dims) + [0]] * 2)) + self.pad_margin_tens = tf.convert_to_tensor(pad, dtype='int32') + + elif self.pad_shape is not None: + assert self.pad_margin is None, 'please do not provide a padding shape and margin at the same time.' + + # pad shape + tensor_shape = tf.cast(tf.convert_to_tensor(shape), 'int32') + self.pad_shape_tens = np.array([0] + utils.reformat_to_list(self.pad_shape, length=self.n_dims) + [0]) + self.pad_shape_tens = tf.convert_to_tensor(self.pad_shape_tens, dtype='int32') + self.pad_shape_tens = tf.math.maximum(tensor_shape, self.pad_shape_tens) + + # padding margin + min_margins = (self.pad_shape_tens - tensor_shape) / 2 + max_margins = self.pad_shape_tens - tensor_shape - min_margins + self.pad_margin_tens = tf.stack([min_margins, max_margins], axis=-1) + + else: + raise Exception('please either provide a padding shape or a padding margin.') + + self.built = True + super(PadAroundCentre, self).build(input_shape) + + def call(self, inputs, **kwargs): + return tf.pad(inputs, self.pad_margin_tens, mode='CONSTANT', constant_values=self.value) + + +class MaskEdges(Layer): + """Reset the edges of a tensor to zero (i.e. with bands of zeros along the specified axes). + The width of the zero-band is randomly drawn from a uniform distribution, whose range is given in boundaries. + + :param axes: axes along which to reset edges to zero. Can be an int (single axis), or a sequence. + :param boundaries: numpy array of shape (len(axes), 4). Each row contains the two bounds of the uniform + distributions from which we draw the width of the zero-bands on each side. + Those bounds must be expressed in relative side (i.e. between 0 and 1). + :return: a tensor of the same shape as the input, with bands of zeros along the specified axes. + + example: + tensor=tf.constant([[[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]]]) # shape = [1,10,10,1] + axes=1 + boundaries = np.array([[0.2, 0.45, 0.85, 0.9]]) + + In this case, we reset the edges along the 2nd dimension (i.e. the 1st dimension after the batch dimension), + the 1st zero-band will expand from the 1st row to a number drawn from [0.2*tensor.shape[1], 0.45*tensor.shape[1]], + and the 2nd zero-band will expand from a row drawn from [0.85*tensor.shape[1], 0.9*tensor.shape[1]], to the end of + the tensor. A possible output could be: + array([[[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]]]) # shape = [1,10,10,1] + """ + + def __init__(self, axes, boundaries, prob_mask=1, **kwargs): + self.axes = utils.reformat_to_list(axes, dtype='int') + self.boundaries = utils.reformat_to_n_channels_array(boundaries, n_dims=4, n_channels=len(self.axes)) + self.prob_mask = prob_mask + self.inputshape = None + super(MaskEdges, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["axes"] = self.axes + config["boundaries"] = self.boundaries + config["prob_mask"] = self.prob_mask + return config + + def build(self, input_shape): + self.inputshape = input_shape + self.built = True + super(MaskEdges, self).build(input_shape) + + def call(self, inputs, **kwargs): + + # build mask + mask = tf.ones_like(inputs) + for i, axis in enumerate(self.axes): + + # select restricting indices + axis_boundaries = self.boundaries[i, :] + idx1 = tf.math.round(tf.random.uniform([1], + minval=axis_boundaries[0] * self.inputshape[axis], + maxval=axis_boundaries[1] * self.inputshape[axis])) + idx2 = tf.math.round(tf.random.uniform([1], + minval=axis_boundaries[2] * self.inputshape[axis], + maxval=axis_boundaries[3] * self.inputshape[axis] - 1) - idx1) + idx3 = self.inputshape[axis] - idx1 - idx2 + split_idx = tf.cast(tf.concat([idx1, idx2, idx3], axis=0), dtype='int32') + + # update mask + split_list = tf.split(inputs, split_idx, axis=axis) + tmp_mask = tf.concat([tf.zeros_like(split_list[0]), + tf.ones_like(split_list[1]), + tf.zeros_like(split_list[2])], axis=axis) + mask = mask * tmp_mask + + # mask second_channel + tensor = K.switch(tf.squeeze(K.greater(tf.random.uniform([1], 0, 1), 1 - self.prob_mask)), + inputs * mask, + inputs) + + return [tensor, mask] + + def compute_output_shape(self, input_shape): + return [input_shape] * 2 + + +class ImageGradients(Layer): + + def __init__(self, gradient_type='sobel', return_magnitude=False, **kwargs): + + self.gradient_type = gradient_type + assert (self.gradient_type == 'sobel') | (self.gradient_type == '1-step_diff'), \ + 'gradient_type should be either sobel or 1-step_diff, had %s' % self.gradient_type + + # shape + self.n_dims = 0 + self.shape = None + self.n_channels = 0 + + # convolution params if sobel diff + self.stride = None + self.kernels = None + self.convnd = None + + self.return_magnitude = return_magnitude + + super(ImageGradients, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["gradient_type"] = self.gradient_type + config["return_magnitude"] = self.return_magnitude + return config + + def build(self, input_shape): + + # get shapes + self.n_dims = len(input_shape) - 2 + self.shape = input_shape[1:] + self.n_channels = input_shape[-1] + + # prepare kernel if sobel gradients + if self.gradient_type == 'sobel': + self.kernels = l2i_et.sobel_kernels(self.n_dims) + self.stride = [1] * (self.n_dims + 2) + self.convnd = getattr(tf.nn, 'conv%dd' % self.n_dims) + else: + self.kernels = self.convnd = self.stride = None + + self.built = True + super(ImageGradients, self).build(input_shape) + + def call(self, inputs, **kwargs): + + image = inputs + batchsize = tf.split(tf.shape(inputs), [1, -1])[0] + gradients = list() + + # sobel method + if self.gradient_type == 'sobel': + # get sobel gradients in each direction + for n in range(self.n_dims): + gradient = image + # apply 1D kernel in each direction (sobel kernels are separable), instead of applying a nD kernel + for k in self.kernels[n]: + gradient = tf.concat([self.convnd(tf.expand_dims(gradient[..., n], -1), k, self.stride, 'SAME') + for n in range(self.n_channels)], -1) + gradients.append(gradient) + + # 1-step method, only supports 2 and 3D + else: + + # get 1-step diff + if self.n_dims == 2: + gradients.append(image[:, 1:, :, :] - image[:, :-1, :, :]) # dx + gradients.append(image[:, :, 1:, :] - image[:, :, :-1, :]) # dy + + elif self.n_dims == 3: + gradients.append(image[:, 1:, :, :, :] - image[:, :-1, :, :, :]) # dx + gradients.append(image[:, :, 1:, :, :] - image[:, :, :-1, :, :]) # dy + gradients.append(image[:, :, :, 1:, :] - image[:, :, :, :-1, :]) # dz + + else: + raise Exception('ImageGradients only support 2D or 3D tensors for 1-step diff, had: %dD' % self.n_dims) + + # pad with zeros to return tensors of the same shape as input + for i in range(self.n_dims): + tmp_shape = list(self.shape) + tmp_shape[i] = 1 + zeros = tf.zeros(tf.concat([batchsize, tf.convert_to_tensor(tmp_shape, dtype='int32')], 0), image.dtype) + gradients[i] = tf.concat([gradients[i], zeros], axis=i + 1) + + # compute total gradient magnitude if necessary, or concatenate different gradients along the channel axis + if self.return_magnitude: + gradients = tf.sqrt(tf.reduce_sum(tf.square(tf.stack(gradients, axis=-1)), axis=-1)) + else: + gradients = tf.concat(gradients, axis=-1) + + return gradients + + def compute_output_shape(self, input_shape): + if not self.return_magnitude: + input_shape = list(input_shape) + input_shape[-1] = self.n_dims + return tuple(input_shape) + + +class RandomDilationErosion(Layer): + """ + GPU implementation of binary dilation or erosion. The operation can be chosen to be always a dilation, or always an + erosion, or randomly choosing between them for each element of the batch. + The chosen operation is applied to the input with a given probability. Moreover, it is also possible to randomise + the factor of the operation for each element of the mini-batch. + :param min_factor: minimum possible value for the dilation/erosion factor. Must be an integer. + :param max_factor: minimum possible value for the dilation/erosion factor. Must be an integer. + Set it to the same value as min_factor to always perform dilation/erosion with the same factor. + :param prob: probability with which to apply the selected operation to the input. + :param operation: which operation to apply. Can be 'dilation' or 'erosion' or 'random'. + :param return_mask: if operation is erosion and the input of this layer is a label map, we have the + choice to either return the eroded label map or the mask (return_mask=True) + """ + + def __init__(self, min_factor, max_factor, max_factor_dilate=None, prob=1, operation='random', return_mask=False, + **kwargs): + + self.min_factor = min_factor + self.max_factor = max_factor + self.max_factor_dilate = max_factor_dilate if max_factor_dilate is not None else self.max_factor + self.prob = prob + self.operation = operation + self.return_mask = return_mask + self.n_dims = None + self.inshape = None + self.n_channels = None + self.convnd = None + super(RandomDilationErosion, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["min_factor"] = self.min_factor + config["max_factor"] = self.max_factor + config["max_factor_dilate"] = self.max_factor_dilate + config["prob"] = self.prob + config["operation"] = self.operation + config["return_mask"] = self.return_mask + return config + + def build(self, input_shape): + + # input shape + self.inshape = input_shape + self.n_dims = len(self.inshape) - 2 + self.n_channels = self.inshape[-1] + + # prepare convolution + self.convnd = getattr(tf.nn, 'conv%dd' % self.n_dims) + + self.built = True + super(RandomDilationErosion, self).build(input_shape) + + def call(self, inputs, **kwargs): + + # sample probability of applying operation. If random negative is erosion and positive is dilation + batchsize = tf.split(tf.shape(inputs), [1, -1])[0] + shape = tf.concat([batchsize, tf.convert_to_tensor([1], dtype='int32')], axis=0) + if self.operation == 'dilation': + prob = tf.random.uniform(shape, 0, 1) + elif self.operation == 'erosion': + prob = tf.random.uniform(shape, -1, 0) + elif self.operation == 'random': + prob = tf.random.uniform(shape, -1, 1) + else: + raise ValueError("operation should either be 'dilation' 'erosion' or 'random', had %s" % self.operation) + + # build kernel + if self.min_factor == self.max_factor: + dist_threshold = self.min_factor * tf.ones(shape, dtype='int32') + else: + if (self.max_factor == self.max_factor_dilate) | (self.operation != 'random'): + dist_threshold = tf.random.uniform(shape, minval=self.min_factor, maxval=self.max_factor, dtype='int32') + else: + dist_threshold = tf.cast(tf.map_fn(self._sample_factor, [prob], dtype=tf.float32), dtype='int32') + kernel = l2i_et.unit_kernel(dist_threshold, self.n_dims, max_dist_threshold=self.max_factor) + + # convolve input mask with kernel according to given probability + mask = tf.cast(tf.cast(inputs, dtype='bool'), dtype='float32') + mask = tf.map_fn(self._single_blur, [mask, kernel, prob], dtype=tf.float32) + mask = tf.cast(mask, 'bool') + + if self.return_mask: + return mask + else: + return inputs * tf.cast(mask, dtype=inputs.dtype) + + def _sample_factor(self, inputs): + return tf.cast(K.switch(K.less(tf.squeeze(inputs[0]), 0), + tf.random.uniform((1,), self.min_factor, self.max_factor, dtype='int32'), + tf.random.uniform((1,), self.min_factor, self.max_factor_dilate, dtype='int32')), + dtype='float32') + + def _single_blur(self, inputs): + # dilate... + new_mask = K.switch(K.greater(tf.squeeze(inputs[2]), 1 - self.prob + 0.001), + tf.cast(tf.greater(tf.squeeze(self.convnd(tf.expand_dims(inputs[0], 0), inputs[1], + [1] * (self.n_dims + 2), padding='SAME'), axis=0), 0.01), dtype='float32'), + inputs[0]) + # ...or erode + new_mask = K.switch(K.less(tf.squeeze(inputs[2]), - (1 - self.prob + 0.001)), + 1 - tf.cast(tf.greater(tf.squeeze(self.convnd(tf.expand_dims(1 - new_mask, 0), inputs[1], + [1] * (self.n_dims + 2), padding='SAME'), axis=0), 0.01), dtype='float32'), + new_mask) + return new_mask + + def compute_output_shape(self, input_shape): + return input_shape