Diff of /ext/lab2im/layers.py [000000] .. [e571d1]

Switch to unified view

a b/ext/lab2im/layers.py
1
"""
2
This file regroups several custom keras layers used in the generation model:
3
    - RandomSpatialDeformation,
4
    - RandomCrop,
5
    - RandomFlip,
6
    - SampleConditionalGMM,
7
    - SampleResolution,
8
    - GaussianBlur,
9
    - DynamicGaussianBlur,
10
    - MimicAcquisition,
11
    - BiasFieldCorruption,
12
    - IntensityAugmentation,
13
    - DiceLoss,
14
    - WeightedL2Loss,
15
    - ResetValuesToZero,
16
    - ConvertLabels,
17
    - PadAroundCentre,
18
    - MaskEdges
19
    - ImageGradients
20
    - RandomDilationErosion
21
22
23
If you use this code, please cite the first SynthSeg paper:
24
https://github.com/BBillot/lab2im/blob/master/bibtex.bib
25
26
Copyright 2020 Benjamin Billot
27
28
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
29
compliance with the License. You may obtain a copy of the License at
30
https://www.apache.org/licenses/LICENSE-2.0
31
Unless required by applicable law or agreed to in writing, software distributed under the License is
32
distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
33
implied. See the License for the specific language governing permissions and limitations under the
34
License.
35
"""
36
37
38
# python imports
39
import keras
40
import numpy as np
41
import tensorflow as tf
42
import keras.backend as K
43
from keras.layers import Layer
44
45
# project imports
46
from ext.lab2im import utils
47
from ext.lab2im import edit_tensors as l2i_et
48
49
# third-party imports
50
from ext.neuron import utils as nrn_utils
51
import ext.neuron.layers as nrn_layers
52
53
54
class RandomSpatialDeformation(Layer):
55
    """This layer spatially deforms one or several tensors with a combination of affine and elastic transformations.
56
    The input tensors are expected to have the same shape [batchsize, shape_dim1, ..., shape_dimn, channel].
57
    The non-linear deformation is obtained by:
58
    1) a small-size SVF is sampled from a centred normal distribution of random standard deviation.
59
    2) it is resized with trilinear interpolation to half the shape of the input tensor
60
    3) it is integrated to obtain a diffeomorphic transformation
61
    4) finally, it is resized (again with trilinear interpolation) to full image size
62
    :param scaling_bounds: (optional) range of the random scaling to apply. The scaling factor for each dimension is
63
    sampled from a uniform distribution of predefined bounds. Can either be:
64
    1) a number, in which case the scaling factor is independently sampled from the uniform distribution of bounds
65
    [1-scaling_bounds, 1+scaling_bounds] for each dimension.
66
    2) a sequence, in which case the scaling factor is sampled from the uniform distribution of bounds
67
    (1-scaling_bounds[i], 1+scaling_bounds[i]) for the i-th dimension.
68
    3) a numpy array of shape (2, n_dims), in which case the scaling factor is sampled from the uniform distribution
69
     of bounds (scaling_bounds[0, i], scaling_bounds[1, i]) for the i-th dimension.
70
    4) False, in which case scaling is completely turned off.
71
    Default is scaling_bounds = 0.15 (case 1)
72
    :param rotation_bounds: (optional) same as scaling bounds but for the rotation angle, except that for cases 1
73
    and 2, the bounds are centred on 0 rather than 1, i.e. [0+rotation_bounds[i], 0-rotation_bounds[i]].
74
    Default is rotation_bounds = 15.
75
    :param shearing_bounds: (optional) same as scaling bounds. Default is shearing_bounds = 0.012.
76
    :param translation_bounds: (optional) same as scaling bounds. Default is translation_bounds = False, but we
77
    encourage using it when cropping is deactivated (i.e. when output_shape=None in BrainGenerator).
78
    :param enable_90_rotations: (optional) whether to rotate the input by a random angle chosen in {0, 90, 180, 270}.
79
    This is done regardless of the value of rotation_bounds. If true, a different value is sampled for each dimension.
80
    :param nonlin_std: (optional) maximum value of the standard deviation of the normal distribution from which we
81
    sample the small-size SVF. Set to 0 if you wish to completely turn the elastic deformation off.
82
    :param nonlin_scale: (optional) if nonlin_std is not False, factor between the shapes of the input tensor
83
    and the shape of the input non-linear tensor.
84
    :param inter_method: (optional) interpolation method when deforming the input tensor. Can be 'linear', or 'nearest'
85
    :param prob_deform: (optional) probability to apply spatial deformation
86
    """
87
88
    def __init__(self,
89
                 scaling_bounds=0.15,
90
                 rotation_bounds=10,
91
                 shearing_bounds=0.02,
92
                 translation_bounds=False,
93
                 enable_90_rotations=False,
94
                 nonlin_std=4.,
95
                 nonlin_scale=.0625,
96
                 inter_method='linear',
97
                 prob_deform=1,
98
                 **kwargs):
99
100
        # shape attributes
101
        self.n_inputs = 1
102
        self.inshape = None
103
        self.n_dims = None
104
        self.small_shape = None
105
106
        # deformation attributes
107
        self.scaling_bounds = scaling_bounds
108
        self.rotation_bounds = rotation_bounds
109
        self.shearing_bounds = shearing_bounds
110
        self.translation_bounds = translation_bounds
111
        self.enable_90_rotations = enable_90_rotations
112
        self.nonlin_std = nonlin_std
113
        self.nonlin_scale = nonlin_scale
114
115
        # boolean attributes
116
        self.apply_affine_trans = (self.scaling_bounds is not False) | (self.rotation_bounds is not False) | \
117
                                  (self.shearing_bounds is not False) | (self.translation_bounds is not False) | \
118
                                  self.enable_90_rotations
119
        self.apply_elastic_trans = self.nonlin_std > 0
120
        self.prob_deform = prob_deform
121
122
        # interpolation methods
123
        self.inter_method = inter_method
124
125
        super(RandomSpatialDeformation, self).__init__(**kwargs)
126
127
    def get_config(self):
128
        config = super().get_config()
129
        config["scaling_bounds"] = self.scaling_bounds
130
        config["rotation_bounds"] = self.rotation_bounds
131
        config["shearing_bounds"] = self.shearing_bounds
132
        config["translation_bounds"] = self.translation_bounds
133
        config["enable_90_rotations"] = self.enable_90_rotations
134
        config["nonlin_std"] = self.nonlin_std
135
        config["nonlin_scale"] = self.nonlin_scale
136
        config["inter_method"] = self.inter_method
137
        config["prob_deform"] = self.prob_deform
138
        return config
139
140
    def build(self, input_shape):
141
142
        if not isinstance(input_shape, list):
143
            inputshape = [input_shape]
144
        else:
145
            self.n_inputs = len(input_shape)
146
            inputshape = input_shape
147
        self.inshape = inputshape[0][1:]
148
        self.n_dims = len(self.inshape) - 1
149
150
        if self.apply_elastic_trans:
151
            self.small_shape = utils.get_resample_shape(self.inshape[:self.n_dims],
152
                                                        self.nonlin_scale, self.n_dims)
153
        else:
154
            self.small_shape = None
155
156
        self.inter_method = utils.reformat_to_list(self.inter_method, length=self.n_inputs, dtype='str')
157
158
        self.built = True
159
        super(RandomSpatialDeformation, self).build(input_shape)
160
161
    def call(self, inputs, **kwargs):
162
163
        # reformat inputs and get its shape
164
        if self.n_inputs < 2:
165
            inputs = [inputs]
166
        types = [v.dtype for v in inputs]
167
        inputs = [tf.cast(v, dtype='float32') for v in inputs]
168
        batchsize = tf.split(tf.shape(inputs[0]), [1, self.n_dims + 1])[0]
169
170
        # initialise list of transforms to operate
171
        list_trans = list()
172
173
        # add affine deformation to inputs list
174
        if self.apply_affine_trans:
175
            affine_trans = utils.sample_affine_transform(batchsize,
176
                                                         self.n_dims,
177
                                                         self.rotation_bounds,
178
                                                         self.scaling_bounds,
179
                                                         self.shearing_bounds,
180
                                                         self.translation_bounds,
181
                                                         self.enable_90_rotations)
182
            list_trans.append(affine_trans)
183
184
        # prepare non-linear deformation field and add it to inputs list
185
        if self.apply_elastic_trans:
186
187
            # sample small field from normal distribution of specified std dev
188
            trans_shape = tf.concat([batchsize, tf.convert_to_tensor(self.small_shape, dtype='int32')], axis=0)
189
            trans_std = tf.random.uniform((1, 1), maxval=self.nonlin_std)
190
            elastic_trans = tf.random.normal(trans_shape, stddev=trans_std)
191
192
            # reshape this field to half size (for smoother SVF), integrate it, and reshape to full image size
193
            resize_shape = [max(int(self.inshape[i] / 2), self.small_shape[i]) for i in range(self.n_dims)]
194
            elastic_trans = nrn_layers.Resize(size=resize_shape, interp_method='linear')(elastic_trans)
195
            elastic_trans = nrn_layers.VecInt()(elastic_trans)
196
            elastic_trans = nrn_layers.Resize(size=self.inshape[:self.n_dims], interp_method='linear')(elastic_trans)
197
            list_trans.append(elastic_trans)
198
199
        # apply deformations and return tensors with correct dtype
200
        if self.apply_affine_trans | self.apply_elastic_trans:
201
            if self.prob_deform == 1:
202
                inputs = [nrn_layers.SpatialTransformer(m)([v] + list_trans) for (m, v) in
203
                          zip(self.inter_method, inputs)]
204
            else:
205
                rand_trans = tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_deform))
206
                inputs = [K.switch(rand_trans, nrn_layers.SpatialTransformer(m)([v] + list_trans), v)
207
                          for (m, v) in zip(self.inter_method, inputs)]
208
        if self.n_inputs < 2:
209
            return tf.cast(inputs[0], types[0])
210
        else:
211
            return [tf.cast(v, t) for (t, v) in zip(types, inputs)]
212
213
214
class RandomCrop(Layer):
215
    """Randomly crop all input tensors to a given shape. This cropping is applied to all channels.
216
    The input tensors are expected to have shape [batchsize, shape_dim1, ..., shape_dimn, channel].
217
    :param crop_shape: list with cropping shape in each dimension (excluding batch and channel dimension)
218
219
    example:
220
    if input is a tensor of shape [batchsize, 160, 160, 160, 3],
221
    output = RandomCrop(crop_shape=[96, 128, 96])(input)
222
    will yield an output of shape [batchsize, 96, 128, 96, 3] that is obtained by cropping with randomly selected
223
    cropping indices.
224
    """
225
226
    def __init__(self, crop_shape, **kwargs):
227
228
        self.several_inputs = True
229
        self.crop_max_val = None
230
        self.crop_shape = crop_shape
231
        self.n_dims = len(crop_shape)
232
        self.list_n_channels = None
233
        super(RandomCrop, self).__init__(**kwargs)
234
235
    def get_config(self):
236
        config = super().get_config()
237
        config["crop_shape"] = self.crop_shape
238
        return config
239
240
    def build(self, input_shape):
241
242
        if not isinstance(input_shape, list):
243
            self.several_inputs = False
244
            inputshape = [input_shape]
245
        else:
246
            inputshape = input_shape
247
        self.crop_max_val = np.array(np.array(inputshape[0][1:self.n_dims + 1])) - np.array(self.crop_shape)
248
        self.list_n_channels = [i[-1] for i in inputshape]
249
        self.built = True
250
        super(RandomCrop, self).build(input_shape)
251
252
    def call(self, inputs, **kwargs):
253
254
        # if one input only is provided, performs the cropping directly
255
        if not self.several_inputs:
256
            return tf.map_fn(self._single_slice, inputs, dtype=inputs.dtype)
257
258
        # otherwise we concatenate all inputs before cropping, so that they are all cropped at the same location
259
        else:
260
            types = [v.dtype for v in inputs]
261
            inputs = tf.concat([tf.cast(v, 'float32') for v in inputs], axis=-1)
262
            inputs = tf.map_fn(self._single_slice, inputs, dtype=tf.float32)
263
            inputs = tf.split(inputs, self.list_n_channels, axis=-1)
264
            return [tf.cast(v, t) for (t, v) in zip(types, inputs)]
265
266
    def _single_slice(self, vol):
267
        crop_idx = tf.cast(tf.random.uniform([self.n_dims], 0, np.array(self.crop_max_val), 'float32'), dtype='int32')
268
        crop_idx = tf.concat([crop_idx, tf.zeros([1], dtype='int32')], axis=0)
269
        crop_size = tf.convert_to_tensor(self.crop_shape + [-1], dtype='int32')
270
        return tf.slice(vol, begin=crop_idx, size=crop_size)
271
272
    def compute_output_shape(self, input_shape):
273
        output_shape = [tuple([None] + self.crop_shape + [v]) for v in self.list_n_channels]
274
        return output_shape if self.several_inputs else output_shape[0]
