--- a +++ b/ext/lab2im/lab2im_model.py @@ -0,0 +1,174 @@ +""" +If you use this code, please cite the first SynthSeg paper: +https://github.com/BBillot/lab2im/blob/master/bibtex.bib + +Copyright 2020 Benjamin Billot + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License at +https://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software distributed under the License is +distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +implied. See the License for the specific language governing permissions and limitations under the +License. +""" + + +# python imports +import numpy as np +import keras.layers as KL +from keras.models import Model + +# project imports +from ext.lab2im import utils +from ext.lab2im import layers +from ext.lab2im.edit_tensors import resample_tensor, blurring_sigma_for_downsampling + + +def lab2im_model(labels_shape, + n_channels, + generation_labels, + output_labels, + atlas_res, + target_res, + output_shape=None, + output_div_by_n=None, + blur_range=1.15): + """ + This function builds a keras/tensorflow model to generate images from provided label maps. + The images are generated by sampling a Gaussian Mixture Model (of given parameters), conditioned on the label map. + The model will take as inputs: + -a label map + -a vector containing the means of the Gaussian Mixture Model for each label, + -a vector containing the standard deviations of the Gaussian Mixture Model for each label, + -an array of size batch*(n_dims+1)*(n_dims+1) representing a linear transformation + The model returns: + -the generated image normalised between 0 and 1. + -the corresponding label map, with only the labels present in output_labels (the other are reset to zero). + :param labels_shape: shape of the input label maps. Can be a sequence or a 1d numpy array. + :param n_channels: number of channels to be synthetised. + :param generation_labels: list of all possible label values in the input label maps. + Can be a sequence or a 1d numpy array. + :param output_labels: list of the same length as generation_labels to indicate which values to use in the label maps + returned by this model, i.e. all occurrences of generation_labels[i] in the input label maps will be converted to + output_labels[i] in the returned label maps. Examples: + Set output_labels[i] to zero if you wish to erase the value generation_labels[i] from the returned label maps. + Set output_labels[i]=generation_labels[i] if you wish to keep the value generation_labels[i] in the returned maps. + Can be a list or a 1d numpy array. By default output_labels is equal to generation_labels. + :param atlas_res: resolution of the input label maps. + Can be a number (isotropic resolution), a sequence, or a 1d numpy array. + :param target_res: target resolution of the generated images and corresponding label maps. + Can be a number (isotropic resolution), a sequence, or a 1d numpy array. + :param output_shape: (optional) desired shape of the output images. + If the atlas and target resolutions are the same, the output will be cropped to output_shape, and if the two + resolutions are different, the output will be resized with trilinear interpolation to output_shape. + Can be an integer (same size in all dimensions), a sequence, or a 1d numpy array. + :param output_div_by_n: (optional) forces the output shape to be divisible by this value. It overwrites output_shape + if necessary. Can be an integer (same size in all dimensions), a sequence, or a 1d numpy array. + :param blur_range: (optional) Randomise the standard deviation of the blurring kernels, (whether data_res is given + or not). At each mini_batch, the standard deviation of the blurring kernels are multiplied by a coefficient sampled + from a uniform distribution with bounds [1/blur_range, blur_range]. If None, no randomisation. Default is 1.15. + """ + + # reformat resolutions + labels_shape = utils.reformat_to_list(labels_shape) + n_dims, _ = utils.get_dims(labels_shape) + atlas_res = utils.reformat_to_n_channels_array(atlas_res, n_dims=n_dims)[0] + target_res = atlas_res if (target_res is None) else utils.reformat_to_n_channels_array(target_res, n_dims)[0] + + # get shapes + crop_shape, output_shape = get_shapes(labels_shape, output_shape, atlas_res, target_res, output_div_by_n) + + # define model inputs + labels_input = KL.Input(shape=labels_shape+[1], name='labels_input', dtype='int32') + means_input = KL.Input(shape=list(generation_labels.shape) + [n_channels], name='means_input') + stds_input = KL.Input(shape=list(generation_labels.shape) + [n_channels], name='stds_input') + + # deform labels + labels = layers.RandomSpatialDeformation(inter_method='nearest')(labels_input) + + # cropping + if crop_shape != labels_shape: + labels._keras_shape = tuple(labels.get_shape().as_list()) + labels = layers.RandomCrop(crop_shape)(labels) + + # build synthetic image + labels._keras_shape = tuple(labels.get_shape().as_list()) + image = layers.SampleConditionalGMM(generation_labels)([labels, means_input, stds_input]) + + # apply bias field + image._keras_shape = tuple(image.get_shape().as_list()) + image = layers.BiasFieldCorruption(.3, .025, same_bias_for_all_channels=False)(image) + + # intensity augmentation + image._keras_shape = tuple(image.get_shape().as_list()) + image = layers.IntensityAugmentation(clip=300, normalise=True, gamma_std=.2)(image) + + # blur image + sigma = blurring_sigma_for_downsampling(atlas_res, target_res) + image._keras_shape = tuple(image.get_shape().as_list()) + image = layers.GaussianBlur(sigma=sigma, random_blur_range=blur_range)(image) + + # resample to target res + if crop_shape != output_shape: + image = resample_tensor(image, output_shape, interp_method='linear') + labels = resample_tensor(labels, output_shape, interp_method='nearest') + + # reset unwanted labels to zero + labels = layers.ConvertLabels(generation_labels, dest_values=output_labels, name='labels_out')(labels) + + # build model (dummy layer enables to keep the labels when plugging this model to other models) + image = KL.Lambda(lambda x: x[0], name='image_out')([image, labels]) + brain_model = Model(inputs=[labels_input, means_input, stds_input], outputs=[image, labels]) + + return brain_model + + +def get_shapes(labels_shape, output_shape, atlas_res, target_res, output_div_by_n): + + n_dims = len(atlas_res) + + # get resampling factor + if atlas_res.tolist() != target_res.tolist(): + resample_factor = [atlas_res[i] / float(target_res[i]) for i in range(n_dims)] + else: + resample_factor = None + + # output shape specified, need to get cropping shape, and resample shape if necessary + if output_shape is not None: + output_shape = utils.reformat_to_list(output_shape, length=n_dims, dtype='int') + + # make sure that output shape is smaller or equal to label shape + if resample_factor is not None: + output_shape = [min(int(labels_shape[i] * resample_factor[i]), output_shape[i]) for i in range(n_dims)] + else: + output_shape = [min(labels_shape[i], output_shape[i]) for i in range(n_dims)] + + # make sure output shape is divisible by output_div_by_n + if output_div_by_n is not None: + tmp_shape = [utils.find_closest_number_divisible_by_m(s, output_div_by_n) + for s in output_shape] + if output_shape != tmp_shape: + print('output shape {0} not divisible by {1}, changed to {2}'.format(output_shape, output_div_by_n, + tmp_shape)) + output_shape = tmp_shape + + # get cropping and resample shape + if resample_factor is not None: + cropping_shape = [int(np.around(output_shape[i]/resample_factor[i], 0)) for i in range(n_dims)] + else: + cropping_shape = output_shape + + # no output shape specified, so no cropping unless label_shape is not divisible by output_div_by_n + else: + cropping_shape = labels_shape + if resample_factor is not None: + output_shape = [int(np.around(cropping_shape[i]*resample_factor[i], 0)) for i in range(n_dims)] + else: + output_shape = cropping_shape + # make sure output shape is divisible by output_div_by_n + if output_div_by_n is not None: + output_shape = [utils.find_closest_number_divisible_by_m(s, output_div_by_n, answer_type='closer') + for s in output_shape] + + return cropping_shape, output_shape