|
a |
|
b/SynthSeg/model_inputs.py |
|
|
1 |
""" |
|
|
2 |
If you use this code, please cite one of the SynthSeg papers: |
|
|
3 |
https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib |
|
|
4 |
|
|
|
5 |
Copyright 2020 Benjamin Billot |
|
|
6 |
|
|
|
7 |
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in |
|
|
8 |
compliance with the License. You may obtain a copy of the License at |
|
|
9 |
https://www.apache.org/licenses/LICENSE-2.0 |
|
|
10 |
Unless required by applicable law or agreed to in writing, software distributed under the License is |
|
|
11 |
distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or |
|
|
12 |
implied. See the License for the specific language governing permissions and limitations under the |
|
|
13 |
License. |
|
|
14 |
""" |
|
|
15 |
|
|
|
16 |
|
|
|
17 |
# python imports |
|
|
18 |
import numpy as np |
|
|
19 |
import numpy.random as npr |
|
|
20 |
|
|
|
21 |
# third-party imports |
|
|
22 |
from ext.lab2im import utils |
|
|
23 |
|
|
|
24 |
|
|
|
25 |
def build_model_inputs(path_label_maps, |
|
|
26 |
n_labels, |
|
|
27 |
batchsize=1, |
|
|
28 |
n_channels=1, |
|
|
29 |
subjects_prob=None, |
|
|
30 |
generation_classes=None, |
|
|
31 |
prior_distributions='uniform', |
|
|
32 |
prior_means=None, |
|
|
33 |
prior_stds=None, |
|
|
34 |
use_specific_stats_for_channel=False, |
|
|
35 |
mix_prior_and_random=False): |
|
|
36 |
""" |
|
|
37 |
This function builds a generator that will be used to give the necessary inputs to the label_to_image model: the |
|
|
38 |
input label maps, as well as the means and stds defining the parameters of the GMM (which change at each minibatch). |
|
|
39 |
:param path_label_maps: list of the paths of the input label maps. |
|
|
40 |
:param n_labels: number of labels in the input label maps. |
|
|
41 |
:param batchsize: (optional) numbers of images to generate per mini-batch. Default is 1. |
|
|
42 |
:param n_channels: (optional) number of channels to be synthesised. Default is 1. |
|
|
43 |
:param subjects_prob: (optional) relative order of importance (doesn't have to be probabilistic), with which to pick |
|
|
44 |
the provided label maps at each minibatch. Must be a 1D numpy array, as long as path_label_maps. |
|
|
45 |
:param generation_classes: (optional) Indices regrouping generation labels into classes of same intensity |
|
|
46 |
distribution. Regrouped labels will thus share the same Gaussian when sampling a new image. Can be a sequence or a |
|
|
47 |
1d numpy array. It should have the same length as generation_labels, and contain values between 0 and K-1, where K |
|
|
48 |
is the total number of classes. Default is all labels have different classes. |
|
|
49 |
:param prior_distributions: (optional) type of distribution from which we sample the GMM parameters. |
|
|
50 |
Can either be 'uniform', or 'normal'. Default is 'uniform'. |
|
|
51 |
:param prior_means: (optional) hyperparameters controlling the prior distributions of the GMM means. Because |
|
|
52 |
these prior distributions are uniform or normal, they require by 2 hyperparameters. Thus prior_means can be: |
|
|
53 |
1) a sequence of length 2, directly defining the two hyperparameters: [min, max] if prior_distributions is |
|
|
54 |
uniform, [mean, std] if the distribution is normal. The GMM means of are independently sampled at each |
|
|
55 |
mini_batch from the same distribution. |
|
|
56 |
2) an array of shape (2, K), where K is the number of classes (K=len(generation_labels) if generation_classes is |
|
|
57 |
not given). The mean of the Gaussian distribution associated to class k in [0, ...K-1] is sampled at each mini-batch |
|
|
58 |
from U(prior_means[0,k], prior_means[1,k]) if prior_distributions is uniform, or from |
|
|
59 |
N(prior_means[0,k], prior_means[1,k]) if prior_distributions is normal. |
|
|
60 |
3) an array of shape (2*n_mod, K), where each block of two rows is associated to hyperparameters derived |
|
|
61 |
from different modalities. In this case, if use_specific_stats_for_channel is False, we first randomly select a |
|
|
62 |
modality from the n_mod possibilities, and we sample the GMM means like in 2). |
|
|
63 |
If use_specific_stats_for_channel is True, each block of two rows correspond to a different channel |
|
|
64 |
(n_mod=n_channels), thus we select the corresponding block to each channel rather than randomly drawing it. |
|
|
65 |
4) the path to such a numpy array. |
|
|
66 |
Default is None, which corresponds to prior_means = [25, 225]. |
|
|
67 |
:param prior_stds: (optional) same as prior_means but for the standard deviations of the GMM. |
|
|
68 |
Default is None, which corresponds to prior_stds = [5, 25]. |
|
|
69 |
:param use_specific_stats_for_channel: (optional) whether the i-th block of two rows in the prior arrays must be |
|
|
70 |
only used to generate the i-th channel. If True, n_mod should be equal to n_channels. Default is False. |
|
|
71 |
:param mix_prior_and_random: (optional) if prior_means is not None, enables to reset the priors to their default |
|
|
72 |
values for half of these cases, and thus generate images of random contrast. |
|
|
73 |
""" |
|
|
74 |
|
|
|
75 |
# allocate unique class to each label if generation classes is not given |
|
|
76 |
if generation_classes is None: |
|
|
77 |
generation_classes = np.arange(n_labels) |
|
|
78 |
n_classes = len(np.unique(generation_classes)) |
|
|
79 |
|
|
|
80 |
# make sure subjects_prob sums to 1 |
|
|
81 |
subjects_prob = utils.load_array_if_path(subjects_prob) |
|
|
82 |
if subjects_prob is not None: |
|
|
83 |
subjects_prob /= np.sum(subjects_prob) |
|
|
84 |
|
|
|
85 |
# Generate! |
|
|
86 |
while True: |
|
|
87 |
|
|
|
88 |
# randomly pick as many images as batchsize |
|
|
89 |
indices = npr.choice(np.arange(len(path_label_maps)), size=batchsize, p=subjects_prob) |
|
|
90 |
|
|
|
91 |
# initialise input lists |
|
|
92 |
list_label_maps = [] |
|
|
93 |
list_means = [] |
|
|
94 |
list_stds = [] |
|
|
95 |
|
|
|
96 |
for idx in indices: |
|
|
97 |
|
|
|
98 |
# load input label map |
|
|
99 |
lab = utils.load_volume(path_label_maps[idx], dtype='int', aff_ref=np.eye(4)) |
|
|
100 |
if (npr.uniform() > 0.7) & ('seg_cerebral' in path_label_maps[idx]): |
|
|
101 |
lab[lab == 24] = 0 |
|
|
102 |
|
|
|
103 |
# add label map to inputs |
|
|
104 |
list_label_maps.append(utils.add_axis(lab, axis=[0, -1])) |
|
|
105 |
|
|
|
106 |
# add means and standard deviations to inputs |
|
|
107 |
means = np.empty((1, n_labels, 0)) |
|
|
108 |
stds = np.empty((1, n_labels, 0)) |
|
|
109 |
for channel in range(n_channels): |
|
|
110 |
|
|
|
111 |
# retrieve channel specific stats if necessary |
|
|
112 |
if isinstance(prior_means, np.ndarray): |
|
|
113 |
if (prior_means.shape[0] > 2) & use_specific_stats_for_channel: |
|
|
114 |
if prior_means.shape[0] / 2 != n_channels: |
|
|
115 |
raise ValueError("the number of blocks in prior_means does not match n_channels. This " |
|
|
116 |
"message is printed because use_specific_stats_for_channel is True.") |
|
|
117 |
tmp_prior_means = prior_means[2 * channel:2 * channel + 2, :] |
|
|
118 |
else: |
|
|
119 |
tmp_prior_means = prior_means |
|
|
120 |
else: |
|
|
121 |
tmp_prior_means = prior_means |
|
|
122 |
if (prior_means is not None) & mix_prior_and_random & (npr.uniform() > 0.5): |
|
|
123 |
tmp_prior_means = None |
|
|
124 |
if isinstance(prior_stds, np.ndarray): |
|
|
125 |
if (prior_stds.shape[0] > 2) & use_specific_stats_for_channel: |
|
|
126 |
if prior_stds.shape[0] / 2 != n_channels: |
|
|
127 |
raise ValueError("the number of blocks in prior_stds does not match n_channels. This " |
|
|
128 |
"message is printed because use_specific_stats_for_channel is True.") |
|
|
129 |
tmp_prior_stds = prior_stds[2 * channel:2 * channel + 2, :] |
|
|
130 |
else: |
|
|
131 |
tmp_prior_stds = prior_stds |
|
|
132 |
else: |
|
|
133 |
tmp_prior_stds = prior_stds |
|
|
134 |
if (prior_stds is not None) & mix_prior_and_random & (npr.uniform() > 0.5): |
|
|
135 |
tmp_prior_stds = None |
|
|
136 |
|
|
|
137 |
# draw means and std devs from priors |
|
|
138 |
tmp_classes_means = utils.draw_value_from_distribution(tmp_prior_means, n_classes, prior_distributions, |
|
|
139 |
125., 125., positive_only=True) |
|
|
140 |
tmp_classes_stds = utils.draw_value_from_distribution(tmp_prior_stds, n_classes, prior_distributions, |
|
|
141 |
15., 15., positive_only=True) |
|
|
142 |
random_coef = npr.uniform() |
|
|
143 |
if random_coef > 0.95: # reset the background to 0 in 5% of cases |
|
|
144 |
tmp_classes_means[0] = 0 |
|
|
145 |
tmp_classes_stds[0] = 0 |
|
|
146 |
elif random_coef > 0.7: # reset the background to low Gaussian in 25% of cases |
|
|
147 |
tmp_classes_means[0] = npr.uniform(0, 15) |
|
|
148 |
tmp_classes_stds[0] = npr.uniform(0, 5) |
|
|
149 |
tmp_means = utils.add_axis(tmp_classes_means[generation_classes], axis=[0, -1]) |
|
|
150 |
tmp_stds = utils.add_axis(tmp_classes_stds[generation_classes], axis=[0, -1]) |
|
|
151 |
means = np.concatenate([means, tmp_means], axis=-1) |
|
|
152 |
stds = np.concatenate([stds, tmp_stds], axis=-1) |
|
|
153 |
list_means.append(means) |
|
|
154 |
list_stds.append(stds) |
|
|
155 |
|
|
|
156 |
# build list of inputs for generation model |
|
|
157 |
list_inputs = [list_label_maps, list_means, list_stds] |
|
|
158 |
if batchsize > 1: # concatenate each input type if batchsize > 1 |
|
|
159 |
list_inputs = [np.concatenate(item, 0) for item in list_inputs] |
|
|
160 |
else: |
|
|
161 |
list_inputs = [item[0] for item in list_inputs] |
|
|
162 |
|
|
|
163 |
yield list_inputs |