275
276
277
class RandomFlip(Layer):
278
    """This layer randomly flips the input tensor along the specified axes with a specified probability.
279
    It can also take multiple tensors as inputs (if they have the same shape). The same flips will be applied to all
280
    input tensors. These are expected to have shape [batchsize, shape_dim1, ..., shape_dimn, channel].
281
    If specified, this layer can also swap corresponding values. This is especially useful when flipping label maps
282
    with different labels for right/left structures, such that the flipped label maps keep a consistent labelling.
283
    :param axis: integer, or list of integers specifying the dimensions along which to flip.
284
    If a list, the input tensors can be flipped simultaneously in several directions. The values in flip_axis exclude
285
    the batch dimension (e.g. 0 will flip the tensor along the first axis after the batch dimension).
286
    Default is None, where the tensors can be flipped along all axes (except batch and channel axes).
287
    :param swap_labels: boolean to specify whether to swap the values of each input. Values are only swapped if an odd
288
    number of flips is applied.
289
    Can also be a list if several tensors are given as input.
290
    All the inputs for which the values need to be swapped must be int32 or int64.
291
    :param label_list: if swap_labels is True, list of all labels contained in labels. Must be ordered as follows, first
292
     the neutral labels (i.e. non-sided), then left labels and right labels.
293
    :param n_neutral_labels: if swap_labels is True, number of non-sided labels
294
    :param prob: probability to flip along each specified axis
295
296
    example 1:
297
    if input is a tensor of shape (batchsize, 10, 100, 200, 3)
298
    output = RandomFlip()(input) will randomly flip input along one of the 1st, 2nd, or 3rd axis (i.e. those with shape
299
    10, 100, 200).
300
301
    example 2:
302
    if input is a tensor of shape (batchsize, 10, 100, 200, 3)
303
    output = RandomFlip(flip_axis=1)(input) will randomly flip input along the 3rd axis (with shape 100), i.e. the axis
304
    with index 1 if we don't count the batch axis.
305
306
    example 3:
307
    input = tf.convert_to_tensor(np.array([[1, 0, 0, 0, 0, 0, 0],
308
                                           [1, 0, 0, 0, 2, 2, 0],
309
                                           [1, 0, 0, 0, 2, 2, 0],
310
                                           [1, 0, 0, 0, 2, 2, 0],
311
                                           [1, 0, 0, 0, 0, 0, 0]]))
312
    label_list = np.array([0, 1, 2])
313
    n_neutral_labels = 1
314
    output = RandomFlip(flip_axis=1, swap_labels=True, label_list=label_list, n_neutral_labels=n_neutral_labels)(input)
315
    where output will either be equal to input (bear in mind the flipping occurs with a 0.5 probability), or:
316
    output = [[0, 0, 0, 0, 0, 0, 2],
317
              [0, 1, 1, 0, 0, 0, 2],
318
              [0, 1, 1, 0, 0, 0, 2],
319
              [0, 1, 1, 0, 0, 0, 2],
320
              [0, 0, 0, 0, 0, 0, 2]]
321
    Note that the input must have a dtype int32 or int64 for its values to be swapped, otherwise an error will be raised
322
323
    example 4:
324
    if labels is the same as in the input of example 3, and image is a float32 image, then we can swap consistently both
325
    the labels and the image with:
326
    labels, image = RandomFlip(flip_axis=1, swap_labels=[True, False], label_list=label_list,
327
                               n_neutral_labels=n_neutral_labels)([labels, image]])
328
    Note that the labels must have a dtype int32 or int64 to be swapped, otherwise an error will be raised.
329
    This doesn't concern the image input, as its values are not swapped.
330
    """
331
332
    def __init__(self, axis=None, swap_labels=False, label_list=None, n_neutral_labels=None, prob=0.5, **kwargs):
333
334
        # shape attributes
335
        self.several_inputs = True
336
        self.n_dims = None
337
        self.list_n_channels = None
338
339
        # axis along which to flip
340
        self.axis = utils.reformat_to_list(axis)
341
        self.flip_axes = None
342
343
        # whether to swap labels, and corresponding label list
344
        self.swap_labels = utils.reformat_to_list(swap_labels)
345
        self.label_list = label_list
346
        self.n_neutral_labels = n_neutral_labels
347
        self.swap_lut = None
348
349
        self.prob = prob
350
351
        super(RandomFlip, self).__init__(**kwargs)
352
353
    def get_config(self):
354
        config = super().get_config()
355
        config["axis"] = self.axis
356
        config["swap_labels"] = self.swap_labels
357
        config["label_list"] = self.label_list
358
        config["n_neutral_labels"] = self.n_neutral_labels
359
        config["prob"] = self.prob
360
        return config
361
362
    def build(self, input_shape):
363
364
        if not isinstance(input_shape, list):
365
            self.several_inputs = False
366
            inputshape = [input_shape]
367
        else:
368
            inputshape = input_shape
369
        self.n_dims = len(inputshape[0][1:-1])
370
        self.list_n_channels = [i[-1] for i in inputshape]
371
        self.swap_labels = utils.reformat_to_list(self.swap_labels, length=len(inputshape))
372
        self.flip_axes = np.arange(self.n_dims).tolist() if self.axis is None else self.axis
373
374
        # create label list with swapped labels
375
        if any(self.swap_labels):
376
            assert (self.label_list is not None) & (self.n_neutral_labels is not None), \
377
                'please provide a label_list, and n_neutral_labels when swapping the values of at least one input'
378
            n_labels = len(self.label_list)
379
            if self.n_neutral_labels == n_labels:
380
                self.swap_labels = [False] * len(self.swap_labels)
381
            else:
382
                rl_split = np.split(self.label_list, [self.n_neutral_labels,
383
                                                      self.n_neutral_labels + int((n_labels-self.n_neutral_labels)/2)])
384
                label_list_swap = np.concatenate((rl_split[0], rl_split[2], rl_split[1]))
385
                swap_lut = utils.get_mapping_lut(self.label_list, label_list_swap)
386
                self.swap_lut = tf.convert_to_tensor(swap_lut, dtype='int32')
387
388
        self.built = True
389
        super(RandomFlip, self).build(input_shape)
390
391
    def call(self, inputs, **kwargs):
392
393
        # convert inputs to list, and get each input type
394
        inputs = [inputs] if not self.several_inputs else inputs
395
        types = [v.dtype for v in inputs]
396
397
        # store whether to flip along each specified dimension
398
        batchsize = tf.split(tf.shape(inputs[0]), [1, self.n_dims + 1])[0]
399
        size = tf.concat([batchsize, len(self.flip_axes) * tf.ones(1, dtype='int32')], axis=0)
400
        rand_flip = K.less(tf.random.uniform(size, 0, 1), self.prob)
401
402
        # swap right/left labels if we apply an odd number of flips
403
        odd = tf.math.floormod(tf.reduce_sum(tf.cast(rand_flip, 'int32'), -1, keepdims=True), 2) != 0
404
        swapped_inputs = list()
405
        for i in range(len(inputs)):
406
            if self.swap_labels[i]:
407
                swapped_inputs.append(tf.map_fn(self._single_swap, [inputs[i], odd], dtype=types[i]))
408
            else:
409
                swapped_inputs.append(inputs[i])
410
411
        # flip inputs and convert them back to their original type
412
        inputs = tf.concat([tf.cast(v, 'float32') for v in swapped_inputs], axis=-1)
413
        inputs = tf.map_fn(self._single_flip, [inputs, rand_flip], dtype=tf.float32)
414
        inputs = tf.split(inputs, self.list_n_channels, axis=-1)
415
416
        if self.several_inputs:
417
            return [tf.cast(v, t) for (t, v) in zip(types, inputs)]
418
        else:
419
            return tf.cast(inputs[0], types[0])
420
421
    def _single_swap(self, inputs):
422
        return K.switch(inputs[1], tf.gather(self.swap_lut, inputs[0]), inputs[0])
423
424
    @staticmethod
425
    def _single_flip(inputs):
426
        flip_axis = tf.where(inputs[1])
427
        return K.switch(tf.equal(tf.size(flip_axis), 0), inputs[0], tf.reverse(inputs[0], axis=flip_axis[..., 0]))
428
429
430
class SampleConditionalGMM(Layer):
431
    """This layer generates an image by sampling a Gaussian Mixture Model conditioned on a label map given as input.
432
    The parameters of the GMM are given as two additional inputs to the layer (means and standard deviations):
433
    image = SampleConditionalGMM(generation_labels)([label_map, means, stds])
434
435
    :param generation_labels: list of all possible label values contained in the input label maps.
436
    Must be a list or a 1D numpy array of size N, where N is the total number of possible label values.
437
438
    Layer inputs:
439
    label_map: input label map of shape [batchsize, shape_dim1, ..., shape_dimn, n_channel].
440
    All the values of label_map must be contained in generation_labels, but the input label_map doesn't necessarily have
441
    to contain all the values in generation_labels.
442
    means: tensor containing the mean values of all Gaussian distributions of the GMM.
443
           It must be of shape [batchsize, N, n_channel], and in the same order as generation label,
444
           i.e. the ith value of generation_labels will be associated to the ith value of means.
445
    stds: same as means but for the standard deviations of the GMM.
446
    """
447
448
    def __init__(self, generation_labels, **kwargs):
449
        self.generation_labels = generation_labels
450
        self.n_labels = None
451
        self.n_channels = None
452
        self.max_label = None
453
        self.indices = None
454
        self.shape = None
455
        super(SampleConditionalGMM, self).__init__(**kwargs)
456
457
    def get_config(self):
458
        config = super().get_config()
459
        config["generation_labels"] = self.generation_labels
460
        return config
461
462
    def build(self, input_shape):
463
464
        # check n_labels and n_channels
465
        assert len(input_shape) == 3, 'should have three inputs: labels, means, std devs (in that order).'
466
        self.n_channels = input_shape[1][-1]
467
        self.n_labels = len(self.generation_labels)
468
        assert self.n_labels == input_shape[1][1], 'means should have the same number of values as generation_labels'
469
        assert self.n_labels == input_shape[2][1], 'stds should have the same number of values as generation_labels'
470
471
        # scatter parameters (to build mean/std lut)
472
        self.max_label = np.max(self.generation_labels) + 1
473
        indices = np.concatenate([self.generation_labels + self.max_label * i for i in range(self.n_channels)], axis=-1)
474
        self.shape = tf.convert_to_tensor([np.max(indices) + 1], dtype='int32')
475
        self.indices = tf.convert_to_tensor(utils.add_axis(indices, axis=[0, -1]), dtype='int32')
476
477
        self.built = True
478
        super(SampleConditionalGMM, self).build(input_shape)
479
480
    def call(self, inputs, **kwargs):
481
482
        # reformat labels and scatter indices
483
        batch = tf.split(tf.shape(inputs[0]), [1, -1])[0]
484
        tmp_indices = tf.tile(self.indices, tf.concat([batch, tf.convert_to_tensor([1, 1], dtype='int32')], axis=0))
485
        labels = tf.concat([tf.cast(inputs[0], dtype='int32') + self.max_label * i for i in range(self.n_channels)], -1)
486
487
        # build mean map
488
        means = tf.concat([inputs[1][..., i] for i in range(self.n_channels)], 1)
489
        tile_shape = tf.concat([batch, tf.convert_to_tensor([1, ], dtype='int32')], axis=0)
490
        means = tf.tile(tf.expand_dims(tf.scatter_nd(tmp_indices, means, self.shape), 0), tile_shape)
491
        means_map = tf.map_fn(lambda x: tf.gather(x[0], x[1]), [means, labels], dtype=tf.float32)
492
493
        # same for stds
494
        stds = tf.concat([inputs[2][..., i] for i in range(self.n_channels)], 1)
495
        stds = tf.tile(tf.expand_dims(tf.scatter_nd(tmp_indices, stds, self.shape), 0), tile_shape)
496
        stds_map = tf.map_fn(lambda x: tf.gather(x[0], x[1]), [stds, labels], dtype=tf.float32)
497
498
        return stds_map * tf.random.normal(tf.shape(labels)) + means_map
499
500
    def compute_output_shape(self, input_shape):
501
        return input_shape[0] if (self.n_channels == 1) else tuple(list(input_shape[0][:-1]) + [self.n_channels])
502
503
504
class SampleResolution(Layer):
505
    """Build a random resolution tensor by sampling a uniform distribution of provided range.
506
507
    You can use this layer in the following ways:
508
        resolution = SampleConditionalGMM(min_resolution)() in this case resolution will be a tensor of shape (n_dims,),
509
        where n_dims is the length of the min_resolution parameter (provided as a list, see below).
510
        resolution = SampleConditionalGMM(min_resolution)(input), where input is a tensor for which the first dimension
511
        represents the batch_size. In this case resolution will be a tensor of shape (batchsize, n_dims,).
512
513
    :param min_resolution: list of length n_dims specifying the inferior bounds of the uniform distributions to
514
    sample from for each value.
515
    :param max_res_iso: If not None, all the values of resolution will be equal to the same value, which is randomly
516
    sampled at each minibatch in U(min_resolution, max_res_iso).
517
    :param max_res_aniso: If not None, we first randomly select a direction i in the range [0, n_dims-1], and we sample
518
    a value in the corresponding uniform distribution U(min_resolution[i], max_res_aniso[i]).
519
    The other values of resolution will be set to min_resolution.
520
    :param prob_iso: if both max_res_iso and max_res_aniso are specified, this allows to specify the probability of
521
    sampling an isotropic resolution (therefore using max_res_iso) with respect to anisotropic resolution
522
    (which would use max_res_aniso).
523
    :param prob_min: if not zero, this allows to return with the specified probability an output resolution equal
524
    to min_resolution.
525
    :param return_thickness: if set to True, this layer will also return a thickness value of the same shape as
526
    resolution, which will be sampled independently for each axis from the uniform distribution
527
    U(min_resolution, resolution).
528
529
    """
