Diff of /SynthSeg/model_inputs.py [000000] .. [e571d1]

Switch to unified view

a b/SynthSeg/model_inputs.py
1
"""
2
If you use this code, please cite one of the SynthSeg papers:
3
https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib
4
5
Copyright 2020 Benjamin Billot
6
7
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
8
compliance with the License. You may obtain a copy of the License at
9
https://www.apache.org/licenses/LICENSE-2.0
10
Unless required by applicable law or agreed to in writing, software distributed under the License is
11
distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
12
implied. See the License for the specific language governing permissions and limitations under the
13
License.
14
"""
15
16
17
# python imports
18
import numpy as np
19
import numpy.random as npr
20
21
# third-party imports
22
from ext.lab2im import utils
23
24
25
def build_model_inputs(path_label_maps,
26
                       n_labels,
27
                       batchsize=1,
28
                       n_channels=1,
29
                       subjects_prob=None,
30
                       generation_classes=None,
31
                       prior_distributions='uniform',
32
                       prior_means=None,
33
                       prior_stds=None,
34
                       use_specific_stats_for_channel=False,
35
                       mix_prior_and_random=False):
36
    """
37
    This function builds a generator that will be used to give the necessary inputs to the label_to_image model: the
38
    input label maps, as well as the means and stds defining the parameters of the GMM (which change at each minibatch).
39
    :param path_label_maps: list of the paths of the input label maps.
40
    :param n_labels: number of labels in the input label maps.
41
    :param batchsize: (optional) numbers of images to generate per mini-batch. Default is 1.
42
    :param n_channels: (optional) number of channels to be synthesised. Default is 1.
43
    :param subjects_prob: (optional) relative order of importance (doesn't have to be probabilistic), with which to pick
44
    the provided label maps at each minibatch. Must be a 1D numpy array, as long as path_label_maps.
45
    :param generation_classes: (optional) Indices regrouping generation labels into classes of same intensity
46
    distribution. Regrouped labels will thus share the same Gaussian when sampling a new image. Can be a sequence or a
47
    1d numpy array. It should have the same length as generation_labels, and contain values between 0 and K-1, where K
48
    is the total number of classes. Default is all labels have different classes.
49
    :param prior_distributions: (optional) type of distribution from which we sample the GMM parameters.
50
    Can either be 'uniform', or 'normal'. Default is 'uniform'.
51
    :param prior_means: (optional) hyperparameters controlling the prior distributions of the GMM means. Because
52
    these prior distributions are uniform or normal, they require by 2 hyperparameters. Thus prior_means can be:
53
    1) a sequence of length 2, directly defining the two hyperparameters: [min, max] if prior_distributions is
54
    uniform, [mean, std] if the distribution is normal. The GMM means of are independently sampled at each
55
    mini_batch from the same distribution.
56
    2) an array of shape (2, K), where K is the number of classes (K=len(generation_labels) if generation_classes is
57
    not given). The mean of the Gaussian distribution associated to class k in [0, ...K-1] is sampled at each mini-batch
58
    from U(prior_means[0,k], prior_means[1,k]) if prior_distributions is uniform, or from
59
    N(prior_means[0,k], prior_means[1,k]) if prior_distributions is normal.
60
    3) an array of shape (2*n_mod, K), where each block of two rows is associated to hyperparameters derived
61
    from different modalities. In this case, if use_specific_stats_for_channel is False, we first randomly select a
62
    modality from the n_mod possibilities, and we sample the GMM means like in 2).
63
    If use_specific_stats_for_channel is True, each block of two rows correspond to a different channel
64
    (n_mod=n_channels), thus we select the corresponding block to each channel rather than randomly drawing it.
65
    4) the path to such a numpy array.
66
    Default is None, which corresponds to prior_means = [25, 225].
67
    :param prior_stds: (optional) same as prior_means but for the standard deviations of the GMM.
68
    Default is None, which corresponds to prior_stds = [5, 25].
69
    :param use_specific_stats_for_channel: (optional) whether the i-th block of two rows in the prior arrays must be
70
    only used to generate the i-th channel. If True, n_mod should be equal to n_channels. Default is False.
71
    :param mix_prior_and_random: (optional) if prior_means is not None, enables to reset the priors to their default
72
    values for half of these cases, and thus generate images of random contrast.
73
    """
74
75
    # allocate unique class to each label if generation classes is not given
76
    if generation_classes is None:
