a b/ext/lab2im/edit_tensors.py
1
"""
2
3
This file contains functions to handle keras/tensorflow tensors.
4
    - blurring_sigma_for_downsampling
5
    - gaussian_kernel
6
    - resample_tensor
7
    - expand_dims
8
9
10
If you use this code, please cite the first SynthSeg paper:
11
https://github.com/BBillot/lab2im/blob/master/bibtex.bib
12
13
Copyright 2020 Benjamin Billot
14
15
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
16
compliance with the License. You may obtain a copy of the License at
17
https://www.apache.org/licenses/LICENSE-2.0
18
Unless required by applicable law or agreed to in writing, software distributed under the License is
19
distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
20
implied. See the License for the specific language governing permissions and limitations under the
21
License.
22
23
"""
24
25
26
# python imports
27
import numpy as np
28
import tensorflow as tf
29
import keras.layers as KL
30
import keras.backend as K
31
from itertools import combinations
32
33
# project imports
34
from ext.lab2im import utils
35
36
# third-party imports
37
import ext.neuron.layers as nrn_layers
38
from ext.neuron.utils import volshape_to_meshgrid
39
40
41
def blurring_sigma_for_downsampling(current_res, downsample_res, mult_coef=None, thickness=None):
42
    """Compute standard deviations of 1d gaussian masks for image blurring before downsampling.
43
    :param downsample_res: resolution to downsample to. Can be a 1d numpy array or list, or a tensor.
44
    :param current_res: resolution of the volume before downsampling.
45
    Can be a 1d numpy array or list or tensor of the same length as downsample res.
46
    :param mult_coef: (optional) multiplicative coefficient for the blurring kernel. Default is 0.75.
47
    :param thickness: (optional) slice thickness in each dimension. Must be the same type as downsample_res.
48
    :return: standard deviation of the blurring masks given as the same type as downsample_res (list or tensor).
49
    """
50
51
    if not tf.is_tensor(downsample_res):
52
53
        # get blurring resolution (min between downsample_res and thickness)
54
        current_res = np.array(current_res)
55
        downsample_res = np.array(downsample_res)
56
        if thickness is not None:
57
            downsample_res = np.minimum(downsample_res, np.array(thickness))
58
59
        # get std deviation for blurring kernels
60
        if mult_coef is None:
61
            sigma = 0.75 * downsample_res / current_res
62
            sigma[downsample_res == current_res] = 0.5
63
        else:
64
            sigma = mult_coef * downsample_res / current_res
65
        sigma[downsample_res == 0] = 0
66
67
    else:
68
69
        # reformat data resolution at which we blur
70
        if thickness is not None:
71
            down_res = KL.Lambda(lambda x: tf.math.minimum(x[0], x[1]))([downsample_res, thickness])
72
        else:
73
            down_res = downsample_res
74
75
        # get std deviation for blurring kernels
76
        if mult_coef is None:
77
            sigma = KL.Lambda(lambda x: tf.where(tf.math.equal(x, tf.convert_to_tensor(current_res, dtype='float32')),
78
                              0.5, 0.75 * x / tf.convert_to_tensor(current_res, dtype='float32')))(down_res)
79
        else:
80
            sigma = KL.Lambda(lambda x: mult_coef * x / tf.convert_to_tensor(current_res, dtype='float32'))(down_res)
81
        sigma = KL.Lambda(lambda x: tf.where(tf.math.equal(x[0], 0.), 0., x[1]))([down_res, sigma])
82
83
    return sigma
84
85
86
def gaussian_kernel(sigma, max_sigma=None, blur_range=None, separable=True):
87
    """Build gaussian kernels of the specified standard deviation. The outputs are given as tensorflow tensors.
88
    :param sigma: standard deviation of the tensors. Can be given as a list/numpy array or as tensors. In each case,
89
    sigma must have the same length as the number of dimensions of the volume that will be blurred with the output
90
    tensors (e.g. sigma must have 3 values for 3D volumes).
91
    :param max_sigma:
92
    :param blur_range:
93
    :param separable:
94
    :return:
95
    """
96
    # convert sigma into a tensor
97
    if not tf.is_tensor(sigma):
98
        sigma_tens = tf.convert_to_tensor(utils.reformat_to_list(sigma), dtype='float32')
99
    else:
100
        assert max_sigma is not None, 'max_sigma must be provided when sigma is given as a tensor'
101
        sigma_tens = sigma
102
    shape = sigma_tens.get_shape().as_list()
103
104
    # get n_dims and batchsize
105
    if shape[0] is not None:
106
        n_dims = shape[0]
107
        batchsize = None
108
    else:
109
        n_dims = shape[1]
110
        batchsize = tf.split(tf.shape(sigma_tens), [1, -1])[0]
111
112
    # reformat max_sigma
113
    if max_sigma is not None:  # dynamic blurring
114
        max_sigma = np.array(utils.reformat_to_list(max_sigma, length=n_dims))
115
    else:  # sigma is fixed
116
        max_sigma = np.array(utils.reformat_to_list(sigma, length=n_dims))
117
118
    # randomise the burring std dev and/or split it between dimensions
119
    if blur_range is not None:
120
        if blur_range != 1:
121
            sigma_tens = sigma_tens * tf.random.uniform(tf.shape(sigma_tens), minval=1 / blur_range, maxval=blur_range)
122
123
    # get size of blurring kernels
124
    windowsize = np.int32(np.ceil(2.5 * max_sigma) / 2) * 2 + 1
125
126
    if separable:
127
128
        split_sigma = tf.split(sigma_tens, [1] * n_dims, axis=-1)
129
130
        kernels = list()
131
        comb = np.array(list(combinations(list(range(n_dims)), n_dims - 1))[::-1])
132
        for (i, wsize) in enumerate(windowsize):
133
134
            if wsize > 1:
135
136
                # build meshgrid and replicate it along batch dim if dynamic blurring
137
                locations = tf.cast(tf.range(0, wsize), 'float32') - (wsize - 1) / 2
138
                if batchsize is not None:
139
                    locations = tf.tile(tf.expand_dims(locations, axis=0),
140
                                        tf.concat([batchsize, tf.ones(tf.shape(tf.shape(locations)), dtype='int32')],
141
                                                  axis=0))
142
                    comb[i] += 1
143
144
                # compute gaussians
145
                exp_term = -K.square(locations) / (2 * split_sigma[i] ** 2)
146
                g = tf.exp(exp_term - tf.math.log(np.sqrt(2 * np.pi) * split_sigma[i]))
147
                g = g / tf.reduce_sum(g)
148
149
                for axis in comb[i]:
150
                    g = tf.expand_dims(g, axis=axis)
151
                kernels.append(tf.expand_dims(tf.expand_dims(g, -1), -1))
152
153
            else:
154
                kernels.append(None)
155
156
    else:
157
158
        # build meshgrid
159
        mesh = [tf.cast(f, 'float32') for f in volshape_to_meshgrid(windowsize, indexing='ij')]
160
        diff = tf.stack([mesh[f] - (windowsize[f] - 1) / 2 for f in range(len(windowsize))], axis=-1)
161
162
        # replicate meshgrid to batch size and reshape sigma_tens
163
        if batchsize is not None:
164
            diff = tf.tile(tf.expand_dims(diff, axis=0),
165
                           tf.concat([batchsize, tf.ones(tf.shape(tf.shape(diff)), dtype='int32')], axis=0))
166
            for i in range(n_dims):
167
                sigma_tens = tf.expand_dims(sigma_tens, axis=1)
168
        else:
169
            for i in range(n_dims):
170
                sigma_tens = tf.expand_dims(sigma_tens, axis=0)
171
172
        # compute gaussians
173
        sigma_is_0 = tf.equal(sigma_tens, 0)
174
        exp_term = -K.square(diff) / (2 * tf.where(sigma_is_0, tf.ones_like(sigma_tens), sigma_tens)**2)
175
        norms = exp_term - tf.math.log(tf.where(sigma_is_0, tf.ones_like(sigma_tens), np.sqrt(2 * np.pi) * sigma_tens))
176
        kernels = K.sum(norms, -1)
177
        kernels = tf.exp(kernels)
178
        kernels /= tf.reduce_sum(kernels)
179
        kernels = tf.expand_dims(tf.expand_dims(kernels, -1), -1)
180
181
    return kernels
182
183
184
def sobel_kernels(n_dims):
185
    """Returns sobel kernels to compute spatial derivative on image of n dimensions."""