530
531
    def __init__(self,
532
                 min_resolution,
533
                 max_res_iso=None,
534
                 max_res_aniso=None,
535
                 prob_iso=0.1,
536
                 prob_min=0.05,
537
                 return_thickness=True,
538
                 **kwargs):
539
540
        self.min_res = min_resolution
541
        self.max_res_iso_input = max_res_iso
542
        self.max_res_iso = None
543
        self.max_res_aniso_input = max_res_aniso
544
        self.max_res_aniso = None
545
        self.prob_iso = prob_iso
546
        self.prob_min = prob_min
547
        self.return_thickness = return_thickness
548
        self.n_dims = len(self.min_res)
549
        self.add_batchsize = False
550
        self.min_res_tens = None
551
        super(SampleResolution, self).__init__(**kwargs)
552
553
    def get_config(self):
554
        config = super().get_config()
555
        config["min_resolution"] = self.min_res
556
        config["max_res_iso"] = self.max_res_iso
557
        config["max_res_aniso"] = self.max_res_aniso
558
        config["prob_iso"] = self.prob_iso
559
        config["prob_min"] = self.prob_min
560
        config["return_thickness"] = self.return_thickness
561
        return config
562
563
    def build(self, input_shape):
564
565
        # check maximum resolutions
566
        assert ((self.max_res_iso_input is not None) | (self.max_res_aniso_input is not None)), \
567
            'at least one of maximum isotropic or anisotropic resolutions must be provided, received none'
568
569
        # reformat resolutions as numpy arrays
570
        self.min_res = np.array(self.min_res)
571
        if self.max_res_iso_input is not None:
572
            self.max_res_iso = np.array(self.max_res_iso_input)
573
            assert len(self.min_res) == len(self.max_res_iso), \
574
                'min and isotropic max resolution must have the same length, ' \
575
                'had {0} and {1}'.format(self.min_res, self.max_res_iso)
576
            if np.array_equal(self.min_res, self.max_res_iso):
577
                self.max_res_iso = None
578
        if self.max_res_aniso_input is not None:
579
            self.max_res_aniso = np.array(self.max_res_aniso_input)
580
            assert len(self.min_res) == len(self.max_res_aniso), \
581
                'min and anisotropic max resolution must have the same length, ' \
582
                'had {} and {}'.format(self.min_res, self.max_res_aniso)
583
            if np.array_equal(self.min_res, self.max_res_aniso):
584
                self.max_res_aniso = None
585
586
        # check prob iso
587
        if (self.max_res_iso is not None) & (self.max_res_aniso is not None) & (self.prob_iso == 0):
588
            raise Exception('prob iso is 0 while sampling either isotropic and anisotropic resolutions is enabled')
589
590
        if input_shape:
591
            self.add_batchsize = True
592
593
        self.min_res_tens = tf.convert_to_tensor(self.min_res, dtype='float32')
594
595
        self.built = True
596
        super(SampleResolution, self).build(input_shape)
597
598
    def call(self, inputs, **kwargs):
599
600
        if not self.add_batchsize:
601
            shape = [self.n_dims]
602
            dim = tf.random.uniform(shape=(1, 1), minval=0, maxval=self.n_dims, dtype='int32')
603
            mask = tf.tensor_scatter_nd_update(tf.zeros([self.n_dims], dtype='bool'), dim,
604
                                               tf.convert_to_tensor([True], dtype='bool'))
605
        else:
606
            batch = tf.split(tf.shape(inputs), [1, -1])[0]
607
            tile_shape = tf.concat([batch, tf.convert_to_tensor([1], dtype='int32')], axis=0)
608
            self.min_res_tens = tf.tile(tf.expand_dims(self.min_res_tens, 0), tile_shape)
609
610
            shape = tf.concat([batch, tf.convert_to_tensor([self.n_dims], dtype='int32')], axis=0)
611
            indices = tf.stack([tf.range(0, batch[0]), tf.random.uniform(batch, 0, self.n_dims, dtype='int32')], 1)
612
            mask = tf.tensor_scatter_nd_update(tf.zeros(shape, dtype='bool'), indices, tf.ones(batch, dtype='bool'))
613
614
        # return min resolution as tensor if min=max
615
        if (self.max_res_iso is None) & (self.max_res_aniso is None):
616
            new_resolution = self.min_res_tens
617
618
        # sample isotropic resolution only
619
        elif (self.max_res_iso is not None) & (self.max_res_aniso is None):
620
            new_resolution_iso = tf.random.uniform(shape, minval=self.min_res, maxval=self.max_res_iso)
621
            new_resolution = K.switch(tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_min)),
622
                                      self.min_res_tens,
623
                                      new_resolution_iso)
624
625
        # sample anisotropic resolution only
626
        elif (self.max_res_iso is None) & (self.max_res_aniso is not None):
627
            new_resolution_aniso = tf.random.uniform(shape, minval=self.min_res, maxval=self.max_res_aniso)
628
            new_resolution = K.switch(tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_min)),
629
                                      self.min_res_tens,
630
                                      tf.where(mask, new_resolution_aniso, self.min_res_tens))
631
632
        # sample either anisotropic or isotropic resolution
633
        else:
634
            new_resolution_iso = tf.random.uniform(shape, minval=self.min_res, maxval=self.max_res_iso)
635
            new_resolution_aniso = tf.random.uniform(shape, minval=self.min_res, maxval=self.max_res_aniso)
636
            new_resolution = K.switch(tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_iso)),
637
                                      new_resolution_iso,
638
                                      tf.where(mask, new_resolution_aniso, self.min_res_tens))
639
            new_resolution = K.switch(tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_min)),
640
                                      self.min_res_tens,
641
                                      new_resolution)
642
643
        if self.return_thickness:
644
            return [new_resolution, tf.random.uniform(tf.shape(self.min_res_tens), self.min_res_tens, new_resolution)]
645
        else:
646
            return new_resolution
647
648
    def compute_output_shape(self, input_shape):
649
        if self.return_thickness:
650
            return [(None, self.n_dims)] * 2 if self.add_batchsize else [self.n_dims] * 2
651
        else:
652
            return (None, self.n_dims) if self.add_batchsize else self.n_dims
653
654
655
class GaussianBlur(Layer):
656
    """Applies gaussian blur to an input image.
657
    The input image is expected to have shape [batchsize, shape_dim1, ..., shape_dimn, channel].
658
    :param sigma: standard deviation of the blurring kernels to apply. Can be a number, a list of length n_dims, or
659
    a numpy array.
660
    :param random_blur_range: (optional) if not None, this introduces a randomness in the blurring kernels, where
661
    sigma is now multiplied by a coefficient dynamically sampled from a uniform distribution with bounds
662
    [1/random_blur_range, random_blur_range].
663
    :param use_mask: (optional) whether a mask of the input will be provided as an additional layer input. This is used
664
    to mask the blurred image, and to correct for edge blurring effects.
665
666
    example 1:
667
    output = GaussianBlur(sigma=0.5)(input) will isotropically blur the input with a gaussian kernel of std 0.5.
668
669
    example 2:
670
    if input is a tensor of shape [batchsize, 10, 100, 200, 2]
671
    output = GaussianBlur(sigma=[0.5, 1, 10])(input) will blur the input a different gaussian kernel in each dimension.
672
673
    example 3:
674
    output = GaussianBlur(sigma=0.5, random_blur_range=1.15)(input)
675
    will blur the input a different gaussian kernel in each dimension, as each dimension will be associated with
676
    a kernel, whose standard deviation will be uniformly sampled from [0.5/1.15; 0.5*1.15].
677
678
    example 4:
679
    output = GaussianBlur(sigma=0.5, use_mask=True)([input, mask])
680
    will 1) blur the input a different gaussian kernel in each dimension, 2) mask the blurred image with the provided
681
    mask, and 3) correct for edge blurring effects. If the provided mask is not of boolean type, it will be thresholded
682
    above positive values.
683
    """
684
685
    def __init__(self, sigma, random_blur_range=None, use_mask=False, **kwargs):
686
        self.sigma = utils.reformat_to_list(sigma)
687
        assert np.all(np.array(self.sigma) >= 0), 'sigma should be superior or equal to 0'
688
        self.use_mask = use_mask
689
690
        self.n_dims = None
691
        self.n_channels = None
692
        self.blur_range = random_blur_range
693
        self.stride = None
694
        self.separable = None
695
        self.kernels = None
696
        self.convnd = None
697
        super(GaussianBlur, self).__init__(**kwargs)
698
699
    def get_config(self):
700
        config = super().get_config()
701
        config["sigma"] = self.sigma
702
        config["random_blur_range"] = self.blur_range
703
        config["use_mask"] = self.use_mask
704
        return config
705
706
    def build(self, input_shape):
707
708
        # get shapes
709
        if self.use_mask:
710
            assert len(input_shape) == 2, 'please provide a mask as second layer input when use_mask=True'
711
            self.n_dims = len(input_shape[0]) - 2
712
            self.n_channels = input_shape[0][-1]
713
        else:
714
            self.n_dims = len(input_shape) - 2
715
            self.n_channels = input_shape[-1]
716
717
        # prepare blurring kernel
718
        self.stride = [1] * (self.n_dims + 2)
719
        self.sigma = utils.reformat_to_list(self.sigma, length=self.n_dims)
720
        self.separable = np.linalg.norm(np.array(self.sigma)) > 5
721
        if self.blur_range is None:  # fixed kernels
722
            self.kernels = l2i_et.gaussian_kernel(self.sigma, separable=self.separable)
723
        else:
724
            self.kernels = None
725
726
        # prepare convolution
727
        self.convnd = getattr(tf.nn, 'conv%dd' % self.n_dims)
728
729
        self.built = True
730
        super(GaussianBlur, self).build(input_shape)
731
732
    def call(self, inputs, **kwargs):
733
734
        if self.use_mask:
735
            image = inputs[0]
736
            mask = tf.cast(inputs[1], 'bool')
737
        else:
738
            image = inputs
739
            mask = None
740
741
        # redefine the kernels at each new step when blur_range is activated
742
        if self.blur_range is not None:
743
            self.kernels = l2i_et.gaussian_kernel(self.sigma, blur_range=self.blur_range, separable=self.separable)
744
745
        if self.separable:
746
            for k in self.kernels:
747
                if k is not None:
748
                    image = tf.concat([self.convnd(tf.expand_dims(image[..., n], -1), k, self.stride, 'SAME')
749
                                       for n in range(self.n_channels)], -1)
750
                    if self.use_mask:
751
                        maskb = tf.cast(mask, 'float32')
752
                        maskb = tf.concat([self.convnd(tf.expand_dims(maskb[..., n], -1), k, self.stride, 'SAME')
753
                                           for n in range(self.n_channels)], -1)
754
                        image = image / (maskb + K.epsilon())
755
                        image = tf.where(mask, image, tf.zeros_like(image))
756
        else:
757
            if any(self.sigma):
758
                image = tf.concat([self.convnd(tf.expand_dims(image[..., n], -1), self.kernels, self.stride, 'SAME')
759
                                   for n in range(self.n_channels)], -1)
760
                if self.use_mask:
761
                    maskb = tf.cast(mask, 'float32')
762
                    maskb = tf.concat([self.convnd(tf.expand_dims(maskb[..., n], -1), self.kernels, self.stride, 'SAME')
763
                                       for n in range(self.n_channels)], -1)
764
                    image = image / (maskb + K.epsilon())
765
                    image = tf.where(mask, image, tf.zeros_like(image))
766
767
        return image
768
769
770
class DynamicGaussianBlur(Layer):
771
    """Applies gaussian blur to an input image, where the standard deviation of the blurring kernel is provided as a
772
    layer input, which enables to perform dynamic blurring (i.e. the blurring kernel can vary at each minibatch).
773
    :param max_sigma: maximum value of the standard deviation that will be provided as input. This is used to compute
774
    the size of the blurring kernels. It must be provided as a list of length n_dims.
775
    :param random_blur_range: (optional) if not None, this introduces a randomness in the blurring kernels, where
776
    sigma is now multiplied by a coefficient dynamically sampled from a uniform distribution with bounds
777
    [1/random_blur_range, random_blur_range].
778
779
    example:
780
    blurred_image = DynamicGaussianBlur(max_sigma=[5.]*3, random_blurring_range=1.15)([image, sigma])
781
    will return a blurred version of image, where the standard deviation of each dimension (given as a tensor, and with
782
    values lower than 5 for each axis) is multiplied by a random coefficient uniformly sampled from [1/1.15; 1.15].
783
    """
