|
a |
|
b/ext/lab2im/edit_tensors.py |
|
|
1 |
""" |
|
|
2 |
|
|
|
3 |
This file contains functions to handle keras/tensorflow tensors. |
|
|
4 |
- blurring_sigma_for_downsampling |
|
|
5 |
- gaussian_kernel |
|
|
6 |
- resample_tensor |
|
|
7 |
- expand_dims |
|
|
8 |
|
|
|
9 |
|
|
|
10 |
If you use this code, please cite the first SynthSeg paper: |
|
|
11 |
https://github.com/BBillot/lab2im/blob/master/bibtex.bib |
|
|
12 |
|
|
|
13 |
Copyright 2020 Benjamin Billot |
|
|
14 |
|
|
|
15 |
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in |
|
|
16 |
compliance with the License. You may obtain a copy of the License at |
|
|
17 |
https://www.apache.org/licenses/LICENSE-2.0 |
|
|
18 |
Unless required by applicable law or agreed to in writing, software distributed under the License is |
|
|
19 |
distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or |
|
|
20 |
implied. See the License for the specific language governing permissions and limitations under the |
|
|
21 |
License. |
|
|
22 |
|
|
|
23 |
""" |
|
|
24 |
|
|
|
25 |
|
|
|
26 |
# python imports |
|
|
27 |
import numpy as np |
|
|
28 |
import tensorflow as tf |
|
|
29 |
import keras.layers as KL |
|
|
30 |
import keras.backend as K |
|
|
31 |
from itertools import combinations |
|
|
32 |
|
|
|
33 |
# project imports |
|
|
34 |
from ext.lab2im import utils |
|
|
35 |
|
|
|
36 |
# third-party imports |
|
|
37 |
import ext.neuron.layers as nrn_layers |
|
|
38 |
from ext.neuron.utils import volshape_to_meshgrid |
|
|
39 |
|
|
|
40 |
|
|
|
41 |
def blurring_sigma_for_downsampling(current_res, downsample_res, mult_coef=None, thickness=None): |
|
|
42 |
"""Compute standard deviations of 1d gaussian masks for image blurring before downsampling. |
|
|
43 |
:param downsample_res: resolution to downsample to. Can be a 1d numpy array or list, or a tensor. |
|
|
44 |
:param current_res: resolution of the volume before downsampling. |
|
|
45 |
Can be a 1d numpy array or list or tensor of the same length as downsample res. |
|
|
46 |
:param mult_coef: (optional) multiplicative coefficient for the blurring kernel. Default is 0.75. |
|
|
47 |
:param thickness: (optional) slice thickness in each dimension. Must be the same type as downsample_res. |
|
|
48 |
:return: standard deviation of the blurring masks given as the same type as downsample_res (list or tensor). |
|
|
49 |
""" |
|
|
50 |
|
|
|
51 |
if not tf.is_tensor(downsample_res): |
|
|
52 |
|
|
|
53 |
# get blurring resolution (min between downsample_res and thickness) |
|
|
54 |
current_res = np.array(current_res) |
|
|
55 |
downsample_res = np.array(downsample_res) |
|
|
56 |
if thickness is not None: |
|
|
57 |
downsample_res = np.minimum(downsample_res, np.array(thickness)) |
|
|
58 |
|
|
|
59 |
# get std deviation for blurring kernels |
|
|
60 |
if mult_coef is None: |
|
|
61 |
sigma = 0.75 * downsample_res / current_res |
|
|
62 |
sigma[downsample_res == current_res] = 0.5 |
|
|
63 |
else: |
|
|
64 |
sigma = mult_coef * downsample_res / current_res |
|
|
65 |
sigma[downsample_res == 0] = 0 |
|
|
66 |
|
|
|
67 |
else: |
|
|
68 |
|
|
|
69 |
# reformat data resolution at which we blur |
|
|
70 |
if thickness is not None: |
|
|
71 |
down_res = KL.Lambda(lambda x: tf.math.minimum(x[0], x[1]))([downsample_res, thickness]) |
|
|
72 |
else: |
|
|
73 |
down_res = downsample_res |
|
|
74 |
|
|
|
75 |
# get std deviation for blurring kernels |
|
|
76 |
if mult_coef is None: |
|
|
77 |
sigma = KL.Lambda(lambda x: tf.where(tf.math.equal(x, tf.convert_to_tensor(current_res, dtype='float32')), |
|
|
78 |
0.5, 0.75 * x / tf.convert_to_tensor(current_res, dtype='float32')))(down_res) |
|
|
79 |
else: |
|
|
80 |
sigma = KL.Lambda(lambda x: mult_coef * x / tf.convert_to_tensor(current_res, dtype='float32'))(down_res) |
|
|
81 |
sigma = KL.Lambda(lambda x: tf.where(tf.math.equal(x[0], 0.), 0., x[1]))([down_res, sigma]) |
|
|
82 |
|
|
|
83 |
return sigma |
|
|
84 |
|
|
|
85 |
|
|
|
86 |
def gaussian_kernel(sigma, max_sigma=None, blur_range=None, separable=True): |
|
|
87 |
"""Build gaussian kernels of the specified standard deviation. The outputs are given as tensorflow tensors. |
|
|
88 |
:param sigma: standard deviation of the tensors. Can be given as a list/numpy array or as tensors. In each case, |
|
|
89 |
sigma must have the same length as the number of dimensions of the volume that will be blurred with the output |
|
|
90 |
tensors (e.g. sigma must have 3 values for 3D volumes). |
|
|
91 |
:param max_sigma: |
|
|
92 |
:param blur_range: |
|
|
93 |
:param separable: |
|
|
94 |
:return: |
|
|
95 |
""" |
|
|
96 |
# convert sigma into a tensor |
|
|
97 |
if not tf.is_tensor(sigma): |
|
|
98 |
sigma_tens = tf.convert_to_tensor(utils.reformat_to_list(sigma), dtype='float32') |
|
|
99 |
else: |
|
|
100 |
assert max_sigma is not None, 'max_sigma must be provided when sigma is given as a tensor' |
|
|
101 |
sigma_tens = sigma |
|
|
102 |
shape = sigma_tens.get_shape().as_list() |
|
|
103 |
|
|
|
104 |
# get n_dims and batchsize |
|
|
105 |
if shape[0] is not None: |
|
|
106 |
n_dims = shape[0] |
|
|
107 |
batchsize = None |
|
|
108 |
else: |
|
|
109 |
n_dims = shape[1] |
|
|
110 |
batchsize = tf.split(tf.shape(sigma_tens), [1, -1])[0] |
|
|
111 |
|
|
|
112 |
# reformat max_sigma |
|
|
113 |
if max_sigma is not None: # dynamic blurring |
|
|
114 |
max_sigma = np.array(utils.reformat_to_list(max_sigma, length=n_dims)) |
|
|
115 |
else: # sigma is fixed |
|
|
116 |
max_sigma = np.array(utils.reformat_to_list(sigma, length=n_dims)) |
|
|
117 |
|
|
|
118 |
# randomise the burring std dev and/or split it between dimensions |
|
|
119 |
if blur_range is not None: |
|
|
120 |
if blur_range != 1: |
|
|
121 |
sigma_tens = sigma_tens * tf.random.uniform(tf.shape(sigma_tens), minval=1 / blur_range, maxval=blur_range) |
|
|
122 |
|
|
|
123 |
# get size of blurring kernels |
|
|
124 |
windowsize = np.int32(np.ceil(2.5 * max_sigma) / 2) * 2 + 1 |
|
|
125 |
|
|
|
126 |
if separable: |
|
|
127 |
|
|
|
128 |
split_sigma = tf.split(sigma_tens, [1] * n_dims, axis=-1) |
|
|
129 |
|
|
|
130 |
kernels = list() |
|
|
131 |
comb = np.array(list(combinations(list(range(n_dims)), n_dims - 1))[::-1]) |
|
|
132 |
for (i, wsize) in enumerate(windowsize): |
|
|
133 |
|
|
|
134 |
if wsize > 1: |
|
|
135 |
|
|
|
136 |
# build meshgrid and replicate it along batch dim if dynamic blurring |
|
|
137 |
locations = tf.cast(tf.range(0, wsize), 'float32') - (wsize - 1) / 2 |
|
|
138 |
if batchsize is not None: |
|
|
139 |
locations = tf.tile(tf.expand_dims(locations, axis=0), |
|
|
140 |
tf.concat([batchsize, tf.ones(tf.shape(tf.shape(locations)), dtype='int32')], |
|
|
141 |
axis=0)) |
|
|
142 |
comb[i] += 1 |
|
|
143 |
|
|
|
144 |
# compute gaussians |
|
|
145 |
exp_term = -K.square(locations) / (2 * split_sigma[i] ** 2) |
|
|
146 |
g = tf.exp(exp_term - tf.math.log(np.sqrt(2 * np.pi) * split_sigma[i])) |
|
|
147 |
g = g / tf.reduce_sum(g) |
|
|
148 |
|
|
|
149 |
for axis in comb[i]: |
|
|
150 |
g = tf.expand_dims(g, axis=axis) |
|
|
151 |
kernels.append(tf.expand_dims(tf.expand_dims(g, -1), -1)) |
|
|
152 |
|
|
|
153 |
else: |
|
|
154 |
kernels.append(None) |
|
|
155 |
|
|
|
156 |
else: |
|
|
157 |
|
|
|
158 |
# build meshgrid |
|
|
159 |
mesh = [tf.cast(f, 'float32') for f in volshape_to_meshgrid(windowsize, indexing='ij')] |
|
|
160 |
diff = tf.stack([mesh[f] - (windowsize[f] - 1) / 2 for f in range(len(windowsize))], axis=-1) |
|
|
161 |
|
|
|
162 |
# replicate meshgrid to batch size and reshape sigma_tens |
|
|
163 |
if batchsize is not None: |
|
|
164 |
diff = tf.tile(tf.expand_dims(diff, axis=0), |
|
|
165 |
tf.concat([batchsize, tf.ones(tf.shape(tf.shape(diff)), dtype='int32')], axis=0)) |
|
|
166 |
for i in range(n_dims): |
|
|
167 |
sigma_tens = tf.expand_dims(sigma_tens, axis=1) |
|
|
168 |
else: |
|
|
169 |
for i in range(n_dims): |
|
|
170 |
sigma_tens = tf.expand_dims(sigma_tens, axis=0) |
|
|
171 |
|
|
|
172 |
# compute gaussians |
|
|
173 |
sigma_is_0 = tf.equal(sigma_tens, 0) |
|
|
174 |
exp_term = -K.square(diff) / (2 * tf.where(sigma_is_0, tf.ones_like(sigma_tens), sigma_tens)**2) |
|
|
175 |
norms = exp_term - tf.math.log(tf.where(sigma_is_0, tf.ones_like(sigma_tens), np.sqrt(2 * np.pi) * sigma_tens)) |
|
|
176 |
kernels = K.sum(norms, -1) |
|
|
177 |
kernels = tf.exp(kernels) |
|
|
178 |
kernels /= tf.reduce_sum(kernels) |
|
|
179 |
kernels = tf.expand_dims(tf.expand_dims(kernels, -1), -1) |
|
|
180 |
|
|
|
181 |
return kernels |
|
|
182 |
|
|
|
183 |
|
|
|
184 |
def sobel_kernels(n_dims): |
|
|
185 |
"""Returns sobel kernels to compute spatial derivative on image of n dimensions.""" |
|
|
186 |
|
|
|
187 |
in_dir = tf.convert_to_tensor([1, 0, -1], dtype='float32') |
|
|
188 |
orthogonal_dir = tf.convert_to_tensor([1, 2, 1], dtype='float32') |
|
|
189 |
comb = np.array(list(combinations(list(range(n_dims)), n_dims - 1))[::-1]) |
|
|
190 |
|
|
|
191 |
list_kernels = list() |
|
|
192 |
for dim in range(n_dims): |
|
|
193 |
|
|
|
194 |
sublist_kernels = list() |
|
|
195 |
for axis in range(n_dims): |
|
|
196 |
|
|
|
197 |
kernel = in_dir if axis == dim else orthogonal_dir |
|
|
198 |
for i in comb[axis]: |
|
|
199 |
kernel = tf.expand_dims(kernel, axis=i) |
|
|
200 |
sublist_kernels.append(tf.expand_dims(tf.expand_dims(kernel, -1), -1)) |
|
|
201 |
|
|
|
202 |
list_kernels.append(sublist_kernels) |
|
|
203 |
|
|
|
204 |
return list_kernels |
|
|
205 |
|
|
|
206 |
|
|
|
207 |
def unit_kernel(dist_threshold, n_dims, max_dist_threshold=None): |
|
|
208 |
"""Build kernel with values of 1 for voxel at a distance < dist_threshold from the center, and 0 otherwise. |
|
|
209 |
The outputs are given as tensorflow tensors. |
|
|
210 |
:param dist_threshold: maximum distance from the center until voxel will have a value of 1. Can be a tensor of size |
|
|
211 |
(batch_size, 1), or a float. |
|
|
212 |
:param n_dims: dimension of the kernel to return (excluding batch and channel dimensions). |
|
|
213 |
:param max_dist_threshold: if distance_threshold is a tensor, max_dist_threshold must be given. It represents the |
|
|
214 |
maximum value that will be passed to dist_threshold. Must be a float. |
|
|
215 |
""" |
|
|
216 |
|
|
|
217 |
# convert dist_threshold into a tensor |
|
|
218 |
if not tf.is_tensor(dist_threshold): |
|
|
219 |
dist_threshold_tens = tf.convert_to_tensor(utils.reformat_to_list(dist_threshold), dtype='float32') |
|
|
220 |
else: |
|
|
221 |
assert max_dist_threshold is not None, 'max_sigma must be provided when dist_threshold is given as a tensor' |
|
|
222 |
dist_threshold_tens = tf.cast(dist_threshold, 'float32') |
|
|
223 |
shape = dist_threshold_tens.get_shape().as_list() |
|
|
224 |
|
|
|
225 |
# get batchsize |
|
|
226 |
batchsize = None if shape[0] is not None else tf.split(tf.shape(dist_threshold_tens), [1, -1])[0] |
|
|
227 |
|
|
|
228 |
# set max_dist_threshold into an array |
|
|
229 |
if max_dist_threshold is None: # dist_threshold is fixed (i.e. dist_threshold will not change at each mini-batch) |
|
|
230 |
max_dist_threshold = dist_threshold |
|
|
231 |
|
|
|
232 |
# get size of blurring kernels |
|
|
233 |
windowsize = np.array([max_dist_threshold * 2 + 1]*n_dims, dtype='int32') |
|
|
234 |
|
|
|
235 |
# build tensor representing the distance from the centre |
|
|
236 |
mesh = [tf.cast(f, 'float32') for f in volshape_to_meshgrid(windowsize, indexing='ij')] |
|
|
237 |
dist = tf.stack([mesh[f] - (windowsize[f] - 1) / 2 for f in range(len(windowsize))], axis=-1) |
|
|
238 |
dist = tf.sqrt(tf.reduce_sum(tf.square(dist), axis=-1)) |
|
|
239 |
|
|
|
240 |
# replicate distance to batch size and reshape sigma_tens |
|
|
241 |
if batchsize is not None: |
|
|
242 |
dist = tf.tile(tf.expand_dims(dist, axis=0), |
|
|
243 |
tf.concat([batchsize, tf.ones(tf.shape(tf.shape(dist)), dtype='int32')], axis=0)) |
|
|
244 |
for i in range(n_dims - 1): |
|
|
245 |
dist_threshold_tens = tf.expand_dims(dist_threshold_tens, axis=1) |
|
|
246 |
else: |
|
|
247 |
for i in range(n_dims - 1): |
|
|
248 |
dist_threshold_tens = tf.expand_dims(dist_threshold_tens, axis=0) |
|
|
249 |
|
|
|
250 |
# build final kernel by thresholding distance tensor |
|
|
251 |
kernel = tf.where(tf.less_equal(dist, dist_threshold_tens), tf.ones_like(dist), tf.zeros_like(dist)) |
|
|
252 |
kernel = tf.expand_dims(tf.expand_dims(kernel, -1), -1) |
|
|
253 |
|
|
|
254 |
return kernel |
|
|
255 |
|
|
|
256 |
|
|
|
257 |
def resample_tensor(tensor, |
|
|
258 |
resample_shape, |
|
|
259 |
interp_method='linear', |
|
|
260 |
subsample_res=None, |
|
|
261 |
volume_res=None, |
|
|
262 |
build_reliability_map=False): |
|
|
263 |
"""This function resamples a volume to resample_shape. It does not apply any pre-filtering. |
|
|
264 |
A prior downsampling step can be added if subsample_res is specified. In this case, volume_res should also be |
|
|
265 |
specified, in order to calculate the downsampling ratio. A reliability map can also be returned to indicate which |
|
|
266 |
slices were interpolated during resampling from the downsampled to final tensor. |
|
|
267 |
:param tensor: tensor |
|
|
268 |
:param resample_shape: list or numpy array of size (n_dims,) |
|
|
269 |
:param interp_method: (optional) interpolation method for resampling, 'linear' (default) or 'nearest' |
|
|
270 |
:param subsample_res: (optional) if not None, this triggers a downsampling of the volume, prior to the resampling |
|
|
271 |
step. List or numpy array of size (n_dims,). Default si None. |
|
|
272 |
:param volume_res: (optional) if subsample_res is not None, this should be provided to compute downsampling ratio. |
|
|
273 |
list or numpy array of size (n_dims,). Default is None. |
|
|
274 |
:param build_reliability_map: whether to return reliability map along with the resampled tensor. This map indicates |
|
|
275 |
which slices of the resampled tensor are interpolated (0=interpolated, 1=real slice, in between=degree of realness). |
|
|
276 |
:return: resampled volume, with reliability map if necessary. |
|
|
277 |
""" |
|
|
278 |
|
|
|
279 |
# reformat resolutions to lists |
|
|
280 |
subsample_res = utils.reformat_to_list(subsample_res) |
|
|
281 |
volume_res = utils.reformat_to_list(volume_res) |
|
|
282 |
n_dims = len(resample_shape) |
|
|
283 |
|
|
|
284 |
# downsample image |
|
|
285 |
tensor_shape = tensor.get_shape().as_list()[1:-1] |
|
|
286 |
downsample_shape = tensor_shape # will be modified if we actually downsample |
|
|
287 |
|
|
|
288 |
if subsample_res is not None: |
|
|
289 |
assert volume_res is not None, 'volume_res must be given when providing a subsampling resolution.' |
|
|
290 |
assert len(subsample_res) == len(volume_res), 'subsample_res and volume_res must have the same length, ' \ |
|
|
291 |
'had {0}, and {1}'.format(len(subsample_res), len(volume_res)) |
|
|
292 |
if subsample_res != volume_res: |
|
|
293 |
|
|
|
294 |
# get shape at which we downsample |
|
|
295 |
downsample_shape = [int(tensor_shape[i] * volume_res[i] / subsample_res[i]) for i in range(n_dims)] |
|
|
296 |
|
|
|
297 |
# downsample volume |
|
|
298 |
tensor._keras_shape = tuple(tensor.get_shape().as_list()) |
|
|
299 |
tensor = nrn_layers.Resize(size=downsample_shape, interp_method='nearest')(tensor) |
|
|
300 |
|
|
|
301 |
# resample image at target resolution |
|
|
302 |
if resample_shape != downsample_shape: # if we didn't downsample downsample_shape = tensor_shape |
|
|
303 |
tensor._keras_shape = tuple(tensor.get_shape().as_list()) |
|
|
304 |
tensor = nrn_layers.Resize(size=resample_shape, interp_method=interp_method)(tensor) |
|
|
305 |
|
|
|
306 |
# compute reliability maps if necessary and return results |
|
|
307 |
if build_reliability_map: |
|
|
308 |
|
|
|
309 |
# compute maps only if we downsampled |
|
|
310 |
if downsample_shape != tensor_shape: |
|
|
311 |
|
|
|
312 |
# compute upsampling factors |
|
|
313 |
upsampling_factors = np.array(resample_shape) / np.array(downsample_shape) |
|
|
314 |
|
|
|
315 |
# build reliability map |
|
|
316 |
reliability_map = 1 |
|
|
317 |
for i in range(n_dims): |
|
|
318 |
loc_float = np.arange(0, resample_shape[i], upsampling_factors[i]) |
|
|
319 |
loc_floor = np.int32(np.floor(loc_float)) |
|
|
320 |
loc_ceil = np.int32(np.clip(loc_floor + 1, 0, resample_shape[i] - 1)) |
|
|
321 |
tmp_reliability_map = np.zeros(resample_shape[i]) |
|
|
322 |
tmp_reliability_map[loc_floor] = 1 - (loc_float - loc_floor) |
|
|
323 |
tmp_reliability_map[loc_ceil] = tmp_reliability_map[loc_ceil] + (loc_float - loc_floor) |
|
|
324 |
shape = [1, 1, 1] |
|
|
325 |
shape[i] = resample_shape[i] |
|
|
326 |
reliability_map = reliability_map * np.reshape(tmp_reliability_map, shape) |
|
|
327 |
shape = KL.Lambda(lambda x: tf.shape(x))(tensor) |
|
|
328 |
mask = KL.Lambda(lambda x: tf.reshape(tf.convert_to_tensor(reliability_map, dtype='float32'), |
|
|
329 |
shape=x))(shape) |
|
|
330 |
|
|
|
331 |
# otherwise just return an all-one tensor |
|
|
332 |
else: |
|
|
333 |
mask = KL.Lambda(lambda x: tf.ones_like(x))(tensor) |
|
|
334 |
|
|
|
335 |
return tensor, mask |
|
|
336 |
|
|
|
337 |
else: |
|
|
338 |
return tensor |
|
|
339 |
|
|
|
340 |
|
|
|
341 |
def expand_dims(tensor, axis=0): |
|
|
342 |
"""Expand the dimensions of the input tensor along the provided axes (given as an integer or a list).""" |
|
|
343 |
axis = utils.reformat_to_list(axis) |
|
|
344 |
for ax in axis: |
|
|
345 |
tensor = tf.expand_dims(tensor, axis=ax) |
|
|
346 |
return tensor |