186
187
    in_dir = tf.convert_to_tensor([1, 0, -1], dtype='float32')
188
    orthogonal_dir = tf.convert_to_tensor([1, 2, 1], dtype='float32')
189
    comb = np.array(list(combinations(list(range(n_dims)), n_dims - 1))[::-1])
190
191
    list_kernels = list()
192
    for dim in range(n_dims):
193
194
        sublist_kernels = list()
195
        for axis in range(n_dims):
196
197
            kernel = in_dir if axis == dim else orthogonal_dir
198
            for i in comb[axis]:
199
                kernel = tf.expand_dims(kernel, axis=i)
200
            sublist_kernels.append(tf.expand_dims(tf.expand_dims(kernel, -1), -1))
201
202
        list_kernels.append(sublist_kernels)
203
204
    return list_kernels
205
206
207
def unit_kernel(dist_threshold, n_dims, max_dist_threshold=None):
208
    """Build kernel with values of 1 for voxel at a distance < dist_threshold from the center, and 0 otherwise.
209
    The outputs are given as tensorflow tensors.
210
    :param dist_threshold: maximum distance from the center until voxel will have a value of 1. Can be a tensor of size
211
    (batch_size, 1), or a float.
212
    :param n_dims: dimension of the kernel to return (excluding batch and channel dimensions).
213
    :param max_dist_threshold: if distance_threshold is a tensor, max_dist_threshold must be given. It represents the
214
    maximum value that will be passed to dist_threshold. Must be a float.
215
    """
216
217
    # convert dist_threshold into a tensor
218
    if not tf.is_tensor(dist_threshold):
219
        dist_threshold_tens = tf.convert_to_tensor(utils.reformat_to_list(dist_threshold), dtype='float32')
220
    else:
221
        assert max_dist_threshold is not None, 'max_sigma must be provided when dist_threshold is given as a tensor'
222
        dist_threshold_tens = tf.cast(dist_threshold, 'float32')
223
    shape = dist_threshold_tens.get_shape().as_list()
224
225
    # get batchsize
226
    batchsize = None if shape[0] is not None else tf.split(tf.shape(dist_threshold_tens), [1, -1])[0]
227
228
    # set max_dist_threshold into an array
229
    if max_dist_threshold is None:  # dist_threshold is fixed (i.e. dist_threshold will not change at each mini-batch)
230
        max_dist_threshold = dist_threshold
231
232
    # get size of blurring kernels
233
    windowsize = np.array([max_dist_threshold * 2 + 1]*n_dims, dtype='int32')
234
235
    # build tensor representing the distance from the centre
236
    mesh = [tf.cast(f, 'float32') for f in volshape_to_meshgrid(windowsize, indexing='ij')]
237
    dist = tf.stack([mesh[f] - (windowsize[f] - 1) / 2 for f in range(len(windowsize))], axis=-1)
238
    dist = tf.sqrt(tf.reduce_sum(tf.square(dist), axis=-1))
239
240
    # replicate distance to batch size and reshape sigma_tens
241
    if batchsize is not None:
242
        dist = tf.tile(tf.expand_dims(dist, axis=0),
243
                       tf.concat([batchsize, tf.ones(tf.shape(tf.shape(dist)), dtype='int32')], axis=0))
244
        for i in range(n_dims - 1):
245
            dist_threshold_tens = tf.expand_dims(dist_threshold_tens, axis=1)
246
    else:
247
        for i in range(n_dims - 1):
248
            dist_threshold_tens = tf.expand_dims(dist_threshold_tens, axis=0)
249
250
    # build final kernel by thresholding distance tensor
251
    kernel = tf.where(tf.less_equal(dist, dist_threshold_tens), tf.ones_like(dist), tf.zeros_like(dist))
252
    kernel = tf.expand_dims(tf.expand_dims(kernel, -1), -1)
253
254
    return kernel
255
256
257
def resample_tensor(tensor,
258
                    resample_shape,
259
                    interp_method='linear',
260
                    subsample_res=None,
261
                    volume_res=None,
262
                    build_reliability_map=False):