784
785
    def __init__(self, max_sigma, random_blur_range=None, **kwargs):
786
        self.max_sigma = max_sigma
787
        self.n_dims = None
788
        self.n_channels = None
789
        self.convnd = None
790
        self.blur_range = random_blur_range
791
        self.separable = None
792
        super(DynamicGaussianBlur, self).__init__(**kwargs)
793
794
    def get_config(self):
795
        config = super().get_config()
796
        config["max_sigma"] = self.max_sigma
797
        config["random_blur_range"] = self.blur_range
798
        return config
799
800
    def build(self, input_shape):
801
        assert len(input_shape) == 2, 'sigma should be provided as an input tensor for dynamic blurring'
802
        self.n_dims = len(input_shape[0]) - 2
803
        self.n_channels = input_shape[0][-1]
804
        self.convnd = getattr(tf.nn, 'conv%dd' % self.n_dims)
805
        self.max_sigma = utils.reformat_to_list(self.max_sigma, length=self.n_dims)
806
        self.separable = np.linalg.norm(np.array(self.max_sigma)) > 5
807
        self.built = True
808
        super(DynamicGaussianBlur, self).build(input_shape)
809
810
    def call(self, inputs, **kwargs):
811
        image = inputs[0]
812
        sigma = inputs[-1]
813
        kernels = l2i_et.gaussian_kernel(sigma, self.max_sigma, self.blur_range, self.separable)
814
        if self.separable:
815
            for kernel in kernels:
816
                image = tf.map_fn(self._single_blur, [image, kernel], dtype=tf.float32)
817
        else:
818
            image = tf.map_fn(self._single_blur, [image, kernels], dtype=tf.float32)
819
        return image
820
821
    def _single_blur(self, inputs):
822
        if self.n_channels > 1:
823
            split_channels = tf.split(inputs[0], [1] * self.n_channels, axis=-1)
824
            blurred_channel = list()
825
            for channel in split_channels:
826
                blurred = self.convnd(tf.expand_dims(channel, 0), inputs[1], [1] * (self.n_dims + 2), padding='SAME')
827
                blurred_channel.append(tf.squeeze(blurred, axis=0))
828
            output = tf.concat(blurred_channel, -1)
829
        else:
830
            output = self.convnd(tf.expand_dims(inputs[0], 0), inputs[1], [1] * (self.n_dims + 2), padding='SAME')
831
            output = tf.squeeze(output, axis=0)
832
        return output
833
834
835
class MimicAcquisition(Layer):
836
    """
837
    Layer that takes an image as input, and simulates data that has been acquired at low resolution.
838
    The output is obtained by resampling the input twice:
839
     - first at a resolution given as an input (i.e. the "acquisition" resolution),
840
     - then at the output resolution (specified output shape).
841
    The input tensor is expected to have shape [batchsize, shape_dim1, ..., shape_dimn, channel].
842
843
    :param volume_res: resolution of the provided inputs. Must be a 1-D numpy array with n_dims elements.
844
    :param min_subsample_res: lower bound of the acquisition resolutions to mimic (i.e. the input resolution must have
845
    values higher than min-subsample_res).
846
    :param resample_shape: shape of the output tensor
847
    :param build_dist_map: whether to return distance maps as outputs. These indicate the distance between each voxel
848
    and the nearest non-interpolated voxel (during the second resampling).
849
    :param prob_noise: probability to apply noise injection
850
851
    example 1:
852
    im_res = [1., 1., 1.]
853
    low_res = [1., 1., 3.]
854
    res = tf.convert_to_tensor([1., 1., 4.5])
855
    image is a tensor of shape (None, 256, 256, 256, 3)
856
    resample_shape = [256, 256, 256]
857
    output = MimicAcquisition(im_res, low_res, resample_shape)([image, res])
858
    output will be a tensor of shape (None, 256, 256, 256, 3), obtained by downsampling image to [1., 1., 4.5].
859
    and re-upsampling it at initial resolution (because resample_shape is equal to the input shape). In this example all
860
    examples of the batch will be downsampled to the same resolution (because res has no batch dimension).
861
    Note that the provided res must have higher values than min_low_res.
862
863
    example 2:
864
    im_res = [1., 1., 1.]
865
    min_low_res = [1., 1., 1.]
866
    res is a tensor of shape (None, 3), obtained for example by using the SampleResolution layer (see above).
867
    image is a tensor of shape (None, 256, 256, 256, 1)
868
    resample_shape = [128, 128, 128]
869
    output = MimicAcquisition(im_res, low_res, resample_shape)([image, res])
870
    output will be a tensor of shape (None, 128, 128, 128, 1), obtained by downsampling each examples of the batch to
871
    the matching resolution in res, and resampling them all to half the initial resolution.
872
    Note that the provided res must have higher values than min_low_res.
873
    """
874
875
    def __init__(self, volume_res, min_subsample_res, resample_shape, build_dist_map=False,
876
                 noise_std=0, prob_noise=0.95, **kwargs):
877
878
        # resolutions and dimensions
879
        self.volume_res = volume_res
880
        self.min_subsample_res = min_subsample_res
881
        self.n_dims = len(self.volume_res)
882
        self.n_channels = None
883
        self.add_batchsize = None
884
885
        # noise
886
        self.noise_std = noise_std
887
        self.prob_noise = prob_noise
888
889
        # input and output shapes
890
        self.inshape = None
891
        self.resample_shape = resample_shape
892
893
        # meshgrids for resampling
894
        self.down_grid = None
895
        self.up_grid = None
896
897
        # whether to return a map indicating the distance from the interpolated voxels, to acquired ones.
898
        self.build_dist_map = build_dist_map
899
900
        super(MimicAcquisition, self).__init__(**kwargs)
901
902
    def get_config(self):
903
        config = super().get_config()
904
        config["volume_res"] = self.volume_res
905
        config["min_subsample_res"] = self.min_subsample_res
906
        config["resample_shape"] = self.resample_shape
907
        config["build_dist_map"] = self.build_dist_map
908
        config["noise_std"] = self.noise_std
909
        config["prob_noise"] = self.prob_noise
910
        return config
911
912
    def build(self, input_shape):
913
914
        # set up input shape and acquisition shape
915
        self.inshape = input_shape[0][1:]
916
        self.n_channels = input_shape[0][-1]
917
        self.add_batchsize = False if (input_shape[1][0] is None) else True
918
        down_tensor_shape = np.int32(np.array(self.inshape[:-1]) * self.volume_res / self.min_subsample_res)
919
920
        # build interpolation meshgrids
921
        self.down_grid = tf.expand_dims(tf.stack(nrn_utils.volshape_to_ndgrid(down_tensor_shape), -1), axis=0)
922
        self.up_grid = tf.expand_dims(tf.stack(nrn_utils.volshape_to_ndgrid(self.resample_shape), -1), axis=0)
923
924
        self.built = True
925
        super(MimicAcquisition, self).build(input_shape)
926
927
    def call(self, inputs, **kwargs):
928
929
        # sort inputs
930
        assert len(inputs) == 2, 'inputs must have two items, the tensor to resample, and the downsampling resolution'
931
        vol = inputs[0]
932
        subsample_res = tf.cast(inputs[1], dtype='float32')
933
        vol = K.reshape(vol, [-1, *self.inshape])  # necessary for multi_gpu models
934
        batchsize = tf.split(tf.shape(vol), [1, -1])[0]
935
        tile_shape = tf.concat([batchsize, tf.ones([1], dtype='int32')], 0)
936
937
        # get downsampling and upsampling factors
938
        if self.add_batchsize:
939
            subsample_res = tf.tile(tf.expand_dims(subsample_res, 0), tile_shape)
940
        down_shape = tf.cast(tf.convert_to_tensor(np.array(self.inshape[:-1]) * self.volume_res, dtype='float32') /
941
                             subsample_res, dtype='int32')
942
        down_zoom_factor = tf.cast(down_shape / tf.convert_to_tensor(self.inshape[:-1]), dtype='float32')
943
        up_zoom_factor = tf.cast(tf.convert_to_tensor(self.resample_shape, dtype='int32') / down_shape, dtype='float32')
944
945
        # downsample
946
        down_loc = tf.tile(self.down_grid, tf.concat([batchsize, tf.ones([self.n_dims + 1], dtype='int32')], 0))
947
        down_loc = tf.cast(down_loc, 'float32') / l2i_et.expand_dims(down_zoom_factor, axis=[1] * self.n_dims)
948
        inshape_tens = tf.tile(tf.expand_dims(tf.convert_to_tensor(self.inshape[:-1]), 0), tile_shape)
949
        inshape_tens = l2i_et.expand_dims(inshape_tens, axis=[1] * self.n_dims)
950
        down_loc = K.clip(down_loc, 0., tf.cast(inshape_tens, 'float32'))
951
        vol = tf.map_fn(self._single_down_interpn, [vol, down_loc], tf.float32)
952
953
        # add noise with predefined probability
954
        if self.noise_std > 0:
955
            sample_shape = tf.concat([batchsize, tf.ones([self.n_dims], dtype='int32'),
956
                                      self.n_channels * tf.ones([1], dtype='int32')], 0)
957
            noise = tf.random.normal(tf.shape(vol), stddev=tf.random.uniform(sample_shape, maxval=self.noise_std))
958
            if self.prob_noise == 1:
959
                vol += noise
960
            else:
961
                vol = K.switch(tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_noise)), vol + noise, vol)
962
963
        # upsample
964
        up_loc = tf.tile(self.up_grid, tf.concat([batchsize, tf.ones([self.n_dims + 1], dtype='int32')], axis=0))
965
        up_loc = tf.cast(up_loc, 'float32') / l2i_et.expand_dims(up_zoom_factor, axis=[1] * self.n_dims)
966
        vol = tf.map_fn(self._single_up_interpn, [vol, up_loc], tf.float32)
967
968
        # return upsampled volume
969
        if not self.build_dist_map:
970
            return vol
971
972
        # return upsampled volumes with distance maps
973
        else:
974
975
            # get grid points
976
            floor = tf.math.floor(up_loc)
977
            ceil = tf.math.ceil(up_loc)
978
979
            # get distances of every voxel to higher and lower grid points for every dimension
980
            f_dist = up_loc - floor
981
            c_dist = ceil - up_loc
982
983
            # keep minimum 1d distances, and compute 3d distance to nearest grid point
984
            dist = tf.math.minimum(f_dist, c_dist) * l2i_et.expand_dims(subsample_res, axis=[1] * self.n_dims)
985
            dist = tf.math.sqrt(tf.math.reduce_sum(tf.math.square(dist), axis=-1, keepdims=True))
986
987
            return [vol, dist]
988
989
    @staticmethod
990
    def _single_down_interpn(inputs):
991
        return nrn_utils.interpn(inputs[0], inputs[1], interp_method='nearest')
992
993
    @staticmethod
994
    def _single_up_interpn(inputs):
995
        return nrn_utils.interpn(inputs[0], inputs[1], interp_method='linear')
996
997
    def compute_output_shape(self, input_shape):
998
        output_shape = tuple([None] + self.resample_shape + [input_shape[0][-1]])
999
        return [output_shape] * 2 if self.build_dist_map else output_shape
1000
1001
1002
class BiasFieldCorruption(Layer):
1003
    """This layer applies a smooth random bias field to the input by applying the following steps:
1004
    1) we first sample a value for the standard deviation of a centred normal distribution
1005
    2) a small-size SVF is sampled from this normal distribution
1006
    3) the small SVF is then resized with trilinear interpolation to image size
1007
    4) it is rescaled to positive values by taking the voxel-wise exponential
1008
    5) it is multiplied to the input tensor.
1009
    The input tensor is expected to have shape [batchsize, shape_dim1, ..., shape_dimn, channel].
1010
1011
    :param bias_field_std: maximum value of the standard deviation sampled in 1 (it will be sampled from the range
1012
    [0, bias_field_std])
1013
    :param bias_scale: ratio between the shape of the input tensor and the shape of the sampled SVF.
1014
    :param same_bias_for_all_channels: whether to apply the same bias field to all the channels of the input tensor.
1015
    :param prob: probability to apply this bias field corruption.
1016
    """
1017
1018
    def __init__(self, bias_field_std=.5, bias_scale=.025, same_bias_for_all_channels=False, prob=0.95, **kwargs):
1019
1020
        # input shape
1021
        self.several_inputs = False
1022
        self.inshape = None
1023
        self.n_dims = None
1024
        self.n_channels = None
1025
1026
        # sampling shape
1027
        self.std_shape = None
1028
        self.small_bias_shape = None
1029
1030
        # bias field parameters
1031
        self.bias_field_std = bias_field_std
1032
        self.bias_scale = bias_scale
1033
        self.same_bias_for_all_channels = same_bias_for_all_channels
1034
        self.prob = prob
1035
1036
        super(BiasFieldCorruption, self).__init__(**kwargs)
1037
1038
    def get_config(self):
