--- a +++ b/ext/lab2im/image_generator.py @@ -0,0 +1,266 @@ +""" +If you use this code, please cite the first SynthSeg paper: +https://github.com/BBillot/lab2im/blob/master/bibtex.bib + +Copyright 2020 Benjamin Billot + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License at +https://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software distributed under the License is +distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +implied. See the License for the specific language governing permissions and limitations under the +License. +""" + + +# python imports +import numpy as np +import numpy.random as npr + +# project imports +from ext.lab2im import utils +from ext.lab2im import edit_volumes +from ext.lab2im.lab2im_model import lab2im_model + + +class ImageGenerator: + + def __init__(self, + labels_dir, + generation_labels=None, + output_labels=None, + batchsize=1, + n_channels=1, + target_res=None, + output_shape=None, + output_div_by_n=None, + generation_classes=None, + prior_distributions='uniform', + prior_means=None, + prior_stds=None, + use_specific_stats_for_channel=False, + blur_range=1.15): + """ + This class is wrapper around the lab2im_model model. It contains the GPU model that generates images from labels + maps, and a python generator that supplies the input data for this model. + To generate pairs of image/labels you can just call the method generate_image() on an object of this class. + + :param labels_dir: path of folder with all input label maps, or to a single label map. + + # IMPORTANT !!! + # Each time we provide a parameter with separate values for each axis (e.g. with a numpy array or a sequence), + # these values refer to the RAS axes. + + # label maps-related parameters + :param generation_labels: (optional) list of all possible label values in the input label maps. + Default is None, where the label values are directly gotten from the provided label maps. + If not None, can be a sequence or a 1d numpy array, or the path to a 1d numpy array. + :param output_labels: (optional) list of the same length as generation_labels to indicate which values to use in + the label maps returned by this function, i.e. all occurrences of generation_labels[i] in the input label maps + will be converted to output_labels[i] in the returned label maps. Examples: + Set output_labels[i] to zero if you wish to erase the value generation_labels[i] from the returned label maps. + Set output_labels[i]=generation_labels[i] to keep the value generation_labels[i] in the returned maps. + Can be a list or a 1d numpy array. By default output_labels is equal to generation_labels. + + # output-related parameters + :param batchsize: (optional) numbers of images to generate per mini-batch. Default is 1. + :param n_channels: (optional) number of channels to be synthetised. Default is 1. + :param target_res: (optional) target resolution of the generated images and corresponding label maps. + If None, the outputs will have the same resolution as the input label maps. + Can be a number (isotropic resolution), a sequence, a 1d numpy array, or the path to a 1d numpy array. + :param output_shape: (optional) shape of the output image, obtained by randomly cropping the generated image. + Can be an integer (same size in all dimensions), a sequence, a 1d numpy array, or the path to a 1d numpy array. + :param output_div_by_n: (optional) forces the output shape to be divisible by this value. It overwrites + output_shape if necessary. Can be an integer (same size in all dimensions), a sequence, a 1d numpy array, or + the path to a 1d numpy array. + + # GMM-sampling parameters + :param generation_classes: (optional) Indices regrouping generation labels into classes of same intensity + distribution. Regrouped labels will thus share the same Gaussian when sampling a new image. Can be a sequence, a + 1d numpy array, or the path to a 1d numpy array. + It should have the same length as generation_labels, and contain values between 0 and K-1, where K is the total + number of classes. Default is all labels have different classes (K=len(generation_labels)). + :param prior_distributions: (optional) type of distribution from which we sample the GMM parameters. + Can either be 'uniform', or 'normal'. Default is 'uniform'. + :param prior_means: (optional) hyperparameters controlling the prior distributions of the GMM means. Because + these prior distributions are uniform or normal, they require by 2 hyperparameters. Thus prior_means can be: + 1) a sequence of length 2, directly defining the two hyperparameters: [min, max] if prior_distributions is + uniform, [mean, std] if the distribution is normal. The GMM means of are independently sampled at each + mini_batch from the same distribution. + 2) an array of shape (2, K), where K is the number of classes (K=len(generation_labels) if generation_classes is + not given). The mean of the Gaussian distribution associated to class k in [0, ...K-1] is sampled at each + mini-batch from U(prior_means[0,k], prior_means[1,k]) if prior_distributions is uniform, and from + N(prior_means[0,k], prior_means[1,k]) if prior_distributions is normal. + 3) an array of shape (2*n_mod, K), where each block of two rows is associated to hyperparameters derived + from different modalities. In this case, if use_specific_stats_for_channel is False, we first randomly select a + modality from the n_mod possibilities, and we sample the GMM means like in 2). + If use_specific_stats_for_channel is True, each block of two rows correspond to a different channel + (n_mod=n_channels), thus we select the corresponding block to each channel rather than randomly drawing it. + 4) the path to such a numpy array. + Default is None, which corresponds to prior_means = [25, 225]. + :param prior_stds: (optional) same as prior_means but for the standard deviations of the GMM. + Default is None, which corresponds to prior_stds = [5, 25]. + :param use_specific_stats_for_channel: (optional) whether the i-th block of two rows in the prior arrays must be + only used to generate the i-th channel. If True, n_mod should be equal to n_channels. Default is False. + + # blurring parameters + :param blur_range: (optional) Randomise the standard deviation of the blurring kernels, (whether data_res is + given or not). At each mini_batch, the standard deviation of the blurring kernels are multiplied by a c + coefficient sampled from a uniform distribution with bounds [1/blur_range, blur_range]. + If None, no randomisation. Default is 1.15. + """ + + # prepare data files + self.labels_paths = utils.list_images_in_folder(labels_dir) + + # generation parameters + self.labels_shape, self.aff, self.n_dims, _, self.header, self.atlas_res = \ + utils.get_volume_info(self.labels_paths[0], aff_ref=np.eye(4)) + self.n_channels = n_channels + if generation_labels is not None: + self.generation_labels = utils.load_array_if_path(generation_labels) + else: + self.generation_labels, _ = utils.get_list_labels(labels_dir=labels_dir) + if output_labels is not None: + self.output_labels = utils.load_array_if_path(output_labels) + else: + self.output_labels = self.generation_labels + self.target_res = utils.load_array_if_path(target_res) + self.batchsize = batchsize + # preliminary operations + self.output_shape = utils.load_array_if_path(output_shape) + self.output_div_by_n = output_div_by_n + # GMM parameters + self.prior_distributions = prior_distributions + if generation_classes is not None: + self.generation_classes = utils.load_array_if_path(generation_classes) + assert self.generation_classes.shape == self.generation_labels.shape, \ + 'if provided, generation labels should have the same shape as generation_labels' + unique_classes = np.unique(self.generation_classes) + assert np.array_equal(unique_classes, np.arange(np.max(unique_classes)+1)), \ + 'generation_classes should a linear range between 0 and its maximum value.' + else: + self.generation_classes = np.arange(self.generation_labels.shape[0]) + self.prior_means = utils.load_array_if_path(prior_means) + self.prior_stds = utils.load_array_if_path(prior_stds) + self.use_specific_stats_for_channel = use_specific_stats_for_channel + + # blurring parameters + self.blur_range = blur_range + + # build transformation model + self.labels_to_image_model, self.model_output_shape = self._build_lab2im_model() + + # build generator for model inputs + self.model_inputs_generator = self._build_model_inputs(len(self.generation_labels)) + + # build brain generator + self.image_generator = self._build_image_generator() + + def _build_lab2im_model(self): + # build_model + lab_to_im_model = lab2im_model(labels_shape=self.labels_shape, + n_channels=self.n_channels, + generation_labels=self.generation_labels, + output_labels=self.output_labels, + atlas_res=self.atlas_res, + target_res=self.target_res, + output_shape=self.output_shape, + output_div_by_n=self.output_div_by_n, + blur_range=self.blur_range) + out_shape = lab_to_im_model.output[0].get_shape().as_list()[1:] + return lab_to_im_model, out_shape + + def _build_image_generator(self): + while True: + model_inputs = next(self.model_inputs_generator) + [image, labels] = self.labels_to_image_model.predict(model_inputs) + yield image, labels + + def generate_image(self): + """call this method when an object of this class has been instantiated to generate new brains""" + (image, labels) = next(self.image_generator) + # put back images in native space + list_images = list() + list_labels = list() + for i in range(self.batchsize): + list_images.append(edit_volumes.align_volume_to_ref(image[i], np.eye(4), aff_ref=self.aff, + n_dims=self.n_dims)) + list_labels.append(edit_volumes.align_volume_to_ref(labels[i], np.eye(4), aff_ref=self.aff, + n_dims=self.n_dims)) + image = np.stack(list_images, axis=0) + labels = np.stack(list_labels, axis=0) + return np.squeeze(image), np.squeeze(labels) + + def _build_model_inputs(self, n_labels): + + # get label info + _, _, n_dims, _, _, _ = utils.get_volume_info(self.labels_paths[0]) + + # Generate! + while True: + + # randomly pick as many images as batchsize + indices = npr.randint(len(self.labels_paths), size=self.batchsize) + + # initialise input lists + list_label_maps = [] + list_means = [] + list_stds = [] + + for idx in indices: + + # load label in identity space, and add them to inputs + y = utils.load_volume(self.labels_paths[idx], dtype='int', aff_ref=np.eye(4)) + list_label_maps.append(utils.add_axis(y, axis=[0, -1])) + + # add means and standard deviations to inputs + means = np.empty((1, n_labels, 0)) + stds = np.empty((1, n_labels, 0)) + for channel in range(self.n_channels): + + # retrieve channel specific stats if necessary + if isinstance(self.prior_means, np.ndarray): + if (self.prior_means.shape[0] > 2) & self.use_specific_stats_for_channel: + if self.prior_means.shape[0] / 2 != self.n_channels: + raise ValueError("the number of blocks in prior_means does not match n_channels. This " + "message is printed because use_specific_stats_for_channel is True.") + tmp_prior_means = self.prior_means[2 * channel:2 * channel + 2, :] + else: + tmp_prior_means = self.prior_means + else: + tmp_prior_means = self.prior_means + if isinstance(self.prior_stds, np.ndarray): + if (self.prior_stds.shape[0] > 2) & self.use_specific_stats_for_channel: + if self.prior_stds.shape[0] / 2 != self.n_channels: + raise ValueError("the number of blocks in prior_stds does not match n_channels. This " + "message is printed because use_specific_stats_for_channel is True.") + tmp_prior_stds = self.prior_stds[2 * channel:2 * channel + 2, :] + else: + tmp_prior_stds = self.prior_stds + else: + tmp_prior_stds = self.prior_stds + + # draw means and std devs from priors + tmp_classes_means = utils.draw_value_from_distribution(tmp_prior_means, n_labels, + self.prior_distributions, 125., 100., + positive_only=True) + tmp_classes_stds = utils.draw_value_from_distribution(tmp_prior_stds, n_labels, + self.prior_distributions, 15., 10., + positive_only=True) + tmp_means = utils.add_axis(tmp_classes_means[self.generation_classes], axis=[0, -1]) + tmp_stds = utils.add_axis(tmp_classes_stds[self.generation_classes], axis=[0, -1]) + means = np.concatenate([means, tmp_means], axis=-1) + stds = np.concatenate([stds, tmp_stds], axis=-1) + list_means.append(means) + list_stds.append(stds) + + # build list of inputs of augmentation model + list_inputs = [list_label_maps, list_means, list_stds] + if self.batchsize > 1: # concatenate individual input types if batchsize > 1 + list_inputs = [np.concatenate(item, 0) for item in list_inputs] + else: + list_inputs = [item[0] for item in list_inputs] + + yield list_inputs