77
        generation_classes = np.arange(n_labels)
78
    n_classes = len(np.unique(generation_classes))
79
80
    # make sure subjects_prob sums to 1
81
    subjects_prob = utils.load_array_if_path(subjects_prob)
82
    if subjects_prob is not None:
83
        subjects_prob /= np.sum(subjects_prob)
84
85
    # Generate!
86
    while True:
87
88
        # randomly pick as many images as batchsize
89
        indices = npr.choice(np.arange(len(path_label_maps)), size=batchsize, p=subjects_prob)
90
91
        # initialise input lists
92
        list_label_maps = []
93
        list_means = []
94
        list_stds = []
95
96
        for idx in indices:
97
98
            # load input label map
99
            lab = utils.load_volume(path_label_maps[idx], dtype='int', aff_ref=np.eye(4))
100
            if (npr.uniform() > 0.7) & ('seg_cerebral' in path_label_maps[idx]):
101
                lab[lab == 24] = 0
102
103
            # add label map to inputs
104
            list_label_maps.append(utils.add_axis(lab, axis=[0, -1]))
105
106
            # add means and standard deviations to inputs
107
            means = np.empty((1, n_labels, 0))
108
            stds = np.empty((1, n_labels, 0))
109
            for channel in range(n_channels):
110
111
                # retrieve channel specific stats if necessary
112
                if isinstance(prior_means, np.ndarray):
113
                    if (prior_means.shape[0] > 2) & use_specific_stats_for_channel:
114
                        if prior_means.shape[0] / 2 != n_channels:
115
                            raise ValueError("the number of blocks in prior_means does not match n_channels. This "
116
                                             "message is printed because use_specific_stats_for_channel is True.")
117
                        tmp_prior_means = prior_means[2 * channel:2 * channel + 2, :]
118
                    else:
119
                        tmp_prior_means = prior_means
120
                else:
121
                    tmp_prior_means = prior_means
122
                if (prior_means is not None) & mix_prior_and_random & (npr.uniform() > 0.5):
123
                    tmp_prior_means = None
124
                if isinstance(prior_stds, np.ndarray):
125
                    if (prior_stds.shape[0] > 2) & use_specific_stats_for_channel:
126
                        if prior_stds.shape[0] / 2 != n_channels:
127
                            raise ValueError("the number of blocks in prior_stds does not match n_channels. This "
128
                                             "message is printed because use_specific_stats_for_channel is True.")
129
                        tmp_prior_stds = prior_stds[2 * channel:2 * channel + 2, :]
130
                    else:
131
                        tmp_prior_stds = prior_stds
132
                else:
133
                    tmp_prior_stds = prior_stds
134
                if (prior_stds is not None) & mix_prior_and_random & (npr.uniform() > 0.5):
135
                    tmp_prior_stds = None
136
137
                # draw means and std devs from priors
138
                tmp_classes_means = utils.draw_value_from_distribution(tmp_prior_means, n_classes, prior_distributions,
139
                                                                       125., 125., positive_only=True)
140
                tmp_classes_stds = utils.draw_value_from_distribution(tmp_prior_stds, n_classes, prior_distributions,
141
                                                                      15., 15., positive_only=True)
142
                random_coef = npr.uniform()
143
                if random_coef > 0.95:  # reset the background to 0 in 5% of cases
144
                    tmp_classes_means[0] = 0
145
                    tmp_classes_stds[0] = 0
146
                elif random_coef > 0.7:  # reset the background to low Gaussian in 25% of cases
147
                    tmp_classes_means[0] = npr.uniform(0, 15)
148
                    tmp_classes_stds[0] = npr.uniform(0, 5)
149
                tmp_means = utils.add_axis(tmp_classes_means[generation_classes], axis=[0, -1])
150
                tmp_stds = utils.add_axis(tmp_classes_stds[generation_classes], axis=[0, -1])
151
                means = np.concatenate([means, tmp_means], axis=-1)
152
                stds = np.concatenate([stds, tmp_stds], axis=-1)
153
            list_means.append(means)
154
            list_stds.append(stds)
155
156
        # build list of inputs for generation model
157
        list_inputs = [list_label_maps, list_means, list_stds]
158
        if batchsize > 1:  # concatenate each input type if batchsize > 1
159
            list_inputs = [np.concatenate(item, 0) for item in list_inputs]
160
        else:
161
            list_inputs = [item[0] for item in list_inputs]
162
163
        yield list_inputs