1039
        config = super().get_config()
1040
        config["bias_field_std"] = self.bias_field_std
1041
        config["bias_scale"] = self.bias_scale
1042
        config["same_bias_for_all_channels"] = self.same_bias_for_all_channels
1043
        config["prob"] = self.prob
1044
        return config
1045
1046
    def build(self, input_shape):
1047
1048
        # input shape
1049
        if isinstance(input_shape, list):
1050
            self.several_inputs = True
1051
            self.inshape = input_shape
1052
        else:
1053
            self.inshape = [input_shape]
1054
        self.n_dims = len(self.inshape[0]) - 2
1055
        self.n_channels = self.inshape[0][-1]
1056
1057
        # sampling shapes
1058
        self.std_shape = [1] * (self.n_dims + 1)
1059
        self.small_bias_shape = utils.get_resample_shape(self.inshape[0][1:self.n_dims + 1], self.bias_scale, 1)
1060
        if not self.same_bias_for_all_channels:
1061
            self.std_shape[-1] = self.n_channels
1062
            self.small_bias_shape[-1] = self.n_channels
1063
1064
        self.built = True
1065
        super(BiasFieldCorruption, self).build(input_shape)
1066
1067
    def call(self, inputs, **kwargs):
1068
1069
        if not self.several_inputs:
1070
            inputs = [inputs]
1071
1072
        if self.bias_field_std > 0:
1073
1074
            # sampling shapes
1075
            batchsize = tf.split(tf.shape(inputs[0]), [1, -1])[0]
1076
            std_shape = tf.concat([batchsize, tf.convert_to_tensor(self.std_shape, dtype='int32')], 0)
1077
            bias_shape = tf.concat([batchsize, tf.convert_to_tensor(self.small_bias_shape, dtype='int32')], axis=0)
1078
1079
            # sample small bias field
1080
            bias_field = tf.random.normal(bias_shape, stddev=tf.random.uniform(std_shape, maxval=self.bias_field_std))
1081
1082
            # resize bias field and take exponential
1083
            bias_field = nrn_layers.Resize(size=self.inshape[0][1:self.n_dims + 1], interp_method='linear')(bias_field)
1084
            bias_field = tf.math.exp(bias_field)
1085
1086
            # apply bias field with predefined probability
1087
            if self.prob == 1:
1088
                return [tf.math.multiply(bias_field, v) for v in inputs]
1089
            else:
1090
                rand_trans = tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob))
1091
                if self.several_inputs:
1092
                    return [K.switch(rand_trans, tf.math.multiply(bias_field, v), v) for v in inputs]
1093
                else:
1094
                    return K.switch(rand_trans, tf.math.multiply(bias_field, inputs[0]), inputs[0])
1095
1096
        else:
1097
            return inputs
1098
1099
1100
class IntensityAugmentation(Layer):
1101
    """This layer enables to augment the intensities of the input tensor, as well as to apply min_max normalisation.
1102
    The following steps are applied (all are optional):
1103
    1) white noise corruption, with a randomly sampled std dev.
1104
    2) clip the input between two values
1105
    3) min-max normalisation
1106
    4) gamma augmentation (i.e. voxel-wise exponentiation by a randomly sampled power)
1107
    The input tensor is expected to have shape [batchsize, shape_dim1, ..., shape_dimn, channel].
1108
1109
    :param noise_std: maximum value of the standard deviation of the Gaussian white noise used in 1 (it will be sampled
1110
    from the range [0, noise_std]). Set to 0 to skip this step.
1111
    :param clip: clip the input tensor between the given values. Can either be: a number (in which case we clip between
1112
    0 and the given value), or a list or a numpy array with two elements. Default is 0, where no clipping occurs.
1113
    :param normalise: whether to apply min-max normalisation, to normalise between 0 and 1. Default is True.
1114
    :param norm_perc: percentiles (between 0 and 1) of the sorted intensity values for robust normalisation. Can be:
1115
    a number (in which case the robust minimum is the provided percentile of sorted values, and the maximum is the
1116
    1 - norm_perc percentile), or a list/numpy array of 2 elements (percentiles for the minimum and maximum values).
1117
    The minimum and maximum values are computed separately for each channel if separate_channels is True.
1118
    Default is 0, where we simply take the minimum and maximum values.
1119
    :param gamma_std: standard deviation of the normal distribution from which we sample gamma (in log domain).
1120
    Default is 0, where no gamma augmentation occurs.
1121
    :param contrast_inversion: whether to perform contrast inversion (i.e. 1 - x). If True, this is performed randomly
1122
    for each element of the batch, as well as for each channel.
1123
    :param separate_channels: whether to augment all channels separately. Default is True.
1124
    :param prob_noise: probability to apply noise injection
1125
    :param prob_gamma: probability to apply gamma augmentation
1126
    """
1127
1128
    def __init__(self, noise_std=0, clip=0, normalise=True, norm_perc=0, gamma_std=0, contrast_inversion=False,
1129
                 separate_channels=True, prob_noise=0.95, prob_gamma=1, **kwargs):
1130
1131
        # shape attributes
1132
        self.n_dims = None
1133
        self.n_channels = None
1134
        self.flatten_shape = None
1135
        self.expand_minmax_dim = None
1136
        self.one = None
1137
1138
        # inputs
1139
        self.noise_std = noise_std
1140
        self.clip = clip
1141
        self.clip_values = None
1142
        self.normalise = normalise
1143
        self.norm_perc = norm_perc
1144
        self.perc = None
1145
        self.gamma_std = gamma_std
1146
        self.separate_channels = separate_channels
1147
        self.contrast_inversion = contrast_inversion
1148
        self.prob_noise = prob_noise
1149
        self.prob_gamma = prob_gamma
1150
1151
        super(IntensityAugmentation, self).__init__(**kwargs)
1152
1153
    def get_config(self):
1154
        config = super().get_config()
1155
        config["noise_std"] = self.noise_std
1156
        config["clip"] = self.clip
1157
        config["normalise"] = self.normalise
1158
        config["norm_perc"] = self.norm_perc
1159
        config["gamma_std"] = self.gamma_std
1160
        config["separate_channels"] = self.separate_channels
1161
        config["prob_noise"] = self.prob_noise
1162
        config["prob_gamma"] = self.prob_gamma
1163
        return config
1164
1165
    def build(self, input_shape):
1166
        self.n_dims = len(input_shape) - 2
1167
        self.n_channels = input_shape[-1]
1168
        self.flatten_shape = np.prod(np.array(input_shape[1:-1]))
1169
        self.flatten_shape = self.flatten_shape * self.n_channels if not self.separate_channels else self.flatten_shape
1170
        self.expand_minmax_dim = self.n_dims if self.separate_channels else self.n_dims + 1
1171
        self.one = tf.ones([1], dtype='int32')
1172
        if self.clip:
1173
            self.clip_values = utils.reformat_to_list(self.clip)
1174
            self.clip_values = self.clip_values if len(self.clip_values) == 2 else [0, self.clip_values[0]]
1175
        else:
1176
            self.clip_values = None
1177
        if self.norm_perc:
1178
            self.perc = utils.reformat_to_list(self.norm_perc)
1179
            self.perc = self.perc if len(self.perc) == 2 else [self.perc[0], 1 - self.perc[0]]
1180
        else:
1181
            self.perc = None
1182
1183
        self.built = True
1184
        super(IntensityAugmentation, self).build(input_shape)
1185
1186
    def call(self, inputs, **kwargs):
1187
1188
        # prepare shape for sampling the noise and gamma std dev (depending on whether we augment channels separately)
1189
        batchsize = tf.split(tf.shape(inputs), [1, -1])[0]
1190
        if (self.noise_std > 0) | (self.gamma_std > 0) | self.contrast_inversion:
1191
            sample_shape = tf.concat([batchsize, tf.ones([self.n_dims], dtype='int32')], 0)
1192
            if self.separate_channels:
1193
                sample_shape = tf.concat([sample_shape, self.n_channels * self.one], 0)
1194
            else:
1195
                sample_shape = tf.concat([sample_shape, self.one], 0)
1196
        else:
1197
            sample_shape = None
1198
1199
        # add noise with predefined probability
1200
        if self.noise_std > 0:
1201
            noise_stddev = tf.random.uniform(sample_shape, maxval=self.noise_std)
1202
            if self.separate_channels:
1203
                noise = tf.random.normal(tf.shape(inputs), stddev=noise_stddev)
1204
            else:
1205
                noise = tf.random.normal(tf.shape(tf.split(inputs, [1, -1], -1)[0]), stddev=noise_stddev)
1206
                noise = tf.tile(noise, tf.convert_to_tensor([1] * (self.n_dims + 1) + [self.n_channels]))
1207
            if self.prob_noise == 1:
1208
                inputs = inputs + noise
1209
            else:
1210
                inputs = K.switch(tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_noise)),
1211
                                  inputs + noise, inputs)
1212
1213
        # clip images to given values
1214
        if self.clip_values is not None:
1215
            inputs = K.clip(inputs, self.clip_values[0], self.clip_values[1])
1216
1217
        # normalise
1218
        if self.normalise:
1219
            # define robust min and max by sorting values and taking percentile
1220
            if self.perc is not None:
1221
                if self.separate_channels:
1222
                    shape = tf.concat([batchsize, self.flatten_shape * self.one, self.n_channels * self.one], 0)
1223
                else:
1224
                    shape = tf.concat([batchsize, self.flatten_shape * self.one], 0)
1225
                intensities = tf.sort(tf.reshape(inputs, shape), axis=1)
1226
                m = intensities[:, max(int(self.perc[0] * self.flatten_shape), 0), ...]
1227
                M = intensities[:, min(int(self.perc[1] * self.flatten_shape), self.flatten_shape - 1), ...]
1228
            # simple min and max
1229
            else:
1230
                m = K.min(inputs, axis=list(range(1, self.expand_minmax_dim + 1)))
1231
                M = K.max(inputs, axis=list(range(1, self.expand_minmax_dim + 1)))
1232
            # normalise
1233
            m = l2i_et.expand_dims(m, axis=[1] * self.expand_minmax_dim)
1234
            M = l2i_et.expand_dims(M, axis=[1] * self.expand_minmax_dim)
1235
            inputs = tf.clip_by_value(inputs, m, M)
1236
            inputs = (inputs - m) / (M - m + K.epsilon())
1237
1238
        # apply voxel-wise exponentiation with predefined probability
1239
        if self.gamma_std > 0:
1240
            gamma = tf.random.normal(sample_shape, stddev=self.gamma_std)
1241
            if self.prob_gamma == 1:
1242
                inputs = tf.math.pow(inputs, tf.math.exp(gamma))
1243
            else:
1244
                inputs = K.switch(tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_gamma)),
1245
                                  tf.math.pow(inputs, tf.math.exp(gamma)), inputs)
1246
1247
        # apply random contrast inversion
1248
        if self.contrast_inversion:
1249
            rand_invert = tf.less(tf.random.uniform(sample_shape, maxval=1), 0.5)
1250
            split_channels = tf.split(inputs, [1] * self.n_channels, axis=-1)
1251
            split_rand_invert = tf.split(rand_invert, [1] * self.n_channels, axis=-1)
1252
            inverted_channel = list()
1253
            for (channel, invert) in zip(split_channels, split_rand_invert):
1254
                inverted_channel.append(tf.map_fn(self._single_invert, [channel, invert], dtype=channel.dtype))
1255
            inputs = tf.concat(inverted_channel, -1)
1256
1257
        return inputs
1258
1259
    @staticmethod
1260
    def _single_invert(inputs):
1261
        return K.switch(tf.squeeze(inputs[1]), 1 - inputs[0], inputs[0])
1262
1263
1264
class DiceLoss(Layer):
1265
    """This layer computes the soft Dice loss between two tensors.
1266
    These tensors are expected to have the same shape (one-hot encoding) [batch, size_dim1, ..., size_dimN, n_labels].
1267
    The first input tensor is the GT and the second is the prediction: dice_loss = DiceLoss()([gt, pred])
1268
1269
    :param class_weights: (optional) if given, the loss is obtained by a weighted average of the Dice across labels.
1270
    Must be a sequence or 1d numpy array of length n_labels. Can also be -1, where the weights are dynamically set to
1271
    the inverse of the volume of each label in the ground truth.
1272
    :param boundary_weights: (optional) bonus weight that we apply to the voxels close to boundaries between structures
1273
    when computing the loss. Default is 0 where no boundary weighting is applied.
1274
    :param boundary_dist: (optional) if boundary_weight is not 0, the extra boundary weighting is applied to all voxels
1275
    within this distance to a region boundary. Default is 3.
1276
    :param skip_background: (optional) whether to skip boundary weighting for the background class, as this may be
1277
    redundant when we have several labels. This is only used if boundary_weight is not 0.
1278
    :param enable_checks: (optional) whether to make sure that the 2 input tensors are probabilistic (i.e. the label
1279
    probabilities sum to 1 at each voxel location). Default is True.
1280
    """