263
    """This function resamples a volume to resample_shape. It does not apply any pre-filtering.
264
    A prior downsampling step can be added if subsample_res is specified. In this case, volume_res should also be
265
    specified, in order to calculate the downsampling ratio. A reliability map can also be returned to indicate which
266
    slices were interpolated during resampling from the downsampled to final tensor.
267
    :param tensor: tensor
268
    :param resample_shape: list or numpy array of size (n_dims,)
269
    :param interp_method: (optional) interpolation method for resampling, 'linear' (default) or 'nearest'
270
    :param subsample_res: (optional) if not None, this triggers a downsampling of the volume, prior to the resampling
271
    step. List or numpy array of size (n_dims,). Default si None.
272
    :param volume_res: (optional) if subsample_res is not None, this should be provided to compute downsampling ratio.
273
    list or numpy array of size (n_dims,). Default is None.
274
    :param build_reliability_map: whether to return reliability map along with the resampled tensor. This map indicates
275
    which slices of the resampled tensor are interpolated (0=interpolated, 1=real slice, in between=degree of realness).
276
    :return: resampled volume, with reliability map if necessary.
277
    """
278
279
    # reformat resolutions to lists
280
    subsample_res = utils.reformat_to_list(subsample_res)
281
    volume_res = utils.reformat_to_list(volume_res)
282
    n_dims = len(resample_shape)
283
284
    # downsample image
285
    tensor_shape = tensor.get_shape().as_list()[1:-1]
286
    downsample_shape = tensor_shape  # will be modified if we actually downsample
287
288
    if subsample_res is not None:
289
        assert volume_res is not None, 'volume_res must be given when providing a subsampling resolution.'
290
        assert len(subsample_res) == len(volume_res), 'subsample_res and volume_res must have the same length, ' \
291
                                                      'had {0}, and {1}'.format(len(subsample_res), len(volume_res))
292
        if subsample_res != volume_res:
293
294
            # get shape at which we downsample
295
            downsample_shape = [int(tensor_shape[i] * volume_res[i] / subsample_res[i]) for i in range(n_dims)]
296
297
            # downsample volume
298
            tensor._keras_shape = tuple(tensor.get_shape().as_list())
299
            tensor = nrn_layers.Resize(size=downsample_shape, interp_method='nearest')(tensor)
300
301
    # resample image at target resolution
302
    if resample_shape != downsample_shape:  # if we didn't downsample downsample_shape = tensor_shape
303
        tensor._keras_shape = tuple(tensor.get_shape().as_list())
304
        tensor = nrn_layers.Resize(size=resample_shape, interp_method=interp_method)(tensor)
305
306
    # compute reliability maps if necessary and return results
307
    if build_reliability_map:
308
309
        # compute maps only if we downsampled
310
        if downsample_shape != tensor_shape:
311
312
            # compute upsampling factors
313
            upsampling_factors = np.array(resample_shape) / np.array(downsample_shape)
314
315
            # build reliability map
316
            reliability_map = 1
317
            for i in range(n_dims):
318
                loc_float = np.arange(0, resample_shape[i], upsampling_factors[i])
319
                loc_floor = np.int32(np.floor(loc_float))
320
                loc_ceil = np.int32(np.clip(loc_floor + 1, 0, resample_shape[i] - 1))
321
                tmp_reliability_map = np.zeros(resample_shape[i])
322
                tmp_reliability_map[loc_floor] = 1 - (loc_float - loc_floor)
323
                tmp_reliability_map[loc_ceil] = tmp_reliability_map[loc_ceil] + (loc_float - loc_floor)
324
                shape = [1, 1, 1]
325
                shape[i] = resample_shape[i]
326
                reliability_map = reliability_map * np.reshape(tmp_reliability_map, shape)
327
            shape = KL.Lambda(lambda x: tf.shape(x))(tensor)
328
            mask = KL.Lambda(lambda x: tf.reshape(tf.convert_to_tensor(reliability_map, dtype='float32'),
329
                                                  shape=x))(shape)
330
331
        # otherwise just return an all-one tensor
332
        else:
333
            mask = KL.Lambda(lambda x: tf.ones_like(x))(tensor)
334
335
        return tensor, mask
336
337
    else:
338
        return tensor
339
340
341
def expand_dims(tensor, axis=0):
342
    """Expand the dimensions of the input tensor along the provided axes (given as an integer or a list)."""
343
    axis = utils.reformat_to_list(axis)
344
    for ax in axis:
345
        tensor = tf.expand_dims(tensor, axis=ax)
346
    return tensor