1281
1282
    def __init__(self,
1283
                 class_weights=None,
1284
                 boundary_weights=0,
1285
                 boundary_dist=3,
1286
                 skip_background=True,
1287
                 enable_checks=True,
1288
                 **kwargs):
1289
1290
        self.class_weights = class_weights
1291
        self.dynamic_weighting = False
1292
        self.class_weights_tens = None
1293
        self.boundary_weights = boundary_weights
1294
        self.boundary_dist = boundary_dist
1295
        self.skip_background = skip_background
1296
        self.enable_checks = enable_checks
1297
        self.spatial_axes = None
1298
        self.avg_pooling_layer = None
1299
        super(DiceLoss, self).__init__(**kwargs)
1300
1301
    def get_config(self):
1302
        config = super().get_config()
1303
        config["class_weights"] = self.class_weights
1304
        config["boundary_weights"] = self.boundary_weights
1305
        config["boundary_dist"] = self.boundary_dist
1306
        config["skip_background"] = self.skip_background
1307
        config["enable_checks"] = self.enable_checks
1308
        return config
1309
1310
    def build(self, input_shape):
1311
1312
        # get shape
1313
        assert len(input_shape) == 2, 'DiceLoss expects 2 inputs to compute the Dice loss.'
1314
        assert input_shape[0] == input_shape[1], 'the two inputs must have the same shape.'
1315
        inshape = input_shape[0][1:]
1316
        n_dims = len(inshape[:-1])
1317
        n_labels = inshape[-1]
1318
        self.spatial_axes = list(range(1, n_dims + 1))
1319
        self.avg_pooling_layer = getattr(keras.layers, 'AvgPool%dD' % n_dims)
1320
        self.skip_background = False if n_labels == 1 else self.skip_background
1321
1322
        # build tensor with class weights
1323
        if self.class_weights is not None:
1324
            if self.class_weights == -1:
1325
                self.dynamic_weighting = True
1326
            else:
1327
                class_weights_tens = utils.reformat_to_list(self.class_weights, n_labels)
1328
                class_weights_tens = tf.convert_to_tensor(class_weights_tens, 'float32')
1329
                self.class_weights_tens = l2i_et.expand_dims(class_weights_tens, 0)
1330
1331
        self.built = True
1332
        super(DiceLoss, self).build(input_shape)
1333
1334
    def call(self, inputs, **kwargs):
1335
1336
        # make sure tensors are probabilistic
1337
        gt = inputs[0]
1338
        pred = inputs[1]
1339
        if self.enable_checks:  # disabling is useful to, e.g., use incomplete label maps
1340
            gt = K.clip(gt / (tf.math.reduce_sum(gt, axis=-1, keepdims=True) + tf.keras.backend.epsilon()), 0, 1)
1341
            pred = K.clip(pred / (tf.math.reduce_sum(pred, axis=-1, keepdims=True) + tf.keras.backend.epsilon()), 0, 1)
1342
1343
        # compute dice loss for each label
1344
        top = 2 * gt * pred
1345
        bottom = tf.math.square(gt) + tf.math.square(pred)
1346
1347
        # apply boundary weighting (ie voxels close to region boundaries will be counted several times to compute Dice)
1348
        if self.boundary_weights:
1349
            avg = self.avg_pooling_layer(pool_size=2 * self.boundary_dist + 1, strides=1, padding='same')(gt)
1350
            boundaries = tf.cast(avg > 0., 'float32') * tf.cast(avg < (1 / len(self.spatial_axes) - 1e-4), 'float32')
1351
            if self.skip_background:
1352
                boundaries_channels = tf.unstack(boundaries, axis=-1)
1353
                boundaries = tf.stack([tf.zeros_like(boundaries_channels[0])] + boundaries_channels[1:], axis=-1)
1354
            boundary_weights_tensor = 1 + self.boundary_weights * boundaries
1355
            top *= boundary_weights_tensor
1356
            bottom *= boundary_weights_tensor
1357
        else:
1358
            boundary_weights_tensor = None
1359
1360
        # compute loss
1361
        top = tf.math.reduce_sum(top, self.spatial_axes)
1362
        bottom = tf.math.reduce_sum(bottom, self.spatial_axes)
1363
        dice = (top + tf.keras.backend.epsilon()) / (bottom + tf.keras.backend.epsilon())
1364
        loss = 1 - dice
1365
1366
        # apply class weighting across labels. In this case loss will have shape (batch), otherwise (batch, n_labels).
1367
        if self.dynamic_weighting:  # the weight of a class is the inverse of its volume in the gt
1368
            if boundary_weights_tensor is not None:  # we account for the boundary weighting to compute volume
1369
                self.class_weights_tens = 1 / tf.reduce_sum(gt * boundary_weights_tensor, self.spatial_axes)
1370
            else:
1371
                self.class_weights_tens = 1 / tf.reduce_sum(gt, self.spatial_axes)
1372
        if self.class_weights_tens is not None:
1373
            self. class_weights_tens /= tf.reduce_sum(self.class_weights_tens, -1)
1374
            loss = tf.reduce_sum(loss * self.class_weights_tens, -1)
1375
1376
        return tf.math.reduce_mean(loss)
1377
1378
    def compute_output_shape(self, input_shape):
1379
        return [[]]
1380
1381
1382
class WeightedL2Loss(Layer):
1383
    """This layer computes a L2 loss weighted by a specified factor (target_value) between two tensors.
1384
    This is designed to be used on the layer before the softmax.
1385
    The tensors are expected to have the same shape [batchsize, size_dim1, ..., size_dimN, n_labels].
1386
    The first input tensor is the GT and the second is the prediction: wl2_loss = WeightedL2Loss()([gt, pred])
1387
1388
    :param target_value: target value for the layer before softmax: target_value when gt = 1, -target_value when gt = 0.
1389
    """
1390
1391
    def __init__(self, target_value=5, **kwargs):
1392
        self.target_value = target_value
1393
        self.n_labels = None
1394
        super(WeightedL2Loss, self).__init__(**kwargs)
1395
1396
    def get_config(self):
1397
        config = super().get_config()
1398
        config["target_value"] = self.target_value
1399
        return config
1400
1401
    def build(self, input_shape):
1402
        assert len(input_shape) == 2, 'DiceLoss expects 2 inputs to compute the Dice loss.'
1403
        assert input_shape[0] == input_shape[1], 'the two inputs must have the same shape.'
1404
        self.n_labels = input_shape[0][-1]
1405
        self.built = True
1406
        super(WeightedL2Loss, self).build(input_shape)
1407
1408
    def call(self, inputs, **kwargs):
1409
        gt = inputs[0]
1410
        pred = inputs[1]
1411
        weights = tf.expand_dims(1 - gt[..., 0] + 1e-8, -1)
1412
        return K.sum(weights * K.square(pred - self.target_value * (2 * gt - 1))) / (K.sum(weights) * self.n_labels)
1413
1414
    def compute_output_shape(self, input_shape):
1415
        return [[]]
1416
1417
1418
class CrossEntropyLoss(Layer):
1419
    """This layer computes the cross-entropy loss between two tensors.
1420
    These tensors are expected to have the same shape (one-hot encoding) [batch, size_dim1, ..., size_dimN, n_labels].
1421
    The first input tensor is the GT and the second is the prediction: ce_loss = CrossEntropyLoss()([gt, pred])
1422
1423
    :param class_weights: (optional) if given, the loss is obtained by a weighted average of the Dice across labels.
1424
    Must be a sequence or 1d numpy array of length n_labels. Can also be -1, where the weights are dynamically set to
1425
    the inverse of the volume of each label in the ground truth.
1426
    :param boundary_weights: (optional) bonus weight that we apply to the voxels close to boundaries between structures
1427
    when computing the loss. Default is 0 where no boundary weighting is applied.
1428
    :param boundary_dist: (optional) if boundary_weight is not 0, the extra boundary weighting is applied to all voxels
1429
    within this distance to a region boundary. Default is 3.
1430
    :param skip_background: (optional) whether to skip boundary weighting for the background class, as this may be
1431
    redundant when we have several labels. This is only used if boundary_weight is not 0.
1432
    :param enable_checks: (optional) whether to make sure that the 2 input tensors are probabilistic (i.e. the label
1433
    probabilities sum to 1 at each voxel location). Default is True.
1434
    """
1435
1436
    def __init__(self,
1437
                 class_weights=None,
1438
                 boundary_weights=0,
1439
                 boundary_dist=3,
1440
                 skip_background=True,
1441
                 enable_checks=True,
1442
                 **kwargs):
1443
1444
        self.class_weights = class_weights
1445
        self.dynamic_weighting = False
1446
        self.class_weights_tens = None
1447
        self.boundary_weights = boundary_weights
1448
        self.boundary_dist = boundary_dist
1449
        self.skip_background = skip_background
1450
        self.enable_checks = enable_checks
1451
        self.spatial_axes = None
1452
        self.avg_pooling_layer = None
1453
        super(CrossEntropyLoss, self).__init__(**kwargs)
1454
1455
    def get_config(self):
1456
        config = super().get_config()
1457
        config["class_weights"] = self.class_weights
1458
        config["boundary_weights"] = self.boundary_weights
1459
        config["boundary_dist"] = self.boundary_dist
1460
        config["skip_background"] = self.skip_background
1461
        config["enable_checks"] = self.enable_checks
1462
        return config
1463
1464
    def build(self, input_shape):
1465
1466
        # get shape
1467
        assert len(input_shape) == 2, 'CrossEntropy expects 2 inputs to compute the Dice loss.'
1468
        assert input_shape[0] == input_shape[1], 'the two inputs must have the same shape.'
1469
        inshape = input_shape[0][1:]
1470
        n_dims = len(inshape[:-1])
1471
        n_labels = inshape[-1]
1472
        self.spatial_axes = list(range(1, n_dims + 1))
1473
        self.avg_pooling_layer = getattr(keras.layers, 'AvgPool%dD' % n_dims)
1474
        self.skip_background = False if n_labels == 1 else self.skip_background
1475
1476
        # build tensor with class weights
1477
        if self.class_weights is not None:
1478
            if self.class_weights == -1:
1479
                self.dynamic_weighting = True
1480
            else:
1481
                class_weights_tens = utils.reformat_to_list(self.class_weights, n_labels)
1482
                class_weights_tens = tf.convert_to_tensor(class_weights_tens, 'float32')
1483
                self.class_weights_tens = l2i_et.expand_dims(class_weights_tens, [0] * (1 + n_dims))
1484
1485
        self.built = True
1486
        super(CrossEntropyLoss, self).build(input_shape)
1487
1488
    def call(self, inputs, **kwargs):
1489
1490
        # make sure tensors are probabilistic
1491
        gt = inputs[0]
1492
        pred = inputs[1]
1493
        if self.enable_checks:  # disabling is useful to, e.g., use incomplete label maps
1494
            gt = K.clip(gt / (tf.math.reduce_sum(gt, axis=-1, keepdims=True) + tf.keras.backend.epsilon()), 0, 1)
1495
            pred = pred / (tf.math.reduce_sum(pred, axis=-1, keepdims=True) + tf.keras.backend.epsilon())
1496
            pred = K.clip(pred, tf.keras.backend.epsilon(), 1 - tf.keras.backend.epsilon())  # to avoid log(0)
1497
1498
        # compare prediction/target, ce has the same shape has the input tensors
1499
        ce = -gt * tf.math.log(pred)
1500
1501
        # apply boundary weighting (ie voxels close to region boundaries will be counted several times to compute Dice)
1502
        if self.boundary_weights:
1503
            avg = self.avg_pooling_layer(pool_size=2 * self.boundary_dist + 1, strides=1, padding='same')(gt)
1504
            boundaries = tf.cast(avg > 0., 'float32') * tf.cast(avg < (1 / len(self.spatial_axes) - 1e-4), 'float32')
1505
            if self.skip_background:
1506
                boundaries_channels = tf.unstack(boundaries, axis=-1)
1507
                boundaries = tf.stack([tf.zeros_like(boundaries_channels[0])] + boundaries_channels[1:], axis=-1)
1508
            boundary_weights_tensor = 1 + self.boundary_weights * boundaries
1509
            ce *= boundary_weights_tensor
1510
        else:
1511
            boundary_weights_tensor = None
1512
1513
        # apply class weighting across labels. By the end of this, ce still has the same shape has the input tensors.
1514
        if self.dynamic_weighting:  # the weight of a class is the inverse of its volume in the gt
1515
            if boundary_weights_tensor is not None:  # we account for the boundary weighting to compute volume
1516
                self.class_weights_tens = 1 / tf.reduce_sum(gt * boundary_weights_tensor, self.spatial_axes, True)
1517
            else:
1518
                self.class_weights_tens = 1 / tf.reduce_sum(gt, self.spatial_axes)
1519
        if self.class_weights_tens is not None:
1520
            self.class_weights_tens /= tf.reduce_sum(self.class_weights_tens, -1)
1521
            ce = tf.reduce_sum(ce * self.class_weights_tens, -1)
1522
1523
        # sum along label axis, and take the mean along spatial dimensions
1524
        ce = tf.math.reduce_mean(tf.math.reduce_sum(ce, axis=-1))
1525
1526
        return ce
1527
1528
    def compute_output_shape(self, input_shape):
1529
        return [[]]
1530
1531
1532
class MomentLoss(Layer):
1533
    """This layer computes a moment loss between two tensors. Specifically, it computes the distance between the centres
1534
    of gravity for all the channels of the two tensors, and then returns a value averaged across all channels.
1535
    These tensors are expected to have the same shape [batch, size_dim1, ..., size_dimN, n_channels].
1536
    The first input tensor is the GT and the second is the prediction: moment_loss = MomentLoss()([gt, pred])
1537
1538
    :param class_weights: (optional) if given, the loss is obtained by a weighted average of the Dice across labels.
1539
    Must be a sequence or 1d numpy array of length n_labels. Can also be -1, where the weights are dynamically set to
1540
    the inverse of the volume of each label in the ground truth.
1541
    :param enable_checks: (optional) whether to make sure that the 2 input tensors are probabilistic (i.e. the label
1542
    probabilities sum to 1 at each voxel location). Default is True.
1543
    """
1544
1545
    def __init__(self, class_weights=None, enable_checks=False, **kwargs):
1546
        self.class_weights = class_weights
1547
        self.dynamic_weighting = False
1548
        self.class_weights_tens = None
1549
        self.enable_checks = enable_checks
1550
        self.spatial_axes = None
1551
        self.coordinates = None
1552
        super(MomentLoss, self).__init__(**kwargs)
1553
1554
    def get_config(self):
1555
        config = super().get_config()
1556
        config["class_weights"] = self.class_weights
1557
        config["enable_checks"] = self.enable_checks
1558
        return config
1559
1560
    def build(self, input_shape):
1561
1562
        # get shape
1563
        assert len(input_shape) == 2, 'MomentLoss expects 2 inputs to compute the Dice loss.'
1564
        assert input_shape[0] == input_shape[1], 'the two inputs must have the same shape.'
1565
        inshape = input_shape[0][1:]
1566
        n_dims = len(inshape[:-1])
1567
        n_labels = inshape[-1]
1568
        self.spatial_axes = list(range(1, n_dims + 1))
1569
1570
        # build coordinate meshgrid of size (1, dim1, dim2, ..., dimN, ndim, nchan)
1571
        self.coordinates = tf.stack(nrn_utils.volshape_to_ndgrid(inshape[:-1]), -1)
1572
        self.coordinates = tf.cast(l2i_et.expand_dims(tf.stack([self.coordinates] * n_labels, -1), 0), 'float32')
1573
1574
        # build tensor with class weights
1575
        if self.class_weights is not None:
1576
            if self.class_weights == -1:
1577
                self.dynamic_weighting = True
1578
            else:
1579
                class_weights_tens = utils.reformat_to_list(self.class_weights, n_labels)
1580
                class_weights_tens = tf.convert_to_tensor(class_weights_tens, 'float32')
1581
                self.class_weights_tens = l2i_et.expand_dims(class_weights_tens, 0)
1582
1583
        self.built = True
1584
        super(MomentLoss, self).build(input_shape)
1585
1586
    def call(self, inputs, **kwargs):
1587
1588
        # make sure tensors are probabilistic
1589
        gt = inputs[0]  # (B, dim1, dim2, ..., dimN, nchan)
1590
        pred = inputs[1]
1591
        if self.enable_checks:  # disabling is useful to, e.g., use incomplete label maps
1592
            gt = gt / (tf.math.reduce_sum(gt, axis=-1, keepdims=True) + tf.keras.backend.epsilon())
1593
            pred = pred / (tf.math.reduce_sum(pred, axis=-1, keepdims=True) + tf.keras.backend.epsilon())
1594
1595
        # compute loss
1596
        gt_mean_coordinates = self._mean_coordinates(gt)  # (B, ndim, nchan)
1597
        pred_mean_coordinates = self._mean_coordinates(pred)
1598
        loss = tf.math.sqrt(tf.reduce_sum(tf.square(pred_mean_coordinates - gt_mean_coordinates), axis=1))  # (B, nchan)
1599
1600
        # apply class weighting across labels. In this case loss will have shape (batch), otherwise (batch, n_labels).
1601
        if self.dynamic_weighting:  # the weight of a class is the inverse of its volume in the gt
1602
            self.class_weights_tens = 1 / tf.reduce_sum(gt, self.spatial_axes)
1603
        if self.class_weights_tens is not None:
1604
            self.class_weights_tens /= tf.reduce_sum(self.class_weights_tens, -1)
1605
            loss = tf.reduce_sum(loss * self.class_weights_tens, -1)
1606
1607
        return tf.math.reduce_mean(loss)
1608
1609
    def _mean_coordinates(self, tensor):
1610
        tensor = l2i_et.expand_dims(tensor, axis=-2)  # (B, dim1, dim2, ..., dimN, 1, nchan)
1611
        numerator = tf.reduce_sum(tensor * self.coordinates, axis=self.spatial_axes)  # (B, ndim, nchan)
1612
        denominator = tf.reduce_sum(tensor, axis=self.spatial_axes) + tf.keras.backend.epsilon()
1613
        return numerator / denominator
1614
1615
    def compute_output_shape(self, input_shape):
1616
        return [[]]
1617
1618
1619
class ResetValuesToZero(Layer):
1620
    """This layer enables to reset given values to 0 within the input tensors.
1621
1622
    :param values: list of values to be reset to 0.
1623
1624
    example:
1625
    input = tf.convert_to_tensor(np.array([[1, 0, 2, 2, 2, 2, 0],
1626
                                           [1, 3, 3, 3, 3, 3, 3],
1627
                                           [1, 0, 0, 0, 4, 4, 4]]))
1628
    values = [1, 3]
1629
    ResetValuesToZero(values)(input)
1630
    >> [[0, 0, 2, 2, 2, 2, 0],
1631
        [0, 0, 0, 0, 0, 0, 0],
1632
        [0, 0, 0, 0, 4, 4, 4]]
1633
    """
1634
1635
    def __init__(self, values, **kwargs):
1636
        assert values is not None, 'please provide correct list of values, received None'
1637
        self.values = utils.reformat_to_list(values)
1638
        self.values_tens = None
1639
        self.n_values = len(values)
1640
        super(ResetValuesToZero, self).__init__(**kwargs)
1641
1642
    def get_config(self):
1643
        config = super().get_config()
1644
        config["values"] = self.values
1645
        return config
1646
1647
    def build(self, input_shape):
1648
        self.values_tens = tf.convert_to_tensor(self.values)
1649
        self.built = True
1650
        super(ResetValuesToZero, self).build(input_shape)
1651
1652
    def call(self, inputs, **kwargs):
1653
        values = tf.cast(self.values_tens, dtype=inputs.dtype)
1654
        for i in range(self.n_values):
1655
            inputs = tf.where(tf.equal(inputs, values[i]), tf.zeros_like(inputs), inputs)
1656
        return inputs
1657
1658
1659
class ConvertLabels(Layer):
1660
    """Convert all labels in a tensor by the corresponding given set of values.
1661
    labels_converted = ConvertLabels(source_values, dest_values)(labels).
1662
    labels must be an int32 tensor, and labels_converted will also be int32.
1663
1664
    :param source_values: list of all the possible values in labels. Must be a list or a 1D numpy array.
1665
    :param dest_values: list of all the target label values. Must be ordered the same as source values:
1666
    labels[labels == source_values[i]] = dest_values[i].
1667
    If None (default), dest_values is equal to [0, ..., N-1], where N is the total number of values in source_values,
1668
    which enables to remap label maps to [0, ..., N-1].
1669
    """
1670
1671
    def __init__(self, source_values, dest_values=None, **kwargs):
1672
        self.source_values = source_values
1673
        self.dest_values = dest_values
1674
        self.lut = None
1675
        super(ConvertLabels, self).__init__(**kwargs)
1676
1677
    def get_config(self):
1678
        config = super().get_config()
1679
        config["source_values"] = self.source_values
1680
        config["dest_values"] = self.dest_values
1681
        return config
1682
1683
    def build(self, input_shape):
1684
        self.lut = tf.convert_to_tensor(utils.get_mapping_lut(self.source_values, dest=self.dest_values), dtype='int32')
1685
        self.built = True
1686
        super(ConvertLabels, self).build(input_shape)
1687
1688
    def call(self, inputs, **kwargs):
1689
        return tf.gather(self.lut, tf.cast(inputs, dtype='int32'))
1690
1691
1692
class PadAroundCentre(Layer):
1693
    """Pad the input tensor to the specified shape with the given value.
1694
    The input tensor is expected to have shape [batchsize, shape_dim1, ..., shape_dimn, channel].
1695
    :param pad_margin: margin to use for padding. The tensor will be padded by the provided margin on each side.
1696
    Can either be a number (all axes padded with the same margin), or a  list/numpy array of length n_dims.
1697
    example: if tensor is of shape [batch, x, y, z, n_channels] and margin=10, then the padded tensor will be of
1698
    shape [batch, x+2*10, y+2*10, z+2*10, n_channels].
1699
    :param pad_shape: shape to pad the tensor to. Can either be a number (all axes padded to the same shape), or a
1700
    list/numpy array of length n_dims.
1701
    :param value: value to pad the tensors with. Default is 0.
1702
    """
1703
1704
    def __init__(self, pad_margin=None, pad_shape=None, value=0, **kwargs):
1705
        self.pad_margin = pad_margin
1706
        self.pad_shape = pad_shape
1707
        self.value = value
1708
        self.pad_margin_tens = None
1709
        self.pad_shape_tens = None
1710
        self.n_dims = None
1711
        super(PadAroundCentre, self).__init__(**kwargs)
1712
1713
    def get_config(self):
1714
        config = super().get_config()
1715
        config["pad_margin"] = self.pad_margin
1716
        config["pad_shape"] = self.pad_shape
1717
        config["value"] = self.value
1718
        return config
1719
1720
    def build(self, input_shape):
1721
        # input shape
1722
        self.n_dims = len(input_shape) - 2
1723
        shape = list(input_shape)
1724
        shape[0] = 0
1725
        shape[-1] = 0
1726
1727
        if self.pad_margin is not None:
1728
            assert self.pad_shape is None, 'please do not provide a padding shape and margin at the same time.'
1729
1730
            # reformat padding margins
1731
            pad = np.transpose(np.array([[0] + utils.reformat_to_list(self.pad_margin, self.n_dims) + [0]] * 2))
1732
            self.pad_margin_tens = tf.convert_to_tensor(pad, dtype='int32')
1733
1734
        elif self.pad_shape is not None:
1735
            assert self.pad_margin is None, 'please do not provide a padding shape and margin at the same time.'
1736
1737
            # pad shape
1738
            tensor_shape = tf.cast(tf.convert_to_tensor(shape), 'int32')
1739
            self.pad_shape_tens = np.array([0] + utils.reformat_to_list(self.pad_shape, length=self.n_dims) + [0])
1740
            self.pad_shape_tens = tf.convert_to_tensor(self.pad_shape_tens, dtype='int32')
1741
            self.pad_shape_tens = tf.math.maximum(tensor_shape, self.pad_shape_tens)
1742
1743
            # padding margin
1744
            min_margins = (self.pad_shape_tens - tensor_shape) / 2
1745
            max_margins = self.pad_shape_tens - tensor_shape - min_margins
1746
            self.pad_margin_tens = tf.stack([min_margins, max_margins], axis=-1)
1747
1748
        else:
1749
            raise Exception('please either provide a padding shape or a padding margin.')
1750
1751
        self.built = True
1752
        super(PadAroundCentre, self).build(input_shape)
1753
1754
    def call(self, inputs, **kwargs):
1755
        return tf.pad(inputs, self.pad_margin_tens, mode='CONSTANT', constant_values=self.value)
1756
1757
1758
class MaskEdges(Layer):
1759
    """Reset the edges of a tensor to zero (i.e. with bands of zeros along the specified axes).
1760
    The width of the zero-band is randomly drawn from a uniform distribution, whose range is given in boundaries.
1761
1762
    :param axes: axes along which to reset edges to zero. Can be an int (single axis), or a sequence.
1763
    :param boundaries: numpy array of shape (len(axes), 4). Each row contains the two bounds of the uniform
1764
    distributions from which we draw the width of the zero-bands on each side.
1765
    Those bounds must be expressed in relative side (i.e. between 0 and 1).
1766
    :return: a tensor of the same shape as the input, with bands of zeros along the specified axes.
1767
1768
    example:
1769
    tensor=tf.constant([[[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
1770
                       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
1771
                       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
1772
                       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
1773
                       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
1774
                       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
1775
                       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
1776
                       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
1777
                       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
1778
                       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]]])  # shape = [1,10,10,1]
1779
    axes=1
1780
    boundaries = np.array([[0.2, 0.45, 0.85, 0.9]])
1781
1782
    In this case, we reset the edges along the 2nd dimension (i.e. the 1st dimension after the batch dimension),
1783
    the 1st zero-band will expand from the 1st row to a number drawn from [0.2*tensor.shape[1], 0.45*tensor.shape[1]],
1784
    and the 2nd zero-band will expand from a row drawn from [0.85*tensor.shape[1], 0.9*tensor.shape[1]], to the end of
1785
    the tensor. A possible output could be:
1786
    array([[[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
1787
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
1788
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
1789
          [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
1790
          [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
1791
          [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
1792
          [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
1793
          [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
1794
          [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
1795
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]]])  # shape = [1,10,10,1]
1796
    """
1797
1798
    def __init__(self, axes, boundaries, prob_mask=1, **kwargs):
1799
        self.axes = utils.reformat_to_list(axes, dtype='int')
1800
        self.boundaries = utils.reformat_to_n_channels_array(boundaries, n_dims=4, n_channels=len(self.axes))
1801
        self.prob_mask = prob_mask
1802
        self.inputshape = None
1803
        super(MaskEdges, self).__init__(**kwargs)
1804
1805
    def get_config(self):
1806
        config = super().get_config()
1807
        config["axes"] = self.axes
1808
        config["boundaries"] = self.boundaries
1809
        config["prob_mask"] = self.prob_mask
1810
        return config
1811
1812
    def build(self, input_shape):
1813
        self.inputshape = input_shape
1814
        self.built = True
1815
        super(MaskEdges, self).build(input_shape)
1816
1817
    def call(self, inputs, **kwargs):
1818
1819
        # build mask
1820
        mask = tf.ones_like(inputs)
1821
        for i, axis in enumerate(self.axes):
1822
1823
            # select restricting indices
1824
            axis_boundaries = self.boundaries[i, :]
1825
            idx1 = tf.math.round(tf.random.uniform([1],
1826
                                                   minval=axis_boundaries[0] * self.inputshape[axis],
1827
                                                   maxval=axis_boundaries[1] * self.inputshape[axis]))
1828
            idx2 = tf.math.round(tf.random.uniform([1],
1829
                                                   minval=axis_boundaries[2] * self.inputshape[axis],
1830
                                                   maxval=axis_boundaries[3] * self.inputshape[axis] - 1) - idx1)
1831
            idx3 = self.inputshape[axis] - idx1 - idx2
1832
            split_idx = tf.cast(tf.concat([idx1, idx2, idx3], axis=0), dtype='int32')
1833
1834
            # update mask
1835
            split_list = tf.split(inputs, split_idx, axis=axis)
1836
            tmp_mask = tf.concat([tf.zeros_like(split_list[0]),
1837
                                  tf.ones_like(split_list[1]),
1838
                                  tf.zeros_like(split_list[2])], axis=axis)
1839
            mask = mask * tmp_mask
1840
1841
        # mask second_channel
1842
        tensor = K.switch(tf.squeeze(K.greater(tf.random.uniform([1], 0, 1), 1 - self.prob_mask)),
1843
                          inputs * mask,
1844
                          inputs)
1845
1846
        return [tensor, mask]
1847
1848
    def compute_output_shape(self, input_shape):
1849
        return [input_shape] * 2
1850
1851
1852
class ImageGradients(Layer):
1853
1854
    def __init__(self, gradient_type='sobel', return_magnitude=False, **kwargs):
1855
1856
        self.gradient_type = gradient_type
1857
        assert (self.gradient_type == 'sobel') | (self.gradient_type == '1-step_diff'), \
1858
            'gradient_type should be either sobel or 1-step_diff, had %s' % self.gradient_type
1859
1860
        # shape
1861
        self.n_dims = 0
1862
        self.shape = None
1863
        self.n_channels = 0
1864
1865
        # convolution params if sobel diff
1866
        self.stride = None
1867
        self.kernels = None
1868
        self.convnd = None
1869
1870
        self.return_magnitude = return_magnitude
1871
1872
        super(ImageGradients, self).__init__(**kwargs)
1873
1874
    def get_config(self):
1875
        config = super().get_config()
1876
        config["gradient_type"] = self.gradient_type
1877
        config["return_magnitude"] = self.return_magnitude
1878
        return config
1879
1880
    def build(self, input_shape):
1881
1882
        # get shapes
1883
        self.n_dims = len(input_shape) - 2
1884
        self.shape = input_shape[1:]
1885
        self.n_channels = input_shape[-1]
1886
1887
        # prepare kernel if sobel gradients
1888
        if self.gradient_type == 'sobel':
1889
            self.kernels = l2i_et.sobel_kernels(self.n_dims)
1890
            self.stride = [1] * (self.n_dims + 2)
1891
            self.convnd = getattr(tf.nn, 'conv%dd' % self.n_dims)
1892
        else:
1893
            self.kernels = self.convnd = self.stride = None
1894
1895
        self.built = True
1896
        super(ImageGradients, self).build(input_shape)
1897
1898
    def call(self, inputs, **kwargs):
1899
1900
        image = inputs
1901
        batchsize = tf.split(tf.shape(inputs), [1, -1])[0]
1902
        gradients = list()
1903
1904
        # sobel method
1905
        if self.gradient_type == 'sobel':
1906
            # get sobel gradients in each direction
1907
            for n in range(self.n_dims):
1908
                gradient = image
1909
                # apply 1D kernel in each direction (sobel kernels are separable), instead of applying a nD kernel
1910
                for k in self.kernels[n]:
1911
                    gradient = tf.concat([self.convnd(tf.expand_dims(gradient[..., n], -1), k, self.stride, 'SAME')
1912
                                          for n in range(self.n_channels)], -1)
1913
                gradients.append(gradient)
1914
1915
        # 1-step method, only supports 2 and 3D
1916
        else:
1917
1918
            # get 1-step diff
1919
            if self.n_dims == 2:
1920
                gradients.append(image[:, 1:, :, :] - image[:, :-1, :, :])  # dx
1921
                gradients.append(image[:, :, 1:, :] - image[:, :, :-1, :])  # dy
1922
1923
            elif self.n_dims == 3:
1924
                gradients.append(image[:, 1:, :, :, :] - image[:, :-1, :, :, :])  # dx
1925
                gradients.append(image[:, :, 1:, :, :] - image[:, :, :-1, :, :])  # dy
1926
                gradients.append(image[:, :, :, 1:, :] - image[:, :, :, :-1, :])  # dz
1927
1928
            else:
1929
                raise Exception('ImageGradients only support 2D or 3D tensors for 1-step diff, had: %dD' % self.n_dims)
1930
1931
            # pad with zeros to return tensors of the same shape as input
1932
            for i in range(self.n_dims):
1933
                tmp_shape = list(self.shape)
1934
                tmp_shape[i] = 1
1935
                zeros = tf.zeros(tf.concat([batchsize, tf.convert_to_tensor(tmp_shape, dtype='int32')], 0), image.dtype)
1936
                gradients[i] = tf.concat([gradients[i], zeros], axis=i + 1)
1937
1938
        # compute total gradient magnitude if necessary, or concatenate different gradients along the channel axis
1939
        if self.return_magnitude:
1940
            gradients = tf.sqrt(tf.reduce_sum(tf.square(tf.stack(gradients, axis=-1)), axis=-1))
1941
        else:
1942
            gradients = tf.concat(gradients, axis=-1)
1943
1944
        return gradients
1945
1946
    def compute_output_shape(self, input_shape):
1947
        if not self.return_magnitude:
1948
            input_shape = list(input_shape)
1949
            input_shape[-1] = self.n_dims
1950
        return tuple(input_shape)
1951
1952
1953
class RandomDilationErosion(Layer):
1954
    """
1955
    GPU implementation of binary dilation or erosion. The operation can be chosen to be always a dilation, or always an
1956
    erosion, or randomly choosing between them for each element of the batch.
1957
    The chosen operation is applied to the input with a given probability. Moreover, it is also possible to randomise
1958
    the factor of the operation for each element of the mini-batch.
1959
    :param min_factor: minimum possible value for the dilation/erosion factor. Must be an integer.
1960
    :param max_factor: minimum possible value for the dilation/erosion factor. Must be an integer.
1961
    Set it to the same value as min_factor to always perform dilation/erosion with the same factor.
1962
    :param prob: probability with which to apply the selected operation to the input.
1963
    :param operation: which operation to apply. Can be 'dilation' or 'erosion' or 'random'.
1964
    :param return_mask: if operation is erosion and the input of this layer is a label map, we have the
1965
    choice to either return the eroded label map or the mask (return_mask=True)
1966
    """
1967
1968
    def __init__(self, min_factor, max_factor, max_factor_dilate=None, prob=1, operation='random', return_mask=False,
1969
                 **kwargs):
1970
1971
        self.min_factor = min_factor
1972
        self.max_factor = max_factor
1973
        self.max_factor_dilate = max_factor_dilate if max_factor_dilate is not None else self.max_factor
1974
        self.prob = prob
1975
        self.operation = operation
1976
        self.return_mask = return_mask
1977
        self.n_dims = None
1978
        self.inshape = None
1979
        self.n_channels = None
1980
        self.convnd = None
1981
        super(RandomDilationErosion, self).__init__(**kwargs)
1982
1983
    def get_config(self):
1984
        config = super().get_config()
1985
        config["min_factor"] = self.min_factor
1986
        config["max_factor"] = self.max_factor
1987
        config["max_factor_dilate"] = self.max_factor_dilate
1988
        config["prob"] = self.prob
1989
        config["operation"] = self.operation
1990
        config["return_mask"] = self.return_mask
1991
        return config
1992
1993
    def build(self, input_shape):
1994
1995
        # input shape
1996
        self.inshape = input_shape
1997
        self.n_dims = len(self.inshape) - 2
1998
        self.n_channels = self.inshape[-1]
1999
2000
        # prepare convolution
2001
        self.convnd = getattr(tf.nn, 'conv%dd' % self.n_dims)
2002
2003
        self.built = True
2004
        super(RandomDilationErosion, self).build(input_shape)
2005
2006
    def call(self, inputs, **kwargs):
2007
2008
        # sample probability of applying operation. If random negative is erosion and positive is dilation
2009
        batchsize = tf.split(tf.shape(inputs), [1, -1])[0]
2010
        shape = tf.concat([batchsize, tf.convert_to_tensor([1], dtype='int32')], axis=0)
2011
        if self.operation == 'dilation':
2012
            prob = tf.random.uniform(shape, 0, 1)
2013
        elif self.operation == 'erosion':
2014
            prob = tf.random.uniform(shape, -1, 0)
2015
        elif self.operation == 'random':
2016
            prob = tf.random.uniform(shape, -1, 1)
2017
        else:
2018
            raise ValueError("operation should either be 'dilation' 'erosion' or 'random', had %s" % self.operation)
2019
2020
        # build kernel
2021
        if self.min_factor == self.max_factor:
2022
            dist_threshold = self.min_factor * tf.ones(shape, dtype='int32')
2023
        else:
2024
            if (self.max_factor == self.max_factor_dilate) | (self.operation != 'random'):
2025
                dist_threshold = tf.random.uniform(shape, minval=self.min_factor, maxval=self.max_factor, dtype='int32')
2026
            else:
2027
                dist_threshold = tf.cast(tf.map_fn(self._sample_factor, [prob], dtype=tf.float32), dtype='int32')
2028
        kernel = l2i_et.unit_kernel(dist_threshold, self.n_dims, max_dist_threshold=self.max_factor)
2029
2030
        # convolve input mask with kernel according to given probability
2031
        mask = tf.cast(tf.cast(inputs, dtype='bool'), dtype='float32')
2032
        mask = tf.map_fn(self._single_blur, [mask, kernel, prob], dtype=tf.float32)
2033
        mask = tf.cast(mask, 'bool')
2034
2035
        if self.return_mask:
2036
            return mask
2037
        else:
2038
            return inputs * tf.cast(mask, dtype=inputs.dtype)
2039
2040
    def _sample_factor(self, inputs):
2041
        return tf.cast(K.switch(K.less(tf.squeeze(inputs[0]), 0),
2042
                                tf.random.uniform((1,), self.min_factor, self.max_factor, dtype='int32'),
2043
                                tf.random.uniform((1,), self.min_factor, self.max_factor_dilate, dtype='int32')),
2044
                       dtype='float32')
2045
2046
    def _single_blur(self, inputs):
2047
        # dilate...
2048
        new_mask = K.switch(K.greater(tf.squeeze(inputs[2]), 1 - self.prob + 0.001),
2049
                            tf.cast(tf.greater(tf.squeeze(self.convnd(tf.expand_dims(inputs[0], 0), inputs[1],
2050
                                    [1] * (self.n_dims + 2), padding='SAME'), axis=0), 0.01), dtype='float32'),
2051
                            inputs[0])
2052
        # ...or erode
2053
        new_mask = K.switch(K.less(tf.squeeze(inputs[2]), - (1 - self.prob + 0.001)),
2054
                            1 - tf.cast(tf.greater(tf.squeeze(self.convnd(tf.expand_dims(1 - new_mask, 0), inputs[1],
2055
                                        [1] * (self.n_dims + 2), padding='SAME'), axis=0), 0.01), dtype='float32'),
2056
                            new_mask)
2057
        return new_mask
2058
2059
    def compute_output_shape(self, input_shape):
2060
        return input_shape