|
a |
|
b/ext/lab2im/layers.py |
|
|
1 |
""" |
|
|
2 |
This file regroups several custom keras layers used in the generation model: |
|
|
3 |
- RandomSpatialDeformation, |
|
|
4 |
- RandomCrop, |
|
|
5 |
- RandomFlip, |
|
|
6 |
- SampleConditionalGMM, |
|
|
7 |
- SampleResolution, |
|
|
8 |
- GaussianBlur, |
|
|
9 |
- DynamicGaussianBlur, |
|
|
10 |
- MimicAcquisition, |
|
|
11 |
- BiasFieldCorruption, |
|
|
12 |
- IntensityAugmentation, |
|
|
13 |
- DiceLoss, |
|
|
14 |
- WeightedL2Loss, |
|
|
15 |
- ResetValuesToZero, |
|
|
16 |
- ConvertLabels, |
|
|
17 |
- PadAroundCentre, |
|
|
18 |
- MaskEdges |
|
|
19 |
- ImageGradients |
|
|
20 |
- RandomDilationErosion |
|
|
21 |
|
|
|
22 |
|
|
|
23 |
If you use this code, please cite the first SynthSeg paper: |
|
|
24 |
https://github.com/BBillot/lab2im/blob/master/bibtex.bib |
|
|
25 |
|
|
|
26 |
Copyright 2020 Benjamin Billot |
|
|
27 |
|
|
|
28 |
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in |
|
|
29 |
compliance with the License. You may obtain a copy of the License at |
|
|
30 |
https://www.apache.org/licenses/LICENSE-2.0 |
|
|
31 |
Unless required by applicable law or agreed to in writing, software distributed under the License is |
|
|
32 |
distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or |
|
|
33 |
implied. See the License for the specific language governing permissions and limitations under the |
|
|
34 |
License. |
|
|
35 |
""" |
|
|
36 |
|
|
|
37 |
|
|
|
38 |
# python imports |
|
|
39 |
import keras |
|
|
40 |
import numpy as np |
|
|
41 |
import tensorflow as tf |
|
|
42 |
import keras.backend as K |
|
|
43 |
from keras.layers import Layer |
|
|
44 |
|
|
|
45 |
# project imports |
|
|
46 |
from ext.lab2im import utils |
|
|
47 |
from ext.lab2im import edit_tensors as l2i_et |
|
|
48 |
|
|
|
49 |
# third-party imports |
|
|
50 |
from ext.neuron import utils as nrn_utils |
|
|
51 |
import ext.neuron.layers as nrn_layers |
|
|
52 |
|
|
|
53 |
|
|
|
54 |
class RandomSpatialDeformation(Layer): |
|
|
55 |
"""This layer spatially deforms one or several tensors with a combination of affine and elastic transformations. |
|
|
56 |
The input tensors are expected to have the same shape [batchsize, shape_dim1, ..., shape_dimn, channel]. |
|
|
57 |
The non-linear deformation is obtained by: |
|
|
58 |
1) a small-size SVF is sampled from a centred normal distribution of random standard deviation. |
|
|
59 |
2) it is resized with trilinear interpolation to half the shape of the input tensor |
|
|
60 |
3) it is integrated to obtain a diffeomorphic transformation |
|
|
61 |
4) finally, it is resized (again with trilinear interpolation) to full image size |
|
|
62 |
:param scaling_bounds: (optional) range of the random scaling to apply. The scaling factor for each dimension is |
|
|
63 |
sampled from a uniform distribution of predefined bounds. Can either be: |
|
|
64 |
1) a number, in which case the scaling factor is independently sampled from the uniform distribution of bounds |
|
|
65 |
[1-scaling_bounds, 1+scaling_bounds] for each dimension. |
|
|
66 |
2) a sequence, in which case the scaling factor is sampled from the uniform distribution of bounds |
|
|
67 |
(1-scaling_bounds[i], 1+scaling_bounds[i]) for the i-th dimension. |
|
|
68 |
3) a numpy array of shape (2, n_dims), in which case the scaling factor is sampled from the uniform distribution |
|
|
69 |
of bounds (scaling_bounds[0, i], scaling_bounds[1, i]) for the i-th dimension. |
|
|
70 |
4) False, in which case scaling is completely turned off. |
|
|
71 |
Default is scaling_bounds = 0.15 (case 1) |
|
|
72 |
:param rotation_bounds: (optional) same as scaling bounds but for the rotation angle, except that for cases 1 |
|
|
73 |
and 2, the bounds are centred on 0 rather than 1, i.e. [0+rotation_bounds[i], 0-rotation_bounds[i]]. |
|
|
74 |
Default is rotation_bounds = 15. |
|
|
75 |
:param shearing_bounds: (optional) same as scaling bounds. Default is shearing_bounds = 0.012. |
|
|
76 |
:param translation_bounds: (optional) same as scaling bounds. Default is translation_bounds = False, but we |
|
|
77 |
encourage using it when cropping is deactivated (i.e. when output_shape=None in BrainGenerator). |
|
|
78 |
:param enable_90_rotations: (optional) whether to rotate the input by a random angle chosen in {0, 90, 180, 270}. |
|
|
79 |
This is done regardless of the value of rotation_bounds. If true, a different value is sampled for each dimension. |
|
|
80 |
:param nonlin_std: (optional) maximum value of the standard deviation of the normal distribution from which we |
|
|
81 |
sample the small-size SVF. Set to 0 if you wish to completely turn the elastic deformation off. |
|
|
82 |
:param nonlin_scale: (optional) if nonlin_std is not False, factor between the shapes of the input tensor |
|
|
83 |
and the shape of the input non-linear tensor. |
|
|
84 |
:param inter_method: (optional) interpolation method when deforming the input tensor. Can be 'linear', or 'nearest' |
|
|
85 |
:param prob_deform: (optional) probability to apply spatial deformation |
|
|
86 |
""" |
|
|
87 |
|
|
|
88 |
def __init__(self, |
|
|
89 |
scaling_bounds=0.15, |
|
|
90 |
rotation_bounds=10, |
|
|
91 |
shearing_bounds=0.02, |
|
|
92 |
translation_bounds=False, |
|
|
93 |
enable_90_rotations=False, |
|
|
94 |
nonlin_std=4., |
|
|
95 |
nonlin_scale=.0625, |
|
|
96 |
inter_method='linear', |
|
|
97 |
prob_deform=1, |
|
|
98 |
**kwargs): |
|
|
99 |
|
|
|
100 |
# shape attributes |
|
|
101 |
self.n_inputs = 1 |
|
|
102 |
self.inshape = None |
|
|
103 |
self.n_dims = None |
|
|
104 |
self.small_shape = None |
|
|
105 |
|
|
|
106 |
# deformation attributes |
|
|
107 |
self.scaling_bounds = scaling_bounds |
|
|
108 |
self.rotation_bounds = rotation_bounds |
|
|
109 |
self.shearing_bounds = shearing_bounds |
|
|
110 |
self.translation_bounds = translation_bounds |
|
|
111 |
self.enable_90_rotations = enable_90_rotations |
|
|
112 |
self.nonlin_std = nonlin_std |
|
|
113 |
self.nonlin_scale = nonlin_scale |
|
|
114 |
|
|
|
115 |
# boolean attributes |
|
|
116 |
self.apply_affine_trans = (self.scaling_bounds is not False) | (self.rotation_bounds is not False) | \ |
|
|
117 |
(self.shearing_bounds is not False) | (self.translation_bounds is not False) | \ |
|
|
118 |
self.enable_90_rotations |
|
|
119 |
self.apply_elastic_trans = self.nonlin_std > 0 |
|
|
120 |
self.prob_deform = prob_deform |
|
|
121 |
|
|
|
122 |
# interpolation methods |
|
|
123 |
self.inter_method = inter_method |
|
|
124 |
|
|
|
125 |
super(RandomSpatialDeformation, self).__init__(**kwargs) |
|
|
126 |
|
|
|
127 |
def get_config(self): |
|
|
128 |
config = super().get_config() |
|
|
129 |
config["scaling_bounds"] = self.scaling_bounds |
|
|
130 |
config["rotation_bounds"] = self.rotation_bounds |
|
|
131 |
config["shearing_bounds"] = self.shearing_bounds |
|
|
132 |
config["translation_bounds"] = self.translation_bounds |
|
|
133 |
config["enable_90_rotations"] = self.enable_90_rotations |
|
|
134 |
config["nonlin_std"] = self.nonlin_std |
|
|
135 |
config["nonlin_scale"] = self.nonlin_scale |
|
|
136 |
config["inter_method"] = self.inter_method |
|
|
137 |
config["prob_deform"] = self.prob_deform |
|
|
138 |
return config |
|
|
139 |
|
|
|
140 |
def build(self, input_shape): |
|
|
141 |
|
|
|
142 |
if not isinstance(input_shape, list): |
|
|
143 |
inputshape = [input_shape] |
|
|
144 |
else: |
|
|
145 |
self.n_inputs = len(input_shape) |
|
|
146 |
inputshape = input_shape |
|
|
147 |
self.inshape = inputshape[0][1:] |
|
|
148 |
self.n_dims = len(self.inshape) - 1 |
|
|
149 |
|
|
|
150 |
if self.apply_elastic_trans: |
|
|
151 |
self.small_shape = utils.get_resample_shape(self.inshape[:self.n_dims], |
|
|
152 |
self.nonlin_scale, self.n_dims) |
|
|
153 |
else: |
|
|
154 |
self.small_shape = None |
|
|
155 |
|
|
|
156 |
self.inter_method = utils.reformat_to_list(self.inter_method, length=self.n_inputs, dtype='str') |
|
|
157 |
|
|
|
158 |
self.built = True |
|
|
159 |
super(RandomSpatialDeformation, self).build(input_shape) |
|
|
160 |
|
|
|
161 |
def call(self, inputs, **kwargs): |
|
|
162 |
|
|
|
163 |
# reformat inputs and get its shape |
|
|
164 |
if self.n_inputs < 2: |
|
|
165 |
inputs = [inputs] |
|
|
166 |
types = [v.dtype for v in inputs] |
|
|
167 |
inputs = [tf.cast(v, dtype='float32') for v in inputs] |
|
|
168 |
batchsize = tf.split(tf.shape(inputs[0]), [1, self.n_dims + 1])[0] |
|
|
169 |
|
|
|
170 |
# initialise list of transforms to operate |
|
|
171 |
list_trans = list() |
|
|
172 |
|
|
|
173 |
# add affine deformation to inputs list |
|
|
174 |
if self.apply_affine_trans: |
|
|
175 |
affine_trans = utils.sample_affine_transform(batchsize, |
|
|
176 |
self.n_dims, |
|
|
177 |
self.rotation_bounds, |
|
|
178 |
self.scaling_bounds, |
|
|
179 |
self.shearing_bounds, |
|
|
180 |
self.translation_bounds, |
|
|
181 |
self.enable_90_rotations) |
|
|
182 |
list_trans.append(affine_trans) |
|
|
183 |
|
|
|
184 |
# prepare non-linear deformation field and add it to inputs list |
|
|
185 |
if self.apply_elastic_trans: |
|
|
186 |
|
|
|
187 |
# sample small field from normal distribution of specified std dev |
|
|
188 |
trans_shape = tf.concat([batchsize, tf.convert_to_tensor(self.small_shape, dtype='int32')], axis=0) |
|
|
189 |
trans_std = tf.random.uniform((1, 1), maxval=self.nonlin_std) |
|
|
190 |
elastic_trans = tf.random.normal(trans_shape, stddev=trans_std) |
|
|
191 |
|
|
|
192 |
# reshape this field to half size (for smoother SVF), integrate it, and reshape to full image size |
|
|
193 |
resize_shape = [max(int(self.inshape[i] / 2), self.small_shape[i]) for i in range(self.n_dims)] |
|
|
194 |
elastic_trans = nrn_layers.Resize(size=resize_shape, interp_method='linear')(elastic_trans) |
|
|
195 |
elastic_trans = nrn_layers.VecInt()(elastic_trans) |
|
|
196 |
elastic_trans = nrn_layers.Resize(size=self.inshape[:self.n_dims], interp_method='linear')(elastic_trans) |
|
|
197 |
list_trans.append(elastic_trans) |
|
|
198 |
|
|
|
199 |
# apply deformations and return tensors with correct dtype |
|
|
200 |
if self.apply_affine_trans | self.apply_elastic_trans: |
|
|
201 |
if self.prob_deform == 1: |
|
|
202 |
inputs = [nrn_layers.SpatialTransformer(m)([v] + list_trans) for (m, v) in |
|
|
203 |
zip(self.inter_method, inputs)] |
|
|
204 |
else: |
|
|
205 |
rand_trans = tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_deform)) |
|
|
206 |
inputs = [K.switch(rand_trans, nrn_layers.SpatialTransformer(m)([v] + list_trans), v) |
|
|
207 |
for (m, v) in zip(self.inter_method, inputs)] |
|
|
208 |
if self.n_inputs < 2: |
|
|
209 |
return tf.cast(inputs[0], types[0]) |
|
|
210 |
else: |
|
|
211 |
return [tf.cast(v, t) for (t, v) in zip(types, inputs)] |
|
|
212 |
|
|
|
213 |
|
|
|
214 |
class RandomCrop(Layer): |
|
|
215 |
"""Randomly crop all input tensors to a given shape. This cropping is applied to all channels. |
|
|
216 |
The input tensors are expected to have shape [batchsize, shape_dim1, ..., shape_dimn, channel]. |
|
|
217 |
:param crop_shape: list with cropping shape in each dimension (excluding batch and channel dimension) |
|
|
218 |
|
|
|
219 |
example: |
|
|
220 |
if input is a tensor of shape [batchsize, 160, 160, 160, 3], |
|
|
221 |
output = RandomCrop(crop_shape=[96, 128, 96])(input) |
|
|
222 |
will yield an output of shape [batchsize, 96, 128, 96, 3] that is obtained by cropping with randomly selected |
|
|
223 |
cropping indices. |
|
|
224 |
""" |
|
|
225 |
|
|
|
226 |
def __init__(self, crop_shape, **kwargs): |
|
|
227 |
|
|
|
228 |
self.several_inputs = True |
|
|
229 |
self.crop_max_val = None |
|
|
230 |
self.crop_shape = crop_shape |
|
|
231 |
self.n_dims = len(crop_shape) |
|
|
232 |
self.list_n_channels = None |
|
|
233 |
super(RandomCrop, self).__init__(**kwargs) |
|
|
234 |
|
|
|
235 |
def get_config(self): |
|
|
236 |
config = super().get_config() |
|
|
237 |
config["crop_shape"] = self.crop_shape |
|
|
238 |
return config |
|
|
239 |
|
|
|
240 |
def build(self, input_shape): |
|
|
241 |
|
|
|
242 |
if not isinstance(input_shape, list): |
|
|
243 |
self.several_inputs = False |
|
|
244 |
inputshape = [input_shape] |
|
|
245 |
else: |
|
|
246 |
inputshape = input_shape |
|
|
247 |
self.crop_max_val = np.array(np.array(inputshape[0][1:self.n_dims + 1])) - np.array(self.crop_shape) |
|
|
248 |
self.list_n_channels = [i[-1] for i in inputshape] |
|
|
249 |
self.built = True |
|
|
250 |
super(RandomCrop, self).build(input_shape) |
|
|
251 |
|
|
|
252 |
def call(self, inputs, **kwargs): |
|
|
253 |
|
|
|
254 |
# if one input only is provided, performs the cropping directly |
|
|
255 |
if not self.several_inputs: |
|
|
256 |
return tf.map_fn(self._single_slice, inputs, dtype=inputs.dtype) |
|
|
257 |
|
|
|
258 |
# otherwise we concatenate all inputs before cropping, so that they are all cropped at the same location |
|
|
259 |
else: |
|
|
260 |
types = [v.dtype for v in inputs] |
|
|
261 |
inputs = tf.concat([tf.cast(v, 'float32') for v in inputs], axis=-1) |
|
|
262 |
inputs = tf.map_fn(self._single_slice, inputs, dtype=tf.float32) |
|
|
263 |
inputs = tf.split(inputs, self.list_n_channels, axis=-1) |
|
|
264 |
return [tf.cast(v, t) for (t, v) in zip(types, inputs)] |
|
|
265 |
|
|
|
266 |
def _single_slice(self, vol): |
|
|
267 |
crop_idx = tf.cast(tf.random.uniform([self.n_dims], 0, np.array(self.crop_max_val), 'float32'), dtype='int32') |
|
|
268 |
crop_idx = tf.concat([crop_idx, tf.zeros([1], dtype='int32')], axis=0) |
|
|
269 |
crop_size = tf.convert_to_tensor(self.crop_shape + [-1], dtype='int32') |
|
|
270 |
return tf.slice(vol, begin=crop_idx, size=crop_size) |
|
|
271 |
|
|
|
272 |
def compute_output_shape(self, input_shape): |
|
|
273 |
output_shape = [tuple([None] + self.crop_shape + [v]) for v in self.list_n_channels] |
|
|
274 |
return output_shape if self.several_inputs else output_shape[0] |
|
|
275 |
|
|
|
276 |
|
|
|
277 |
class RandomFlip(Layer): |
|
|
278 |
"""This layer randomly flips the input tensor along the specified axes with a specified probability. |
|
|
279 |
It can also take multiple tensors as inputs (if they have the same shape). The same flips will be applied to all |
|
|
280 |
input tensors. These are expected to have shape [batchsize, shape_dim1, ..., shape_dimn, channel]. |
|
|
281 |
If specified, this layer can also swap corresponding values. This is especially useful when flipping label maps |
|
|
282 |
with different labels for right/left structures, such that the flipped label maps keep a consistent labelling. |
|
|
283 |
:param axis: integer, or list of integers specifying the dimensions along which to flip. |
|
|
284 |
If a list, the input tensors can be flipped simultaneously in several directions. The values in flip_axis exclude |
|
|
285 |
the batch dimension (e.g. 0 will flip the tensor along the first axis after the batch dimension). |
|
|
286 |
Default is None, where the tensors can be flipped along all axes (except batch and channel axes). |
|
|
287 |
:param swap_labels: boolean to specify whether to swap the values of each input. Values are only swapped if an odd |
|
|
288 |
number of flips is applied. |
|
|
289 |
Can also be a list if several tensors are given as input. |
|
|
290 |
All the inputs for which the values need to be swapped must be int32 or int64. |
|
|
291 |
:param label_list: if swap_labels is True, list of all labels contained in labels. Must be ordered as follows, first |
|
|
292 |
the neutral labels (i.e. non-sided), then left labels and right labels. |
|
|
293 |
:param n_neutral_labels: if swap_labels is True, number of non-sided labels |
|
|
294 |
:param prob: probability to flip along each specified axis |
|
|
295 |
|
|
|
296 |
example 1: |
|
|
297 |
if input is a tensor of shape (batchsize, 10, 100, 200, 3) |
|
|
298 |
output = RandomFlip()(input) will randomly flip input along one of the 1st, 2nd, or 3rd axis (i.e. those with shape |
|
|
299 |
10, 100, 200). |
|
|
300 |
|
|
|
301 |
example 2: |
|
|
302 |
if input is a tensor of shape (batchsize, 10, 100, 200, 3) |
|
|
303 |
output = RandomFlip(flip_axis=1)(input) will randomly flip input along the 3rd axis (with shape 100), i.e. the axis |
|
|
304 |
with index 1 if we don't count the batch axis. |
|
|
305 |
|
|
|
306 |
example 3: |
|
|
307 |
input = tf.convert_to_tensor(np.array([[1, 0, 0, 0, 0, 0, 0], |
|
|
308 |
[1, 0, 0, 0, 2, 2, 0], |
|
|
309 |
[1, 0, 0, 0, 2, 2, 0], |
|
|
310 |
[1, 0, 0, 0, 2, 2, 0], |
|
|
311 |
[1, 0, 0, 0, 0, 0, 0]])) |
|
|
312 |
label_list = np.array([0, 1, 2]) |
|
|
313 |
n_neutral_labels = 1 |
|
|
314 |
output = RandomFlip(flip_axis=1, swap_labels=True, label_list=label_list, n_neutral_labels=n_neutral_labels)(input) |
|
|
315 |
where output will either be equal to input (bear in mind the flipping occurs with a 0.5 probability), or: |
|
|
316 |
output = [[0, 0, 0, 0, 0, 0, 2], |
|
|
317 |
[0, 1, 1, 0, 0, 0, 2], |
|
|
318 |
[0, 1, 1, 0, 0, 0, 2], |
|
|
319 |
[0, 1, 1, 0, 0, 0, 2], |
|
|
320 |
[0, 0, 0, 0, 0, 0, 2]] |
|
|
321 |
Note that the input must have a dtype int32 or int64 for its values to be swapped, otherwise an error will be raised |
|
|
322 |
|
|
|
323 |
example 4: |
|
|
324 |
if labels is the same as in the input of example 3, and image is a float32 image, then we can swap consistently both |
|
|
325 |
the labels and the image with: |
|
|
326 |
labels, image = RandomFlip(flip_axis=1, swap_labels=[True, False], label_list=label_list, |
|
|
327 |
n_neutral_labels=n_neutral_labels)([labels, image]]) |
|
|
328 |
Note that the labels must have a dtype int32 or int64 to be swapped, otherwise an error will be raised. |
|
|
329 |
This doesn't concern the image input, as its values are not swapped. |
|
|
330 |
""" |
|
|
331 |
|
|
|
332 |
def __init__(self, axis=None, swap_labels=False, label_list=None, n_neutral_labels=None, prob=0.5, **kwargs): |
|
|
333 |
|
|
|
334 |
# shape attributes |
|
|
335 |
self.several_inputs = True |
|
|
336 |
self.n_dims = None |
|
|
337 |
self.list_n_channels = None |
|
|
338 |
|
|
|
339 |
# axis along which to flip |
|
|
340 |
self.axis = utils.reformat_to_list(axis) |
|
|
341 |
self.flip_axes = None |
|
|
342 |
|
|
|
343 |
# whether to swap labels, and corresponding label list |
|
|
344 |
self.swap_labels = utils.reformat_to_list(swap_labels) |
|
|
345 |
self.label_list = label_list |
|
|
346 |
self.n_neutral_labels = n_neutral_labels |
|
|
347 |
self.swap_lut = None |
|
|
348 |
|
|
|
349 |
self.prob = prob |
|
|
350 |
|
|
|
351 |
super(RandomFlip, self).__init__(**kwargs) |
|
|
352 |
|
|
|
353 |
def get_config(self): |
|
|
354 |
config = super().get_config() |
|
|
355 |
config["axis"] = self.axis |
|
|
356 |
config["swap_labels"] = self.swap_labels |
|
|
357 |
config["label_list"] = self.label_list |
|
|
358 |
config["n_neutral_labels"] = self.n_neutral_labels |
|
|
359 |
config["prob"] = self.prob |
|
|
360 |
return config |
|
|
361 |
|
|
|
362 |
def build(self, input_shape): |
|
|
363 |
|
|
|
364 |
if not isinstance(input_shape, list): |
|
|
365 |
self.several_inputs = False |
|
|
366 |
inputshape = [input_shape] |
|
|
367 |
else: |
|
|
368 |
inputshape = input_shape |
|
|
369 |
self.n_dims = len(inputshape[0][1:-1]) |
|
|
370 |
self.list_n_channels = [i[-1] for i in inputshape] |
|
|
371 |
self.swap_labels = utils.reformat_to_list(self.swap_labels, length=len(inputshape)) |
|
|
372 |
self.flip_axes = np.arange(self.n_dims).tolist() if self.axis is None else self.axis |
|
|
373 |
|
|
|
374 |
# create label list with swapped labels |
|
|
375 |
if any(self.swap_labels): |
|
|
376 |
assert (self.label_list is not None) & (self.n_neutral_labels is not None), \ |
|
|
377 |
'please provide a label_list, and n_neutral_labels when swapping the values of at least one input' |
|
|
378 |
n_labels = len(self.label_list) |
|
|
379 |
if self.n_neutral_labels == n_labels: |
|
|
380 |
self.swap_labels = [False] * len(self.swap_labels) |
|
|
381 |
else: |
|
|
382 |
rl_split = np.split(self.label_list, [self.n_neutral_labels, |
|
|
383 |
self.n_neutral_labels + int((n_labels-self.n_neutral_labels)/2)]) |
|
|
384 |
label_list_swap = np.concatenate((rl_split[0], rl_split[2], rl_split[1])) |
|
|
385 |
swap_lut = utils.get_mapping_lut(self.label_list, label_list_swap) |
|
|
386 |
self.swap_lut = tf.convert_to_tensor(swap_lut, dtype='int32') |
|
|
387 |
|
|
|
388 |
self.built = True |
|
|
389 |
super(RandomFlip, self).build(input_shape) |
|
|
390 |
|
|
|
391 |
def call(self, inputs, **kwargs): |
|
|
392 |
|
|
|
393 |
# convert inputs to list, and get each input type |
|
|
394 |
inputs = [inputs] if not self.several_inputs else inputs |
|
|
395 |
types = [v.dtype for v in inputs] |
|
|
396 |
|
|
|
397 |
# store whether to flip along each specified dimension |
|
|
398 |
batchsize = tf.split(tf.shape(inputs[0]), [1, self.n_dims + 1])[0] |
|
|
399 |
size = tf.concat([batchsize, len(self.flip_axes) * tf.ones(1, dtype='int32')], axis=0) |
|
|
400 |
rand_flip = K.less(tf.random.uniform(size, 0, 1), self.prob) |
|
|
401 |
|
|
|
402 |
# swap right/left labels if we apply an odd number of flips |
|
|
403 |
odd = tf.math.floormod(tf.reduce_sum(tf.cast(rand_flip, 'int32'), -1, keepdims=True), 2) != 0 |
|
|
404 |
swapped_inputs = list() |
|
|
405 |
for i in range(len(inputs)): |
|
|
406 |
if self.swap_labels[i]: |
|
|
407 |
swapped_inputs.append(tf.map_fn(self._single_swap, [inputs[i], odd], dtype=types[i])) |
|
|
408 |
else: |
|
|
409 |
swapped_inputs.append(inputs[i]) |
|
|
410 |
|
|
|
411 |
# flip inputs and convert them back to their original type |
|
|
412 |
inputs = tf.concat([tf.cast(v, 'float32') for v in swapped_inputs], axis=-1) |
|
|
413 |
inputs = tf.map_fn(self._single_flip, [inputs, rand_flip], dtype=tf.float32) |
|
|
414 |
inputs = tf.split(inputs, self.list_n_channels, axis=-1) |
|
|
415 |
|
|
|
416 |
if self.several_inputs: |
|
|
417 |
return [tf.cast(v, t) for (t, v) in zip(types, inputs)] |
|
|
418 |
else: |
|
|
419 |
return tf.cast(inputs[0], types[0]) |
|
|
420 |
|
|
|
421 |
def _single_swap(self, inputs): |
|
|
422 |
return K.switch(inputs[1], tf.gather(self.swap_lut, inputs[0]), inputs[0]) |
|
|
423 |
|
|
|
424 |
@staticmethod |
|
|
425 |
def _single_flip(inputs): |
|
|
426 |
flip_axis = tf.where(inputs[1]) |
|
|
427 |
return K.switch(tf.equal(tf.size(flip_axis), 0), inputs[0], tf.reverse(inputs[0], axis=flip_axis[..., 0])) |
|
|
428 |
|
|
|
429 |
|
|
|
430 |
class SampleConditionalGMM(Layer): |
|
|
431 |
"""This layer generates an image by sampling a Gaussian Mixture Model conditioned on a label map given as input. |
|
|
432 |
The parameters of the GMM are given as two additional inputs to the layer (means and standard deviations): |
|
|
433 |
image = SampleConditionalGMM(generation_labels)([label_map, means, stds]) |
|
|
434 |
|
|
|
435 |
:param generation_labels: list of all possible label values contained in the input label maps. |
|
|
436 |
Must be a list or a 1D numpy array of size N, where N is the total number of possible label values. |
|
|
437 |
|
|
|
438 |
Layer inputs: |
|
|
439 |
label_map: input label map of shape [batchsize, shape_dim1, ..., shape_dimn, n_channel]. |
|
|
440 |
All the values of label_map must be contained in generation_labels, but the input label_map doesn't necessarily have |
|
|
441 |
to contain all the values in generation_labels. |
|
|
442 |
means: tensor containing the mean values of all Gaussian distributions of the GMM. |
|
|
443 |
It must be of shape [batchsize, N, n_channel], and in the same order as generation label, |
|
|
444 |
i.e. the ith value of generation_labels will be associated to the ith value of means. |
|
|
445 |
stds: same as means but for the standard deviations of the GMM. |
|
|
446 |
""" |
|
|
447 |
|
|
|
448 |
def __init__(self, generation_labels, **kwargs): |
|
|
449 |
self.generation_labels = generation_labels |
|
|
450 |
self.n_labels = None |
|
|
451 |
self.n_channels = None |
|
|
452 |
self.max_label = None |
|
|
453 |
self.indices = None |
|
|
454 |
self.shape = None |
|
|
455 |
super(SampleConditionalGMM, self).__init__(**kwargs) |
|
|
456 |
|
|
|
457 |
def get_config(self): |
|
|
458 |
config = super().get_config() |
|
|
459 |
config["generation_labels"] = self.generation_labels |
|
|
460 |
return config |
|
|
461 |
|
|
|
462 |
def build(self, input_shape): |
|
|
463 |
|
|
|
464 |
# check n_labels and n_channels |
|
|
465 |
assert len(input_shape) == 3, 'should have three inputs: labels, means, std devs (in that order).' |
|
|
466 |
self.n_channels = input_shape[1][-1] |
|
|
467 |
self.n_labels = len(self.generation_labels) |
|
|
468 |
assert self.n_labels == input_shape[1][1], 'means should have the same number of values as generation_labels' |
|
|
469 |
assert self.n_labels == input_shape[2][1], 'stds should have the same number of values as generation_labels' |
|
|
470 |
|
|
|
471 |
# scatter parameters (to build mean/std lut) |
|
|
472 |
self.max_label = np.max(self.generation_labels) + 1 |
|
|
473 |
indices = np.concatenate([self.generation_labels + self.max_label * i for i in range(self.n_channels)], axis=-1) |
|
|
474 |
self.shape = tf.convert_to_tensor([np.max(indices) + 1], dtype='int32') |
|
|
475 |
self.indices = tf.convert_to_tensor(utils.add_axis(indices, axis=[0, -1]), dtype='int32') |
|
|
476 |
|
|
|
477 |
self.built = True |
|
|
478 |
super(SampleConditionalGMM, self).build(input_shape) |
|
|
479 |
|
|
|
480 |
def call(self, inputs, **kwargs): |
|
|
481 |
|
|
|
482 |
# reformat labels and scatter indices |
|
|
483 |
batch = tf.split(tf.shape(inputs[0]), [1, -1])[0] |
|
|
484 |
tmp_indices = tf.tile(self.indices, tf.concat([batch, tf.convert_to_tensor([1, 1], dtype='int32')], axis=0)) |
|
|
485 |
labels = tf.concat([tf.cast(inputs[0], dtype='int32') + self.max_label * i for i in range(self.n_channels)], -1) |
|
|
486 |
|
|
|
487 |
# build mean map |
|
|
488 |
means = tf.concat([inputs[1][..., i] for i in range(self.n_channels)], 1) |
|
|
489 |
tile_shape = tf.concat([batch, tf.convert_to_tensor([1, ], dtype='int32')], axis=0) |
|
|
490 |
means = tf.tile(tf.expand_dims(tf.scatter_nd(tmp_indices, means, self.shape), 0), tile_shape) |
|
|
491 |
means_map = tf.map_fn(lambda x: tf.gather(x[0], x[1]), [means, labels], dtype=tf.float32) |
|
|
492 |
|
|
|
493 |
# same for stds |
|
|
494 |
stds = tf.concat([inputs[2][..., i] for i in range(self.n_channels)], 1) |
|
|
495 |
stds = tf.tile(tf.expand_dims(tf.scatter_nd(tmp_indices, stds, self.shape), 0), tile_shape) |
|
|
496 |
stds_map = tf.map_fn(lambda x: tf.gather(x[0], x[1]), [stds, labels], dtype=tf.float32) |
|
|
497 |
|
|
|
498 |
return stds_map * tf.random.normal(tf.shape(labels)) + means_map |
|
|
499 |
|
|
|
500 |
def compute_output_shape(self, input_shape): |
|
|
501 |
return input_shape[0] if (self.n_channels == 1) else tuple(list(input_shape[0][:-1]) + [self.n_channels]) |
|
|
502 |
|
|
|
503 |
|
|
|
504 |
class SampleResolution(Layer): |
|
|
505 |
"""Build a random resolution tensor by sampling a uniform distribution of provided range. |
|
|
506 |
|
|
|
507 |
You can use this layer in the following ways: |
|
|
508 |
resolution = SampleConditionalGMM(min_resolution)() in this case resolution will be a tensor of shape (n_dims,), |
|
|
509 |
where n_dims is the length of the min_resolution parameter (provided as a list, see below). |
|
|
510 |
resolution = SampleConditionalGMM(min_resolution)(input), where input is a tensor for which the first dimension |
|
|
511 |
represents the batch_size. In this case resolution will be a tensor of shape (batchsize, n_dims,). |
|
|
512 |
|
|
|
513 |
:param min_resolution: list of length n_dims specifying the inferior bounds of the uniform distributions to |
|
|
514 |
sample from for each value. |
|
|
515 |
:param max_res_iso: If not None, all the values of resolution will be equal to the same value, which is randomly |
|
|
516 |
sampled at each minibatch in U(min_resolution, max_res_iso). |
|
|
517 |
:param max_res_aniso: If not None, we first randomly select a direction i in the range [0, n_dims-1], and we sample |
|
|
518 |
a value in the corresponding uniform distribution U(min_resolution[i], max_res_aniso[i]). |
|
|
519 |
The other values of resolution will be set to min_resolution. |
|
|
520 |
:param prob_iso: if both max_res_iso and max_res_aniso are specified, this allows to specify the probability of |
|
|
521 |
sampling an isotropic resolution (therefore using max_res_iso) with respect to anisotropic resolution |
|
|
522 |
(which would use max_res_aniso). |
|
|
523 |
:param prob_min: if not zero, this allows to return with the specified probability an output resolution equal |
|
|
524 |
to min_resolution. |
|
|
525 |
:param return_thickness: if set to True, this layer will also return a thickness value of the same shape as |
|
|
526 |
resolution, which will be sampled independently for each axis from the uniform distribution |
|
|
527 |
U(min_resolution, resolution). |
|
|
528 |
|
|
|
529 |
""" |
|
|
530 |
|
|
|
531 |
def __init__(self, |
|
|
532 |
min_resolution, |
|
|
533 |
max_res_iso=None, |
|
|
534 |
max_res_aniso=None, |
|
|
535 |
prob_iso=0.1, |
|
|
536 |
prob_min=0.05, |
|
|
537 |
return_thickness=True, |
|
|
538 |
**kwargs): |
|
|
539 |
|
|
|
540 |
self.min_res = min_resolution |
|
|
541 |
self.max_res_iso_input = max_res_iso |
|
|
542 |
self.max_res_iso = None |
|
|
543 |
self.max_res_aniso_input = max_res_aniso |
|
|
544 |
self.max_res_aniso = None |
|
|
545 |
self.prob_iso = prob_iso |
|
|
546 |
self.prob_min = prob_min |
|
|
547 |
self.return_thickness = return_thickness |
|
|
548 |
self.n_dims = len(self.min_res) |
|
|
549 |
self.add_batchsize = False |
|
|
550 |
self.min_res_tens = None |
|
|
551 |
super(SampleResolution, self).__init__(**kwargs) |
|
|
552 |
|
|
|
553 |
def get_config(self): |
|
|
554 |
config = super().get_config() |
|
|
555 |
config["min_resolution"] = self.min_res |
|
|
556 |
config["max_res_iso"] = self.max_res_iso |
|
|
557 |
config["max_res_aniso"] = self.max_res_aniso |
|
|
558 |
config["prob_iso"] = self.prob_iso |
|
|
559 |
config["prob_min"] = self.prob_min |
|
|
560 |
config["return_thickness"] = self.return_thickness |
|
|
561 |
return config |
|
|
562 |
|
|
|
563 |
def build(self, input_shape): |
|
|
564 |
|
|
|
565 |
# check maximum resolutions |
|
|
566 |
assert ((self.max_res_iso_input is not None) | (self.max_res_aniso_input is not None)), \ |
|
|
567 |
'at least one of maximum isotropic or anisotropic resolutions must be provided, received none' |
|
|
568 |
|
|
|
569 |
# reformat resolutions as numpy arrays |
|
|
570 |
self.min_res = np.array(self.min_res) |
|
|
571 |
if self.max_res_iso_input is not None: |
|
|
572 |
self.max_res_iso = np.array(self.max_res_iso_input) |
|
|
573 |
assert len(self.min_res) == len(self.max_res_iso), \ |
|
|
574 |
'min and isotropic max resolution must have the same length, ' \ |
|
|
575 |
'had {0} and {1}'.format(self.min_res, self.max_res_iso) |
|
|
576 |
if np.array_equal(self.min_res, self.max_res_iso): |
|
|
577 |
self.max_res_iso = None |
|
|
578 |
if self.max_res_aniso_input is not None: |
|
|
579 |
self.max_res_aniso = np.array(self.max_res_aniso_input) |
|
|
580 |
assert len(self.min_res) == len(self.max_res_aniso), \ |
|
|
581 |
'min and anisotropic max resolution must have the same length, ' \ |
|
|
582 |
'had {} and {}'.format(self.min_res, self.max_res_aniso) |
|
|
583 |
if np.array_equal(self.min_res, self.max_res_aniso): |
|
|
584 |
self.max_res_aniso = None |
|
|
585 |
|
|
|
586 |
# check prob iso |
|
|
587 |
if (self.max_res_iso is not None) & (self.max_res_aniso is not None) & (self.prob_iso == 0): |
|
|
588 |
raise Exception('prob iso is 0 while sampling either isotropic and anisotropic resolutions is enabled') |
|
|
589 |
|
|
|
590 |
if input_shape: |
|
|
591 |
self.add_batchsize = True |
|
|
592 |
|
|
|
593 |
self.min_res_tens = tf.convert_to_tensor(self.min_res, dtype='float32') |
|
|
594 |
|
|
|
595 |
self.built = True |
|
|
596 |
super(SampleResolution, self).build(input_shape) |
|
|
597 |
|
|
|
598 |
def call(self, inputs, **kwargs): |
|
|
599 |
|
|
|
600 |
if not self.add_batchsize: |
|
|
601 |
shape = [self.n_dims] |
|
|
602 |
dim = tf.random.uniform(shape=(1, 1), minval=0, maxval=self.n_dims, dtype='int32') |
|
|
603 |
mask = tf.tensor_scatter_nd_update(tf.zeros([self.n_dims], dtype='bool'), dim, |
|
|
604 |
tf.convert_to_tensor([True], dtype='bool')) |
|
|
605 |
else: |
|
|
606 |
batch = tf.split(tf.shape(inputs), [1, -1])[0] |
|
|
607 |
tile_shape = tf.concat([batch, tf.convert_to_tensor([1], dtype='int32')], axis=0) |
|
|
608 |
self.min_res_tens = tf.tile(tf.expand_dims(self.min_res_tens, 0), tile_shape) |
|
|
609 |
|
|
|
610 |
shape = tf.concat([batch, tf.convert_to_tensor([self.n_dims], dtype='int32')], axis=0) |
|
|
611 |
indices = tf.stack([tf.range(0, batch[0]), tf.random.uniform(batch, 0, self.n_dims, dtype='int32')], 1) |
|
|
612 |
mask = tf.tensor_scatter_nd_update(tf.zeros(shape, dtype='bool'), indices, tf.ones(batch, dtype='bool')) |
|
|
613 |
|
|
|
614 |
# return min resolution as tensor if min=max |
|
|
615 |
if (self.max_res_iso is None) & (self.max_res_aniso is None): |
|
|
616 |
new_resolution = self.min_res_tens |
|
|
617 |
|
|
|
618 |
# sample isotropic resolution only |
|
|
619 |
elif (self.max_res_iso is not None) & (self.max_res_aniso is None): |
|
|
620 |
new_resolution_iso = tf.random.uniform(shape, minval=self.min_res, maxval=self.max_res_iso) |
|
|
621 |
new_resolution = K.switch(tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_min)), |
|
|
622 |
self.min_res_tens, |
|
|
623 |
new_resolution_iso) |
|
|
624 |
|
|
|
625 |
# sample anisotropic resolution only |
|
|
626 |
elif (self.max_res_iso is None) & (self.max_res_aniso is not None): |
|
|
627 |
new_resolution_aniso = tf.random.uniform(shape, minval=self.min_res, maxval=self.max_res_aniso) |
|
|
628 |
new_resolution = K.switch(tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_min)), |
|
|
629 |
self.min_res_tens, |
|
|
630 |
tf.where(mask, new_resolution_aniso, self.min_res_tens)) |
|
|
631 |
|
|
|
632 |
# sample either anisotropic or isotropic resolution |
|
|
633 |
else: |
|
|
634 |
new_resolution_iso = tf.random.uniform(shape, minval=self.min_res, maxval=self.max_res_iso) |
|
|
635 |
new_resolution_aniso = tf.random.uniform(shape, minval=self.min_res, maxval=self.max_res_aniso) |
|
|
636 |
new_resolution = K.switch(tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_iso)), |
|
|
637 |
new_resolution_iso, |
|
|
638 |
tf.where(mask, new_resolution_aniso, self.min_res_tens)) |
|
|
639 |
new_resolution = K.switch(tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_min)), |
|
|
640 |
self.min_res_tens, |
|
|
641 |
new_resolution) |
|
|
642 |
|
|
|
643 |
if self.return_thickness: |
|
|
644 |
return [new_resolution, tf.random.uniform(tf.shape(self.min_res_tens), self.min_res_tens, new_resolution)] |
|
|
645 |
else: |
|
|
646 |
return new_resolution |
|
|
647 |
|
|
|
648 |
def compute_output_shape(self, input_shape): |
|
|
649 |
if self.return_thickness: |
|
|
650 |
return [(None, self.n_dims)] * 2 if self.add_batchsize else [self.n_dims] * 2 |
|
|
651 |
else: |
|
|
652 |
return (None, self.n_dims) if self.add_batchsize else self.n_dims |
|
|
653 |
|
|
|
654 |
|
|
|
655 |
class GaussianBlur(Layer): |
|
|
656 |
"""Applies gaussian blur to an input image. |
|
|
657 |
The input image is expected to have shape [batchsize, shape_dim1, ..., shape_dimn, channel]. |
|
|
658 |
:param sigma: standard deviation of the blurring kernels to apply. Can be a number, a list of length n_dims, or |
|
|
659 |
a numpy array. |
|
|
660 |
:param random_blur_range: (optional) if not None, this introduces a randomness in the blurring kernels, where |
|
|
661 |
sigma is now multiplied by a coefficient dynamically sampled from a uniform distribution with bounds |
|
|
662 |
[1/random_blur_range, random_blur_range]. |
|
|
663 |
:param use_mask: (optional) whether a mask of the input will be provided as an additional layer input. This is used |
|
|
664 |
to mask the blurred image, and to correct for edge blurring effects. |
|
|
665 |
|
|
|
666 |
example 1: |
|
|
667 |
output = GaussianBlur(sigma=0.5)(input) will isotropically blur the input with a gaussian kernel of std 0.5. |
|
|
668 |
|
|
|
669 |
example 2: |
|
|
670 |
if input is a tensor of shape [batchsize, 10, 100, 200, 2] |
|
|
671 |
output = GaussianBlur(sigma=[0.5, 1, 10])(input) will blur the input a different gaussian kernel in each dimension. |
|
|
672 |
|
|
|
673 |
example 3: |
|
|
674 |
output = GaussianBlur(sigma=0.5, random_blur_range=1.15)(input) |
|
|
675 |
will blur the input a different gaussian kernel in each dimension, as each dimension will be associated with |
|
|
676 |
a kernel, whose standard deviation will be uniformly sampled from [0.5/1.15; 0.5*1.15]. |
|
|
677 |
|
|
|
678 |
example 4: |
|
|
679 |
output = GaussianBlur(sigma=0.5, use_mask=True)([input, mask]) |
|
|
680 |
will 1) blur the input a different gaussian kernel in each dimension, 2) mask the blurred image with the provided |
|
|
681 |
mask, and 3) correct for edge blurring effects. If the provided mask is not of boolean type, it will be thresholded |
|
|
682 |
above positive values. |
|
|
683 |
""" |
|
|
684 |
|
|
|
685 |
def __init__(self, sigma, random_blur_range=None, use_mask=False, **kwargs): |
|
|
686 |
self.sigma = utils.reformat_to_list(sigma) |
|
|
687 |
assert np.all(np.array(self.sigma) >= 0), 'sigma should be superior or equal to 0' |
|
|
688 |
self.use_mask = use_mask |
|
|
689 |
|
|
|
690 |
self.n_dims = None |
|
|
691 |
self.n_channels = None |
|
|
692 |
self.blur_range = random_blur_range |
|
|
693 |
self.stride = None |
|
|
694 |
self.separable = None |
|
|
695 |
self.kernels = None |
|
|
696 |
self.convnd = None |
|
|
697 |
super(GaussianBlur, self).__init__(**kwargs) |
|
|
698 |
|
|
|
699 |
def get_config(self): |
|
|
700 |
config = super().get_config() |
|
|
701 |
config["sigma"] = self.sigma |
|
|
702 |
config["random_blur_range"] = self.blur_range |
|
|
703 |
config["use_mask"] = self.use_mask |
|
|
704 |
return config |
|
|
705 |
|
|
|
706 |
def build(self, input_shape): |
|
|
707 |
|
|
|
708 |
# get shapes |
|
|
709 |
if self.use_mask: |
|
|
710 |
assert len(input_shape) == 2, 'please provide a mask as second layer input when use_mask=True' |
|
|
711 |
self.n_dims = len(input_shape[0]) - 2 |
|
|
712 |
self.n_channels = input_shape[0][-1] |
|
|
713 |
else: |
|
|
714 |
self.n_dims = len(input_shape) - 2 |
|
|
715 |
self.n_channels = input_shape[-1] |
|
|
716 |
|
|
|
717 |
# prepare blurring kernel |
|
|
718 |
self.stride = [1] * (self.n_dims + 2) |
|
|
719 |
self.sigma = utils.reformat_to_list(self.sigma, length=self.n_dims) |
|
|
720 |
self.separable = np.linalg.norm(np.array(self.sigma)) > 5 |
|
|
721 |
if self.blur_range is None: # fixed kernels |
|
|
722 |
self.kernels = l2i_et.gaussian_kernel(self.sigma, separable=self.separable) |
|
|
723 |
else: |
|
|
724 |
self.kernels = None |
|
|
725 |
|
|
|
726 |
# prepare convolution |
|
|
727 |
self.convnd = getattr(tf.nn, 'conv%dd' % self.n_dims) |
|
|
728 |
|
|
|
729 |
self.built = True |
|
|
730 |
super(GaussianBlur, self).build(input_shape) |
|
|
731 |
|
|
|
732 |
def call(self, inputs, **kwargs): |
|
|
733 |
|
|
|
734 |
if self.use_mask: |
|
|
735 |
image = inputs[0] |
|
|
736 |
mask = tf.cast(inputs[1], 'bool') |
|
|
737 |
else: |
|
|
738 |
image = inputs |
|
|
739 |
mask = None |
|
|
740 |
|
|
|
741 |
# redefine the kernels at each new step when blur_range is activated |
|
|
742 |
if self.blur_range is not None: |
|
|
743 |
self.kernels = l2i_et.gaussian_kernel(self.sigma, blur_range=self.blur_range, separable=self.separable) |
|
|
744 |
|
|
|
745 |
if self.separable: |
|
|
746 |
for k in self.kernels: |
|
|
747 |
if k is not None: |
|
|
748 |
image = tf.concat([self.convnd(tf.expand_dims(image[..., n], -1), k, self.stride, 'SAME') |
|
|
749 |
for n in range(self.n_channels)], -1) |
|
|
750 |
if self.use_mask: |
|
|
751 |
maskb = tf.cast(mask, 'float32') |
|
|
752 |
maskb = tf.concat([self.convnd(tf.expand_dims(maskb[..., n], -1), k, self.stride, 'SAME') |
|
|
753 |
for n in range(self.n_channels)], -1) |
|
|
754 |
image = image / (maskb + K.epsilon()) |
|
|
755 |
image = tf.where(mask, image, tf.zeros_like(image)) |
|
|
756 |
else: |
|
|
757 |
if any(self.sigma): |
|
|
758 |
image = tf.concat([self.convnd(tf.expand_dims(image[..., n], -1), self.kernels, self.stride, 'SAME') |
|
|
759 |
for n in range(self.n_channels)], -1) |
|
|
760 |
if self.use_mask: |
|
|
761 |
maskb = tf.cast(mask, 'float32') |
|
|
762 |
maskb = tf.concat([self.convnd(tf.expand_dims(maskb[..., n], -1), self.kernels, self.stride, 'SAME') |
|
|
763 |
for n in range(self.n_channels)], -1) |
|
|
764 |
image = image / (maskb + K.epsilon()) |
|
|
765 |
image = tf.where(mask, image, tf.zeros_like(image)) |
|
|
766 |
|
|
|
767 |
return image |
|
|
768 |
|
|
|
769 |
|
|
|
770 |
class DynamicGaussianBlur(Layer): |
|
|
771 |
"""Applies gaussian blur to an input image, where the standard deviation of the blurring kernel is provided as a |
|
|
772 |
layer input, which enables to perform dynamic blurring (i.e. the blurring kernel can vary at each minibatch). |
|
|
773 |
:param max_sigma: maximum value of the standard deviation that will be provided as input. This is used to compute |
|
|
774 |
the size of the blurring kernels. It must be provided as a list of length n_dims. |
|
|
775 |
:param random_blur_range: (optional) if not None, this introduces a randomness in the blurring kernels, where |
|
|
776 |
sigma is now multiplied by a coefficient dynamically sampled from a uniform distribution with bounds |
|
|
777 |
[1/random_blur_range, random_blur_range]. |
|
|
778 |
|
|
|
779 |
example: |
|
|
780 |
blurred_image = DynamicGaussianBlur(max_sigma=[5.]*3, random_blurring_range=1.15)([image, sigma]) |
|
|
781 |
will return a blurred version of image, where the standard deviation of each dimension (given as a tensor, and with |
|
|
782 |
values lower than 5 for each axis) is multiplied by a random coefficient uniformly sampled from [1/1.15; 1.15]. |
|
|
783 |
""" |
|
|
784 |
|
|
|
785 |
def __init__(self, max_sigma, random_blur_range=None, **kwargs): |
|
|
786 |
self.max_sigma = max_sigma |
|
|
787 |
self.n_dims = None |
|
|
788 |
self.n_channels = None |
|
|
789 |
self.convnd = None |
|
|
790 |
self.blur_range = random_blur_range |
|
|
791 |
self.separable = None |
|
|
792 |
super(DynamicGaussianBlur, self).__init__(**kwargs) |
|
|
793 |
|
|
|
794 |
def get_config(self): |
|
|
795 |
config = super().get_config() |
|
|
796 |
config["max_sigma"] = self.max_sigma |
|
|
797 |
config["random_blur_range"] = self.blur_range |
|
|
798 |
return config |
|
|
799 |
|
|
|
800 |
def build(self, input_shape): |
|
|
801 |
assert len(input_shape) == 2, 'sigma should be provided as an input tensor for dynamic blurring' |
|
|
802 |
self.n_dims = len(input_shape[0]) - 2 |
|
|
803 |
self.n_channels = input_shape[0][-1] |
|
|
804 |
self.convnd = getattr(tf.nn, 'conv%dd' % self.n_dims) |
|
|
805 |
self.max_sigma = utils.reformat_to_list(self.max_sigma, length=self.n_dims) |
|
|
806 |
self.separable = np.linalg.norm(np.array(self.max_sigma)) > 5 |
|
|
807 |
self.built = True |
|
|
808 |
super(DynamicGaussianBlur, self).build(input_shape) |
|
|
809 |
|
|
|
810 |
def call(self, inputs, **kwargs): |
|
|
811 |
image = inputs[0] |
|
|
812 |
sigma = inputs[-1] |
|
|
813 |
kernels = l2i_et.gaussian_kernel(sigma, self.max_sigma, self.blur_range, self.separable) |
|
|
814 |
if self.separable: |
|
|
815 |
for kernel in kernels: |
|
|
816 |
image = tf.map_fn(self._single_blur, [image, kernel], dtype=tf.float32) |
|
|
817 |
else: |
|
|
818 |
image = tf.map_fn(self._single_blur, [image, kernels], dtype=tf.float32) |
|
|
819 |
return image |
|
|
820 |
|
|
|
821 |
def _single_blur(self, inputs): |
|
|
822 |
if self.n_channels > 1: |
|
|
823 |
split_channels = tf.split(inputs[0], [1] * self.n_channels, axis=-1) |
|
|
824 |
blurred_channel = list() |
|
|
825 |
for channel in split_channels: |
|
|
826 |
blurred = self.convnd(tf.expand_dims(channel, 0), inputs[1], [1] * (self.n_dims + 2), padding='SAME') |
|
|
827 |
blurred_channel.append(tf.squeeze(blurred, axis=0)) |
|
|
828 |
output = tf.concat(blurred_channel, -1) |
|
|
829 |
else: |
|
|
830 |
output = self.convnd(tf.expand_dims(inputs[0], 0), inputs[1], [1] * (self.n_dims + 2), padding='SAME') |
|
|
831 |
output = tf.squeeze(output, axis=0) |
|
|
832 |
return output |
|
|
833 |
|
|
|
834 |
|
|
|
835 |
class MimicAcquisition(Layer): |
|
|
836 |
""" |
|
|
837 |
Layer that takes an image as input, and simulates data that has been acquired at low resolution. |
|
|
838 |
The output is obtained by resampling the input twice: |
|
|
839 |
- first at a resolution given as an input (i.e. the "acquisition" resolution), |
|
|
840 |
- then at the output resolution (specified output shape). |
|
|
841 |
The input tensor is expected to have shape [batchsize, shape_dim1, ..., shape_dimn, channel]. |
|
|
842 |
|
|
|
843 |
:param volume_res: resolution of the provided inputs. Must be a 1-D numpy array with n_dims elements. |
|
|
844 |
:param min_subsample_res: lower bound of the acquisition resolutions to mimic (i.e. the input resolution must have |
|
|
845 |
values higher than min-subsample_res). |
|
|
846 |
:param resample_shape: shape of the output tensor |
|
|
847 |
:param build_dist_map: whether to return distance maps as outputs. These indicate the distance between each voxel |
|
|
848 |
and the nearest non-interpolated voxel (during the second resampling). |
|
|
849 |
:param prob_noise: probability to apply noise injection |
|
|
850 |
|
|
|
851 |
example 1: |
|
|
852 |
im_res = [1., 1., 1.] |
|
|
853 |
low_res = [1., 1., 3.] |
|
|
854 |
res = tf.convert_to_tensor([1., 1., 4.5]) |
|
|
855 |
image is a tensor of shape (None, 256, 256, 256, 3) |
|
|
856 |
resample_shape = [256, 256, 256] |
|
|
857 |
output = MimicAcquisition(im_res, low_res, resample_shape)([image, res]) |
|
|
858 |
output will be a tensor of shape (None, 256, 256, 256, 3), obtained by downsampling image to [1., 1., 4.5]. |
|
|
859 |
and re-upsampling it at initial resolution (because resample_shape is equal to the input shape). In this example all |
|
|
860 |
examples of the batch will be downsampled to the same resolution (because res has no batch dimension). |
|
|
861 |
Note that the provided res must have higher values than min_low_res. |
|
|
862 |
|
|
|
863 |
example 2: |
|
|
864 |
im_res = [1., 1., 1.] |
|
|
865 |
min_low_res = [1., 1., 1.] |
|
|
866 |
res is a tensor of shape (None, 3), obtained for example by using the SampleResolution layer (see above). |
|
|
867 |
image is a tensor of shape (None, 256, 256, 256, 1) |
|
|
868 |
resample_shape = [128, 128, 128] |
|
|
869 |
output = MimicAcquisition(im_res, low_res, resample_shape)([image, res]) |
|
|
870 |
output will be a tensor of shape (None, 128, 128, 128, 1), obtained by downsampling each examples of the batch to |
|
|
871 |
the matching resolution in res, and resampling them all to half the initial resolution. |
|
|
872 |
Note that the provided res must have higher values than min_low_res. |
|
|
873 |
""" |
|
|
874 |
|
|
|
875 |
def __init__(self, volume_res, min_subsample_res, resample_shape, build_dist_map=False, |
|
|
876 |
noise_std=0, prob_noise=0.95, **kwargs): |
|
|
877 |
|
|
|
878 |
# resolutions and dimensions |
|
|
879 |
self.volume_res = volume_res |
|
|
880 |
self.min_subsample_res = min_subsample_res |
|
|
881 |
self.n_dims = len(self.volume_res) |
|
|
882 |
self.n_channels = None |
|
|
883 |
self.add_batchsize = None |
|
|
884 |
|
|
|
885 |
# noise |
|
|
886 |
self.noise_std = noise_std |
|
|
887 |
self.prob_noise = prob_noise |
|
|
888 |
|
|
|
889 |
# input and output shapes |
|
|
890 |
self.inshape = None |
|
|
891 |
self.resample_shape = resample_shape |
|
|
892 |
|
|
|
893 |
# meshgrids for resampling |
|
|
894 |
self.down_grid = None |
|
|
895 |
self.up_grid = None |
|
|
896 |
|
|
|
897 |
# whether to return a map indicating the distance from the interpolated voxels, to acquired ones. |
|
|
898 |
self.build_dist_map = build_dist_map |
|
|
899 |
|
|
|
900 |
super(MimicAcquisition, self).__init__(**kwargs) |
|
|
901 |
|
|
|
902 |
def get_config(self): |
|
|
903 |
config = super().get_config() |
|
|
904 |
config["volume_res"] = self.volume_res |
|
|
905 |
config["min_subsample_res"] = self.min_subsample_res |
|
|
906 |
config["resample_shape"] = self.resample_shape |
|
|
907 |
config["build_dist_map"] = self.build_dist_map |
|
|
908 |
config["noise_std"] = self.noise_std |
|
|
909 |
config["prob_noise"] = self.prob_noise |
|
|
910 |
return config |
|
|
911 |
|
|
|
912 |
def build(self, input_shape): |
|
|
913 |
|
|
|
914 |
# set up input shape and acquisition shape |
|
|
915 |
self.inshape = input_shape[0][1:] |
|
|
916 |
self.n_channels = input_shape[0][-1] |
|
|
917 |
self.add_batchsize = False if (input_shape[1][0] is None) else True |
|
|
918 |
down_tensor_shape = np.int32(np.array(self.inshape[:-1]) * self.volume_res / self.min_subsample_res) |
|
|
919 |
|
|
|
920 |
# build interpolation meshgrids |
|
|
921 |
self.down_grid = tf.expand_dims(tf.stack(nrn_utils.volshape_to_ndgrid(down_tensor_shape), -1), axis=0) |
|
|
922 |
self.up_grid = tf.expand_dims(tf.stack(nrn_utils.volshape_to_ndgrid(self.resample_shape), -1), axis=0) |
|
|
923 |
|
|
|
924 |
self.built = True |
|
|
925 |
super(MimicAcquisition, self).build(input_shape) |
|
|
926 |
|
|
|
927 |
def call(self, inputs, **kwargs): |
|
|
928 |
|
|
|
929 |
# sort inputs |
|
|
930 |
assert len(inputs) == 2, 'inputs must have two items, the tensor to resample, and the downsampling resolution' |
|
|
931 |
vol = inputs[0] |
|
|
932 |
subsample_res = tf.cast(inputs[1], dtype='float32') |
|
|
933 |
vol = K.reshape(vol, [-1, *self.inshape]) # necessary for multi_gpu models |
|
|
934 |
batchsize = tf.split(tf.shape(vol), [1, -1])[0] |
|
|
935 |
tile_shape = tf.concat([batchsize, tf.ones([1], dtype='int32')], 0) |
|
|
936 |
|
|
|
937 |
# get downsampling and upsampling factors |
|
|
938 |
if self.add_batchsize: |
|
|
939 |
subsample_res = tf.tile(tf.expand_dims(subsample_res, 0), tile_shape) |
|
|
940 |
down_shape = tf.cast(tf.convert_to_tensor(np.array(self.inshape[:-1]) * self.volume_res, dtype='float32') / |
|
|
941 |
subsample_res, dtype='int32') |
|
|
942 |
down_zoom_factor = tf.cast(down_shape / tf.convert_to_tensor(self.inshape[:-1]), dtype='float32') |
|
|
943 |
up_zoom_factor = tf.cast(tf.convert_to_tensor(self.resample_shape, dtype='int32') / down_shape, dtype='float32') |
|
|
944 |
|
|
|
945 |
# downsample |
|
|
946 |
down_loc = tf.tile(self.down_grid, tf.concat([batchsize, tf.ones([self.n_dims + 1], dtype='int32')], 0)) |
|
|
947 |
down_loc = tf.cast(down_loc, 'float32') / l2i_et.expand_dims(down_zoom_factor, axis=[1] * self.n_dims) |
|
|
948 |
inshape_tens = tf.tile(tf.expand_dims(tf.convert_to_tensor(self.inshape[:-1]), 0), tile_shape) |
|
|
949 |
inshape_tens = l2i_et.expand_dims(inshape_tens, axis=[1] * self.n_dims) |
|
|
950 |
down_loc = K.clip(down_loc, 0., tf.cast(inshape_tens, 'float32')) |
|
|
951 |
vol = tf.map_fn(self._single_down_interpn, [vol, down_loc], tf.float32) |
|
|
952 |
|
|
|
953 |
# add noise with predefined probability |
|
|
954 |
if self.noise_std > 0: |
|
|
955 |
sample_shape = tf.concat([batchsize, tf.ones([self.n_dims], dtype='int32'), |
|
|
956 |
self.n_channels * tf.ones([1], dtype='int32')], 0) |
|
|
957 |
noise = tf.random.normal(tf.shape(vol), stddev=tf.random.uniform(sample_shape, maxval=self.noise_std)) |
|
|
958 |
if self.prob_noise == 1: |
|
|
959 |
vol += noise |
|
|
960 |
else: |
|
|
961 |
vol = K.switch(tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_noise)), vol + noise, vol) |
|
|
962 |
|
|
|
963 |
# upsample |
|
|
964 |
up_loc = tf.tile(self.up_grid, tf.concat([batchsize, tf.ones([self.n_dims + 1], dtype='int32')], axis=0)) |
|
|
965 |
up_loc = tf.cast(up_loc, 'float32') / l2i_et.expand_dims(up_zoom_factor, axis=[1] * self.n_dims) |
|
|
966 |
vol = tf.map_fn(self._single_up_interpn, [vol, up_loc], tf.float32) |
|
|
967 |
|
|
|
968 |
# return upsampled volume |
|
|
969 |
if not self.build_dist_map: |
|
|
970 |
return vol |
|
|
971 |
|
|
|
972 |
# return upsampled volumes with distance maps |
|
|
973 |
else: |
|
|
974 |
|
|
|
975 |
# get grid points |
|
|
976 |
floor = tf.math.floor(up_loc) |
|
|
977 |
ceil = tf.math.ceil(up_loc) |
|
|
978 |
|
|
|
979 |
# get distances of every voxel to higher and lower grid points for every dimension |
|
|
980 |
f_dist = up_loc - floor |
|
|
981 |
c_dist = ceil - up_loc |
|
|
982 |
|
|
|
983 |
# keep minimum 1d distances, and compute 3d distance to nearest grid point |
|
|
984 |
dist = tf.math.minimum(f_dist, c_dist) * l2i_et.expand_dims(subsample_res, axis=[1] * self.n_dims) |
|
|
985 |
dist = tf.math.sqrt(tf.math.reduce_sum(tf.math.square(dist), axis=-1, keepdims=True)) |
|
|
986 |
|
|
|
987 |
return [vol, dist] |
|
|
988 |
|
|
|
989 |
@staticmethod |
|
|
990 |
def _single_down_interpn(inputs): |
|
|
991 |
return nrn_utils.interpn(inputs[0], inputs[1], interp_method='nearest') |
|
|
992 |
|
|
|
993 |
@staticmethod |
|
|
994 |
def _single_up_interpn(inputs): |
|
|
995 |
return nrn_utils.interpn(inputs[0], inputs[1], interp_method='linear') |
|
|
996 |
|
|
|
997 |
def compute_output_shape(self, input_shape): |
|
|
998 |
output_shape = tuple([None] + self.resample_shape + [input_shape[0][-1]]) |
|
|
999 |
return [output_shape] * 2 if self.build_dist_map else output_shape |
|
|
1000 |
|
|
|
1001 |
|
|
|
1002 |
class BiasFieldCorruption(Layer): |
|
|
1003 |
"""This layer applies a smooth random bias field to the input by applying the following steps: |
|
|
1004 |
1) we first sample a value for the standard deviation of a centred normal distribution |
|
|
1005 |
2) a small-size SVF is sampled from this normal distribution |
|
|
1006 |
3) the small SVF is then resized with trilinear interpolation to image size |
|
|
1007 |
4) it is rescaled to positive values by taking the voxel-wise exponential |
|
|
1008 |
5) it is multiplied to the input tensor. |
|
|
1009 |
The input tensor is expected to have shape [batchsize, shape_dim1, ..., shape_dimn, channel]. |
|
|
1010 |
|
|
|
1011 |
:param bias_field_std: maximum value of the standard deviation sampled in 1 (it will be sampled from the range |
|
|
1012 |
[0, bias_field_std]) |
|
|
1013 |
:param bias_scale: ratio between the shape of the input tensor and the shape of the sampled SVF. |
|
|
1014 |
:param same_bias_for_all_channels: whether to apply the same bias field to all the channels of the input tensor. |
|
|
1015 |
:param prob: probability to apply this bias field corruption. |
|
|
1016 |
""" |
|
|
1017 |
|
|
|
1018 |
def __init__(self, bias_field_std=.5, bias_scale=.025, same_bias_for_all_channels=False, prob=0.95, **kwargs): |
|
|
1019 |
|
|
|
1020 |
# input shape |
|
|
1021 |
self.several_inputs = False |
|
|
1022 |
self.inshape = None |
|
|
1023 |
self.n_dims = None |
|
|
1024 |
self.n_channels = None |
|
|
1025 |
|
|
|
1026 |
# sampling shape |
|
|
1027 |
self.std_shape = None |
|
|
1028 |
self.small_bias_shape = None |
|
|
1029 |
|
|
|
1030 |
# bias field parameters |
|
|
1031 |
self.bias_field_std = bias_field_std |
|
|
1032 |
self.bias_scale = bias_scale |
|
|
1033 |
self.same_bias_for_all_channels = same_bias_for_all_channels |
|
|
1034 |
self.prob = prob |
|
|
1035 |
|
|
|
1036 |
super(BiasFieldCorruption, self).__init__(**kwargs) |
|
|
1037 |
|
|
|
1038 |
def get_config(self): |
|
|
1039 |
config = super().get_config() |
|
|
1040 |
config["bias_field_std"] = self.bias_field_std |
|
|
1041 |
config["bias_scale"] = self.bias_scale |
|
|
1042 |
config["same_bias_for_all_channels"] = self.same_bias_for_all_channels |
|
|
1043 |
config["prob"] = self.prob |
|
|
1044 |
return config |
|
|
1045 |
|
|
|
1046 |
def build(self, input_shape): |
|
|
1047 |
|
|
|
1048 |
# input shape |
|
|
1049 |
if isinstance(input_shape, list): |
|
|
1050 |
self.several_inputs = True |
|
|
1051 |
self.inshape = input_shape |
|
|
1052 |
else: |
|
|
1053 |
self.inshape = [input_shape] |
|
|
1054 |
self.n_dims = len(self.inshape[0]) - 2 |
|
|
1055 |
self.n_channels = self.inshape[0][-1] |
|
|
1056 |
|
|
|
1057 |
# sampling shapes |
|
|
1058 |
self.std_shape = [1] * (self.n_dims + 1) |
|
|
1059 |
self.small_bias_shape = utils.get_resample_shape(self.inshape[0][1:self.n_dims + 1], self.bias_scale, 1) |
|
|
1060 |
if not self.same_bias_for_all_channels: |
|
|
1061 |
self.std_shape[-1] = self.n_channels |
|
|
1062 |
self.small_bias_shape[-1] = self.n_channels |
|
|
1063 |
|
|
|
1064 |
self.built = True |
|
|
1065 |
super(BiasFieldCorruption, self).build(input_shape) |
|
|
1066 |
|
|
|
1067 |
def call(self, inputs, **kwargs): |
|
|
1068 |
|
|
|
1069 |
if not self.several_inputs: |
|
|
1070 |
inputs = [inputs] |
|
|
1071 |
|
|
|
1072 |
if self.bias_field_std > 0: |
|
|
1073 |
|
|
|
1074 |
# sampling shapes |
|
|
1075 |
batchsize = tf.split(tf.shape(inputs[0]), [1, -1])[0] |
|
|
1076 |
std_shape = tf.concat([batchsize, tf.convert_to_tensor(self.std_shape, dtype='int32')], 0) |
|
|
1077 |
bias_shape = tf.concat([batchsize, tf.convert_to_tensor(self.small_bias_shape, dtype='int32')], axis=0) |
|
|
1078 |
|
|
|
1079 |
# sample small bias field |
|
|
1080 |
bias_field = tf.random.normal(bias_shape, stddev=tf.random.uniform(std_shape, maxval=self.bias_field_std)) |
|
|
1081 |
|
|
|
1082 |
# resize bias field and take exponential |
|
|
1083 |
bias_field = nrn_layers.Resize(size=self.inshape[0][1:self.n_dims + 1], interp_method='linear')(bias_field) |
|
|
1084 |
bias_field = tf.math.exp(bias_field) |
|
|
1085 |
|
|
|
1086 |
# apply bias field with predefined probability |
|
|
1087 |
if self.prob == 1: |
|
|
1088 |
return [tf.math.multiply(bias_field, v) for v in inputs] |
|
|
1089 |
else: |
|
|
1090 |
rand_trans = tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob)) |
|
|
1091 |
if self.several_inputs: |
|
|
1092 |
return [K.switch(rand_trans, tf.math.multiply(bias_field, v), v) for v in inputs] |
|
|
1093 |
else: |
|
|
1094 |
return K.switch(rand_trans, tf.math.multiply(bias_field, inputs[0]), inputs[0]) |
|
|
1095 |
|
|
|
1096 |
else: |
|
|
1097 |
return inputs |
|
|
1098 |
|
|
|
1099 |
|
|
|
1100 |
class IntensityAugmentation(Layer): |
|
|
1101 |
"""This layer enables to augment the intensities of the input tensor, as well as to apply min_max normalisation. |
|
|
1102 |
The following steps are applied (all are optional): |
|
|
1103 |
1) white noise corruption, with a randomly sampled std dev. |
|
|
1104 |
2) clip the input between two values |
|
|
1105 |
3) min-max normalisation |
|
|
1106 |
4) gamma augmentation (i.e. voxel-wise exponentiation by a randomly sampled power) |
|
|
1107 |
The input tensor is expected to have shape [batchsize, shape_dim1, ..., shape_dimn, channel]. |
|
|
1108 |
|
|
|
1109 |
:param noise_std: maximum value of the standard deviation of the Gaussian white noise used in 1 (it will be sampled |
|
|
1110 |
from the range [0, noise_std]). Set to 0 to skip this step. |
|
|
1111 |
:param clip: clip the input tensor between the given values. Can either be: a number (in which case we clip between |
|
|
1112 |
0 and the given value), or a list or a numpy array with two elements. Default is 0, where no clipping occurs. |
|
|
1113 |
:param normalise: whether to apply min-max normalisation, to normalise between 0 and 1. Default is True. |
|
|
1114 |
:param norm_perc: percentiles (between 0 and 1) of the sorted intensity values for robust normalisation. Can be: |
|
|
1115 |
a number (in which case the robust minimum is the provided percentile of sorted values, and the maximum is the |
|
|
1116 |
1 - norm_perc percentile), or a list/numpy array of 2 elements (percentiles for the minimum and maximum values). |
|
|
1117 |
The minimum and maximum values are computed separately for each channel if separate_channels is True. |
|
|
1118 |
Default is 0, where we simply take the minimum and maximum values. |
|
|
1119 |
:param gamma_std: standard deviation of the normal distribution from which we sample gamma (in log domain). |
|
|
1120 |
Default is 0, where no gamma augmentation occurs. |
|
|
1121 |
:param contrast_inversion: whether to perform contrast inversion (i.e. 1 - x). If True, this is performed randomly |
|
|
1122 |
for each element of the batch, as well as for each channel. |
|
|
1123 |
:param separate_channels: whether to augment all channels separately. Default is True. |
|
|
1124 |
:param prob_noise: probability to apply noise injection |
|
|
1125 |
:param prob_gamma: probability to apply gamma augmentation |
|
|
1126 |
""" |
|
|
1127 |
|
|
|
1128 |
def __init__(self, noise_std=0, clip=0, normalise=True, norm_perc=0, gamma_std=0, contrast_inversion=False, |
|
|
1129 |
separate_channels=True, prob_noise=0.95, prob_gamma=1, **kwargs): |
|
|
1130 |
|
|
|
1131 |
# shape attributes |
|
|
1132 |
self.n_dims = None |
|
|
1133 |
self.n_channels = None |
|
|
1134 |
self.flatten_shape = None |
|
|
1135 |
self.expand_minmax_dim = None |
|
|
1136 |
self.one = None |
|
|
1137 |
|
|
|
1138 |
# inputs |
|
|
1139 |
self.noise_std = noise_std |
|
|
1140 |
self.clip = clip |
|
|
1141 |
self.clip_values = None |
|
|
1142 |
self.normalise = normalise |
|
|
1143 |
self.norm_perc = norm_perc |
|
|
1144 |
self.perc = None |
|
|
1145 |
self.gamma_std = gamma_std |
|
|
1146 |
self.separate_channels = separate_channels |
|
|
1147 |
self.contrast_inversion = contrast_inversion |
|
|
1148 |
self.prob_noise = prob_noise |
|
|
1149 |
self.prob_gamma = prob_gamma |
|
|
1150 |
|
|
|
1151 |
super(IntensityAugmentation, self).__init__(**kwargs) |
|
|
1152 |
|
|
|
1153 |
def get_config(self): |
|
|
1154 |
config = super().get_config() |
|
|
1155 |
config["noise_std"] = self.noise_std |
|
|
1156 |
config["clip"] = self.clip |
|
|
1157 |
config["normalise"] = self.normalise |
|
|
1158 |
config["norm_perc"] = self.norm_perc |
|
|
1159 |
config["gamma_std"] = self.gamma_std |
|
|
1160 |
config["separate_channels"] = self.separate_channels |
|
|
1161 |
config["prob_noise"] = self.prob_noise |
|
|
1162 |
config["prob_gamma"] = self.prob_gamma |
|
|
1163 |
return config |
|
|
1164 |
|
|
|
1165 |
def build(self, input_shape): |
|
|
1166 |
self.n_dims = len(input_shape) - 2 |
|
|
1167 |
self.n_channels = input_shape[-1] |
|
|
1168 |
self.flatten_shape = np.prod(np.array(input_shape[1:-1])) |
|
|
1169 |
self.flatten_shape = self.flatten_shape * self.n_channels if not self.separate_channels else self.flatten_shape |
|
|
1170 |
self.expand_minmax_dim = self.n_dims if self.separate_channels else self.n_dims + 1 |
|
|
1171 |
self.one = tf.ones([1], dtype='int32') |
|
|
1172 |
if self.clip: |
|
|
1173 |
self.clip_values = utils.reformat_to_list(self.clip) |
|
|
1174 |
self.clip_values = self.clip_values if len(self.clip_values) == 2 else [0, self.clip_values[0]] |
|
|
1175 |
else: |
|
|
1176 |
self.clip_values = None |
|
|
1177 |
if self.norm_perc: |
|
|
1178 |
self.perc = utils.reformat_to_list(self.norm_perc) |
|
|
1179 |
self.perc = self.perc if len(self.perc) == 2 else [self.perc[0], 1 - self.perc[0]] |
|
|
1180 |
else: |
|
|
1181 |
self.perc = None |
|
|
1182 |
|
|
|
1183 |
self.built = True |
|
|
1184 |
super(IntensityAugmentation, self).build(input_shape) |
|
|
1185 |
|
|
|
1186 |
def call(self, inputs, **kwargs): |
|
|
1187 |
|
|
|
1188 |
# prepare shape for sampling the noise and gamma std dev (depending on whether we augment channels separately) |
|
|
1189 |
batchsize = tf.split(tf.shape(inputs), [1, -1])[0] |
|
|
1190 |
if (self.noise_std > 0) | (self.gamma_std > 0) | self.contrast_inversion: |
|
|
1191 |
sample_shape = tf.concat([batchsize, tf.ones([self.n_dims], dtype='int32')], 0) |
|
|
1192 |
if self.separate_channels: |
|
|
1193 |
sample_shape = tf.concat([sample_shape, self.n_channels * self.one], 0) |
|
|
1194 |
else: |
|
|
1195 |
sample_shape = tf.concat([sample_shape, self.one], 0) |
|
|
1196 |
else: |
|
|
1197 |
sample_shape = None |
|
|
1198 |
|
|
|
1199 |
# add noise with predefined probability |
|
|
1200 |
if self.noise_std > 0: |
|
|
1201 |
noise_stddev = tf.random.uniform(sample_shape, maxval=self.noise_std) |
|
|
1202 |
if self.separate_channels: |
|
|
1203 |
noise = tf.random.normal(tf.shape(inputs), stddev=noise_stddev) |
|
|
1204 |
else: |
|
|
1205 |
noise = tf.random.normal(tf.shape(tf.split(inputs, [1, -1], -1)[0]), stddev=noise_stddev) |
|
|
1206 |
noise = tf.tile(noise, tf.convert_to_tensor([1] * (self.n_dims + 1) + [self.n_channels])) |
|
|
1207 |
if self.prob_noise == 1: |
|
|
1208 |
inputs = inputs + noise |
|
|
1209 |
else: |
|
|
1210 |
inputs = K.switch(tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_noise)), |
|
|
1211 |
inputs + noise, inputs) |
|
|
1212 |
|
|
|
1213 |
# clip images to given values |
|
|
1214 |
if self.clip_values is not None: |
|
|
1215 |
inputs = K.clip(inputs, self.clip_values[0], self.clip_values[1]) |
|
|
1216 |
|
|
|
1217 |
# normalise |
|
|
1218 |
if self.normalise: |
|
|
1219 |
# define robust min and max by sorting values and taking percentile |
|
|
1220 |
if self.perc is not None: |
|
|
1221 |
if self.separate_channels: |
|
|
1222 |
shape = tf.concat([batchsize, self.flatten_shape * self.one, self.n_channels * self.one], 0) |
|
|
1223 |
else: |
|
|
1224 |
shape = tf.concat([batchsize, self.flatten_shape * self.one], 0) |
|
|
1225 |
intensities = tf.sort(tf.reshape(inputs, shape), axis=1) |
|
|
1226 |
m = intensities[:, max(int(self.perc[0] * self.flatten_shape), 0), ...] |
|
|
1227 |
M = intensities[:, min(int(self.perc[1] * self.flatten_shape), self.flatten_shape - 1), ...] |
|
|
1228 |
# simple min and max |
|
|
1229 |
else: |
|
|
1230 |
m = K.min(inputs, axis=list(range(1, self.expand_minmax_dim + 1))) |
|
|
1231 |
M = K.max(inputs, axis=list(range(1, self.expand_minmax_dim + 1))) |
|
|
1232 |
# normalise |
|
|
1233 |
m = l2i_et.expand_dims(m, axis=[1] * self.expand_minmax_dim) |
|
|
1234 |
M = l2i_et.expand_dims(M, axis=[1] * self.expand_minmax_dim) |
|
|
1235 |
inputs = tf.clip_by_value(inputs, m, M) |
|
|
1236 |
inputs = (inputs - m) / (M - m + K.epsilon()) |
|
|
1237 |
|
|
|
1238 |
# apply voxel-wise exponentiation with predefined probability |
|
|
1239 |
if self.gamma_std > 0: |
|
|
1240 |
gamma = tf.random.normal(sample_shape, stddev=self.gamma_std) |
|
|
1241 |
if self.prob_gamma == 1: |
|
|
1242 |
inputs = tf.math.pow(inputs, tf.math.exp(gamma)) |
|
|
1243 |
else: |
|
|
1244 |
inputs = K.switch(tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_gamma)), |
|
|
1245 |
tf.math.pow(inputs, tf.math.exp(gamma)), inputs) |
|
|
1246 |
|
|
|
1247 |
# apply random contrast inversion |
|
|
1248 |
if self.contrast_inversion: |
|
|
1249 |
rand_invert = tf.less(tf.random.uniform(sample_shape, maxval=1), 0.5) |
|
|
1250 |
split_channels = tf.split(inputs, [1] * self.n_channels, axis=-1) |
|
|
1251 |
split_rand_invert = tf.split(rand_invert, [1] * self.n_channels, axis=-1) |
|
|
1252 |
inverted_channel = list() |
|
|
1253 |
for (channel, invert) in zip(split_channels, split_rand_invert): |
|
|
1254 |
inverted_channel.append(tf.map_fn(self._single_invert, [channel, invert], dtype=channel.dtype)) |
|
|
1255 |
inputs = tf.concat(inverted_channel, -1) |
|
|
1256 |
|
|
|
1257 |
return inputs |
|
|
1258 |
|
|
|
1259 |
@staticmethod |
|
|
1260 |
def _single_invert(inputs): |
|
|
1261 |
return K.switch(tf.squeeze(inputs[1]), 1 - inputs[0], inputs[0]) |
|
|
1262 |
|
|
|
1263 |
|
|
|
1264 |
class DiceLoss(Layer): |
|
|
1265 |
"""This layer computes the soft Dice loss between two tensors. |
|
|
1266 |
These tensors are expected to have the same shape (one-hot encoding) [batch, size_dim1, ..., size_dimN, n_labels]. |
|
|
1267 |
The first input tensor is the GT and the second is the prediction: dice_loss = DiceLoss()([gt, pred]) |
|
|
1268 |
|
|
|
1269 |
:param class_weights: (optional) if given, the loss is obtained by a weighted average of the Dice across labels. |
|
|
1270 |
Must be a sequence or 1d numpy array of length n_labels. Can also be -1, where the weights are dynamically set to |
|
|
1271 |
the inverse of the volume of each label in the ground truth. |
|
|
1272 |
:param boundary_weights: (optional) bonus weight that we apply to the voxels close to boundaries between structures |
|
|
1273 |
when computing the loss. Default is 0 where no boundary weighting is applied. |
|
|
1274 |
:param boundary_dist: (optional) if boundary_weight is not 0, the extra boundary weighting is applied to all voxels |
|
|
1275 |
within this distance to a region boundary. Default is 3. |
|
|
1276 |
:param skip_background: (optional) whether to skip boundary weighting for the background class, as this may be |
|
|
1277 |
redundant when we have several labels. This is only used if boundary_weight is not 0. |
|
|
1278 |
:param enable_checks: (optional) whether to make sure that the 2 input tensors are probabilistic (i.e. the label |
|
|
1279 |
probabilities sum to 1 at each voxel location). Default is True. |
|
|
1280 |
""" |
|
|
1281 |
|
|
|
1282 |
def __init__(self, |
|
|
1283 |
class_weights=None, |
|
|
1284 |
boundary_weights=0, |
|
|
1285 |
boundary_dist=3, |
|
|
1286 |
skip_background=True, |
|
|
1287 |
enable_checks=True, |
|
|
1288 |
**kwargs): |
|
|
1289 |
|
|
|
1290 |
self.class_weights = class_weights |
|
|
1291 |
self.dynamic_weighting = False |
|
|
1292 |
self.class_weights_tens = None |
|
|
1293 |
self.boundary_weights = boundary_weights |
|
|
1294 |
self.boundary_dist = boundary_dist |
|
|
1295 |
self.skip_background = skip_background |
|
|
1296 |
self.enable_checks = enable_checks |
|
|
1297 |
self.spatial_axes = None |
|
|
1298 |
self.avg_pooling_layer = None |
|
|
1299 |
super(DiceLoss, self).__init__(**kwargs) |
|
|
1300 |
|
|
|
1301 |
def get_config(self): |
|
|
1302 |
config = super().get_config() |
|
|
1303 |
config["class_weights"] = self.class_weights |
|
|
1304 |
config["boundary_weights"] = self.boundary_weights |
|
|
1305 |
config["boundary_dist"] = self.boundary_dist |
|
|
1306 |
config["skip_background"] = self.skip_background |
|
|
1307 |
config["enable_checks"] = self.enable_checks |
|
|
1308 |
return config |
|
|
1309 |
|
|
|
1310 |
def build(self, input_shape): |
|
|
1311 |
|
|
|
1312 |
# get shape |
|
|
1313 |
assert len(input_shape) == 2, 'DiceLoss expects 2 inputs to compute the Dice loss.' |
|
|
1314 |
assert input_shape[0] == input_shape[1], 'the two inputs must have the same shape.' |
|
|
1315 |
inshape = input_shape[0][1:] |
|
|
1316 |
n_dims = len(inshape[:-1]) |
|
|
1317 |
n_labels = inshape[-1] |
|
|
1318 |
self.spatial_axes = list(range(1, n_dims + 1)) |
|
|
1319 |
self.avg_pooling_layer = getattr(keras.layers, 'AvgPool%dD' % n_dims) |
|
|
1320 |
self.skip_background = False if n_labels == 1 else self.skip_background |
|
|
1321 |
|
|
|
1322 |
# build tensor with class weights |
|
|
1323 |
if self.class_weights is not None: |
|
|
1324 |
if self.class_weights == -1: |
|
|
1325 |
self.dynamic_weighting = True |
|
|
1326 |
else: |
|
|
1327 |
class_weights_tens = utils.reformat_to_list(self.class_weights, n_labels) |
|
|
1328 |
class_weights_tens = tf.convert_to_tensor(class_weights_tens, 'float32') |
|
|
1329 |
self.class_weights_tens = l2i_et.expand_dims(class_weights_tens, 0) |
|
|
1330 |
|
|
|
1331 |
self.built = True |
|
|
1332 |
super(DiceLoss, self).build(input_shape) |
|
|
1333 |
|
|
|
1334 |
def call(self, inputs, **kwargs): |
|
|
1335 |
|
|
|
1336 |
# make sure tensors are probabilistic |
|
|
1337 |
gt = inputs[0] |
|
|
1338 |
pred = inputs[1] |
|
|
1339 |
if self.enable_checks: # disabling is useful to, e.g., use incomplete label maps |
|
|
1340 |
gt = K.clip(gt / (tf.math.reduce_sum(gt, axis=-1, keepdims=True) + tf.keras.backend.epsilon()), 0, 1) |
|
|
1341 |
pred = K.clip(pred / (tf.math.reduce_sum(pred, axis=-1, keepdims=True) + tf.keras.backend.epsilon()), 0, 1) |
|
|
1342 |
|
|
|
1343 |
# compute dice loss for each label |
|
|
1344 |
top = 2 * gt * pred |
|
|
1345 |
bottom = tf.math.square(gt) + tf.math.square(pred) |
|
|
1346 |
|
|
|
1347 |
# apply boundary weighting (ie voxels close to region boundaries will be counted several times to compute Dice) |
|
|
1348 |
if self.boundary_weights: |
|
|
1349 |
avg = self.avg_pooling_layer(pool_size=2 * self.boundary_dist + 1, strides=1, padding='same')(gt) |
|
|
1350 |
boundaries = tf.cast(avg > 0., 'float32') * tf.cast(avg < (1 / len(self.spatial_axes) - 1e-4), 'float32') |
|
|
1351 |
if self.skip_background: |
|
|
1352 |
boundaries_channels = tf.unstack(boundaries, axis=-1) |
|
|
1353 |
boundaries = tf.stack([tf.zeros_like(boundaries_channels[0])] + boundaries_channels[1:], axis=-1) |
|
|
1354 |
boundary_weights_tensor = 1 + self.boundary_weights * boundaries |
|
|
1355 |
top *= boundary_weights_tensor |
|
|
1356 |
bottom *= boundary_weights_tensor |
|
|
1357 |
else: |
|
|
1358 |
boundary_weights_tensor = None |
|
|
1359 |
|
|
|
1360 |
# compute loss |
|
|
1361 |
top = tf.math.reduce_sum(top, self.spatial_axes) |
|
|
1362 |
bottom = tf.math.reduce_sum(bottom, self.spatial_axes) |
|
|
1363 |
dice = (top + tf.keras.backend.epsilon()) / (bottom + tf.keras.backend.epsilon()) |
|
|
1364 |
loss = 1 - dice |
|
|
1365 |
|
|
|
1366 |
# apply class weighting across labels. In this case loss will have shape (batch), otherwise (batch, n_labels). |
|
|
1367 |
if self.dynamic_weighting: # the weight of a class is the inverse of its volume in the gt |
|
|
1368 |
if boundary_weights_tensor is not None: # we account for the boundary weighting to compute volume |
|
|
1369 |
self.class_weights_tens = 1 / tf.reduce_sum(gt * boundary_weights_tensor, self.spatial_axes) |
|
|
1370 |
else: |
|
|
1371 |
self.class_weights_tens = 1 / tf.reduce_sum(gt, self.spatial_axes) |
|
|
1372 |
if self.class_weights_tens is not None: |
|
|
1373 |
self. class_weights_tens /= tf.reduce_sum(self.class_weights_tens, -1) |
|
|
1374 |
loss = tf.reduce_sum(loss * self.class_weights_tens, -1) |
|
|
1375 |
|
|
|
1376 |
return tf.math.reduce_mean(loss) |
|
|
1377 |
|
|
|
1378 |
def compute_output_shape(self, input_shape): |
|
|
1379 |
return [[]] |
|
|
1380 |
|
|
|
1381 |
|
|
|
1382 |
class WeightedL2Loss(Layer): |
|
|
1383 |
"""This layer computes a L2 loss weighted by a specified factor (target_value) between two tensors. |
|
|
1384 |
This is designed to be used on the layer before the softmax. |
|
|
1385 |
The tensors are expected to have the same shape [batchsize, size_dim1, ..., size_dimN, n_labels]. |
|
|
1386 |
The first input tensor is the GT and the second is the prediction: wl2_loss = WeightedL2Loss()([gt, pred]) |
|
|
1387 |
|
|
|
1388 |
:param target_value: target value for the layer before softmax: target_value when gt = 1, -target_value when gt = 0. |
|
|
1389 |
""" |
|
|
1390 |
|
|
|
1391 |
def __init__(self, target_value=5, **kwargs): |
|
|
1392 |
self.target_value = target_value |
|
|
1393 |
self.n_labels = None |
|
|
1394 |
super(WeightedL2Loss, self).__init__(**kwargs) |
|
|
1395 |
|
|
|
1396 |
def get_config(self): |
|
|
1397 |
config = super().get_config() |
|
|
1398 |
config["target_value"] = self.target_value |
|
|
1399 |
return config |
|
|
1400 |
|
|
|
1401 |
def build(self, input_shape): |
|
|
1402 |
assert len(input_shape) == 2, 'DiceLoss expects 2 inputs to compute the Dice loss.' |
|
|
1403 |
assert input_shape[0] == input_shape[1], 'the two inputs must have the same shape.' |
|
|
1404 |
self.n_labels = input_shape[0][-1] |
|
|
1405 |
self.built = True |
|
|
1406 |
super(WeightedL2Loss, self).build(input_shape) |
|
|
1407 |
|
|
|
1408 |
def call(self, inputs, **kwargs): |
|
|
1409 |
gt = inputs[0] |
|
|
1410 |
pred = inputs[1] |
|
|
1411 |
weights = tf.expand_dims(1 - gt[..., 0] + 1e-8, -1) |
|
|
1412 |
return K.sum(weights * K.square(pred - self.target_value * (2 * gt - 1))) / (K.sum(weights) * self.n_labels) |
|
|
1413 |
|
|
|
1414 |
def compute_output_shape(self, input_shape): |
|
|
1415 |
return [[]] |
|
|
1416 |
|
|
|
1417 |
|
|
|
1418 |
class CrossEntropyLoss(Layer): |
|
|
1419 |
"""This layer computes the cross-entropy loss between two tensors. |
|
|
1420 |
These tensors are expected to have the same shape (one-hot encoding) [batch, size_dim1, ..., size_dimN, n_labels]. |
|
|
1421 |
The first input tensor is the GT and the second is the prediction: ce_loss = CrossEntropyLoss()([gt, pred]) |
|
|
1422 |
|
|
|
1423 |
:param class_weights: (optional) if given, the loss is obtained by a weighted average of the Dice across labels. |
|
|
1424 |
Must be a sequence or 1d numpy array of length n_labels. Can also be -1, where the weights are dynamically set to |
|
|
1425 |
the inverse of the volume of each label in the ground truth. |
|
|
1426 |
:param boundary_weights: (optional) bonus weight that we apply to the voxels close to boundaries between structures |
|
|
1427 |
when computing the loss. Default is 0 where no boundary weighting is applied. |
|
|
1428 |
:param boundary_dist: (optional) if boundary_weight is not 0, the extra boundary weighting is applied to all voxels |
|
|
1429 |
within this distance to a region boundary. Default is 3. |
|
|
1430 |
:param skip_background: (optional) whether to skip boundary weighting for the background class, as this may be |
|
|
1431 |
redundant when we have several labels. This is only used if boundary_weight is not 0. |
|
|
1432 |
:param enable_checks: (optional) whether to make sure that the 2 input tensors are probabilistic (i.e. the label |
|
|
1433 |
probabilities sum to 1 at each voxel location). Default is True. |
|
|
1434 |
""" |
|
|
1435 |
|
|
|
1436 |
def __init__(self, |
|
|
1437 |
class_weights=None, |
|
|
1438 |
boundary_weights=0, |
|
|
1439 |
boundary_dist=3, |
|
|
1440 |
skip_background=True, |
|
|
1441 |
enable_checks=True, |
|
|
1442 |
**kwargs): |
|
|
1443 |
|
|
|
1444 |
self.class_weights = class_weights |
|
|
1445 |
self.dynamic_weighting = False |
|
|
1446 |
self.class_weights_tens = None |
|
|
1447 |
self.boundary_weights = boundary_weights |
|
|
1448 |
self.boundary_dist = boundary_dist |
|
|
1449 |
self.skip_background = skip_background |
|
|
1450 |
self.enable_checks = enable_checks |
|
|
1451 |
self.spatial_axes = None |
|
|
1452 |
self.avg_pooling_layer = None |
|
|
1453 |
super(CrossEntropyLoss, self).__init__(**kwargs) |
|
|
1454 |
|
|
|
1455 |
def get_config(self): |
|
|
1456 |
config = super().get_config() |
|
|
1457 |
config["class_weights"] = self.class_weights |
|
|
1458 |
config["boundary_weights"] = self.boundary_weights |
|
|
1459 |
config["boundary_dist"] = self.boundary_dist |
|
|
1460 |
config["skip_background"] = self.skip_background |
|
|
1461 |
config["enable_checks"] = self.enable_checks |
|
|
1462 |
return config |
|
|
1463 |
|
|
|
1464 |
def build(self, input_shape): |
|
|
1465 |
|
|
|
1466 |
# get shape |
|
|
1467 |
assert len(input_shape) == 2, 'CrossEntropy expects 2 inputs to compute the Dice loss.' |
|
|
1468 |
assert input_shape[0] == input_shape[1], 'the two inputs must have the same shape.' |
|
|
1469 |
inshape = input_shape[0][1:] |
|
|
1470 |
n_dims = len(inshape[:-1]) |
|
|
1471 |
n_labels = inshape[-1] |
|
|
1472 |
self.spatial_axes = list(range(1, n_dims + 1)) |
|
|
1473 |
self.avg_pooling_layer = getattr(keras.layers, 'AvgPool%dD' % n_dims) |
|
|
1474 |
self.skip_background = False if n_labels == 1 else self.skip_background |
|
|
1475 |
|
|
|
1476 |
# build tensor with class weights |
|
|
1477 |
if self.class_weights is not None: |
|
|
1478 |
if self.class_weights == -1: |
|
|
1479 |
self.dynamic_weighting = True |
|
|
1480 |
else: |
|
|
1481 |
class_weights_tens = utils.reformat_to_list(self.class_weights, n_labels) |
|
|
1482 |
class_weights_tens = tf.convert_to_tensor(class_weights_tens, 'float32') |
|
|
1483 |
self.class_weights_tens = l2i_et.expand_dims(class_weights_tens, [0] * (1 + n_dims)) |
|
|
1484 |
|
|
|
1485 |
self.built = True |
|
|
1486 |
super(CrossEntropyLoss, self).build(input_shape) |
|
|
1487 |
|
|
|
1488 |
def call(self, inputs, **kwargs): |
|
|
1489 |
|
|
|
1490 |
# make sure tensors are probabilistic |
|
|
1491 |
gt = inputs[0] |
|
|
1492 |
pred = inputs[1] |
|
|
1493 |
if self.enable_checks: # disabling is useful to, e.g., use incomplete label maps |
|
|
1494 |
gt = K.clip(gt / (tf.math.reduce_sum(gt, axis=-1, keepdims=True) + tf.keras.backend.epsilon()), 0, 1) |
|
|
1495 |
pred = pred / (tf.math.reduce_sum(pred, axis=-1, keepdims=True) + tf.keras.backend.epsilon()) |
|
|
1496 |
pred = K.clip(pred, tf.keras.backend.epsilon(), 1 - tf.keras.backend.epsilon()) # to avoid log(0) |
|
|
1497 |
|
|
|
1498 |
# compare prediction/target, ce has the same shape has the input tensors |
|
|
1499 |
ce = -gt * tf.math.log(pred) |
|
|
1500 |
|
|
|
1501 |
# apply boundary weighting (ie voxels close to region boundaries will be counted several times to compute Dice) |
|
|
1502 |
if self.boundary_weights: |
|
|
1503 |
avg = self.avg_pooling_layer(pool_size=2 * self.boundary_dist + 1, strides=1, padding='same')(gt) |
|
|
1504 |
boundaries = tf.cast(avg > 0., 'float32') * tf.cast(avg < (1 / len(self.spatial_axes) - 1e-4), 'float32') |
|
|
1505 |
if self.skip_background: |
|
|
1506 |
boundaries_channels = tf.unstack(boundaries, axis=-1) |
|
|
1507 |
boundaries = tf.stack([tf.zeros_like(boundaries_channels[0])] + boundaries_channels[1:], axis=-1) |
|
|
1508 |
boundary_weights_tensor = 1 + self.boundary_weights * boundaries |
|
|
1509 |
ce *= boundary_weights_tensor |
|
|
1510 |
else: |
|
|
1511 |
boundary_weights_tensor = None |
|
|
1512 |
|
|
|
1513 |
# apply class weighting across labels. By the end of this, ce still has the same shape has the input tensors. |
|
|
1514 |
if self.dynamic_weighting: # the weight of a class is the inverse of its volume in the gt |
|
|
1515 |
if boundary_weights_tensor is not None: # we account for the boundary weighting to compute volume |
|
|
1516 |
self.class_weights_tens = 1 / tf.reduce_sum(gt * boundary_weights_tensor, self.spatial_axes, True) |
|
|
1517 |
else: |
|
|
1518 |
self.class_weights_tens = 1 / tf.reduce_sum(gt, self.spatial_axes) |
|
|
1519 |
if self.class_weights_tens is not None: |
|
|
1520 |
self.class_weights_tens /= tf.reduce_sum(self.class_weights_tens, -1) |
|
|
1521 |
ce = tf.reduce_sum(ce * self.class_weights_tens, -1) |
|
|
1522 |
|
|
|
1523 |
# sum along label axis, and take the mean along spatial dimensions |
|
|
1524 |
ce = tf.math.reduce_mean(tf.math.reduce_sum(ce, axis=-1)) |
|
|
1525 |
|
|
|
1526 |
return ce |
|
|
1527 |
|
|
|
1528 |
def compute_output_shape(self, input_shape): |
|
|
1529 |
return [[]] |
|
|
1530 |
|
|
|
1531 |
|
|
|
1532 |
class MomentLoss(Layer): |
|
|
1533 |
"""This layer computes a moment loss between two tensors. Specifically, it computes the distance between the centres |
|
|
1534 |
of gravity for all the channels of the two tensors, and then returns a value averaged across all channels. |
|
|
1535 |
These tensors are expected to have the same shape [batch, size_dim1, ..., size_dimN, n_channels]. |
|
|
1536 |
The first input tensor is the GT and the second is the prediction: moment_loss = MomentLoss()([gt, pred]) |
|
|
1537 |
|
|
|
1538 |
:param class_weights: (optional) if given, the loss is obtained by a weighted average of the Dice across labels. |
|
|
1539 |
Must be a sequence or 1d numpy array of length n_labels. Can also be -1, where the weights are dynamically set to |
|
|
1540 |
the inverse of the volume of each label in the ground truth. |
|
|
1541 |
:param enable_checks: (optional) whether to make sure that the 2 input tensors are probabilistic (i.e. the label |
|
|
1542 |
probabilities sum to 1 at each voxel location). Default is True. |
|
|
1543 |
""" |
|
|
1544 |
|
|
|
1545 |
def __init__(self, class_weights=None, enable_checks=False, **kwargs): |
|
|
1546 |
self.class_weights = class_weights |
|
|
1547 |
self.dynamic_weighting = False |
|
|
1548 |
self.class_weights_tens = None |
|
|
1549 |
self.enable_checks = enable_checks |
|
|
1550 |
self.spatial_axes = None |
|
|
1551 |
self.coordinates = None |
|
|
1552 |
super(MomentLoss, self).__init__(**kwargs) |
|
|
1553 |
|
|
|
1554 |
def get_config(self): |
|
|
1555 |
config = super().get_config() |
|
|
1556 |
config["class_weights"] = self.class_weights |
|
|
1557 |
config["enable_checks"] = self.enable_checks |
|
|
1558 |
return config |
|
|
1559 |
|
|
|
1560 |
def build(self, input_shape): |
|
|
1561 |
|
|
|
1562 |
# get shape |
|
|
1563 |
assert len(input_shape) == 2, 'MomentLoss expects 2 inputs to compute the Dice loss.' |
|
|
1564 |
assert input_shape[0] == input_shape[1], 'the two inputs must have the same shape.' |
|
|
1565 |
inshape = input_shape[0][1:] |
|
|
1566 |
n_dims = len(inshape[:-1]) |
|
|
1567 |
n_labels = inshape[-1] |
|
|
1568 |
self.spatial_axes = list(range(1, n_dims + 1)) |
|
|
1569 |
|
|
|
1570 |
# build coordinate meshgrid of size (1, dim1, dim2, ..., dimN, ndim, nchan) |
|
|
1571 |
self.coordinates = tf.stack(nrn_utils.volshape_to_ndgrid(inshape[:-1]), -1) |
|
|
1572 |
self.coordinates = tf.cast(l2i_et.expand_dims(tf.stack([self.coordinates] * n_labels, -1), 0), 'float32') |
|
|
1573 |
|
|
|
1574 |
# build tensor with class weights |
|
|
1575 |
if self.class_weights is not None: |
|
|
1576 |
if self.class_weights == -1: |
|
|
1577 |
self.dynamic_weighting = True |
|
|
1578 |
else: |
|
|
1579 |
class_weights_tens = utils.reformat_to_list(self.class_weights, n_labels) |
|
|
1580 |
class_weights_tens = tf.convert_to_tensor(class_weights_tens, 'float32') |
|
|
1581 |
self.class_weights_tens = l2i_et.expand_dims(class_weights_tens, 0) |
|
|
1582 |
|
|
|
1583 |
self.built = True |
|
|
1584 |
super(MomentLoss, self).build(input_shape) |
|
|
1585 |
|
|
|
1586 |
def call(self, inputs, **kwargs): |
|
|
1587 |
|
|
|
1588 |
# make sure tensors are probabilistic |
|
|
1589 |
gt = inputs[0] # (B, dim1, dim2, ..., dimN, nchan) |
|
|
1590 |
pred = inputs[1] |
|
|
1591 |
if self.enable_checks: # disabling is useful to, e.g., use incomplete label maps |
|
|
1592 |
gt = gt / (tf.math.reduce_sum(gt, axis=-1, keepdims=True) + tf.keras.backend.epsilon()) |
|
|
1593 |
pred = pred / (tf.math.reduce_sum(pred, axis=-1, keepdims=True) + tf.keras.backend.epsilon()) |
|
|
1594 |
|
|
|
1595 |
# compute loss |
|
|
1596 |
gt_mean_coordinates = self._mean_coordinates(gt) # (B, ndim, nchan) |
|
|
1597 |
pred_mean_coordinates = self._mean_coordinates(pred) |
|
|
1598 |
loss = tf.math.sqrt(tf.reduce_sum(tf.square(pred_mean_coordinates - gt_mean_coordinates), axis=1)) # (B, nchan) |
|
|
1599 |
|
|
|
1600 |
# apply class weighting across labels. In this case loss will have shape (batch), otherwise (batch, n_labels). |
|
|
1601 |
if self.dynamic_weighting: # the weight of a class is the inverse of its volume in the gt |
|
|
1602 |
self.class_weights_tens = 1 / tf.reduce_sum(gt, self.spatial_axes) |
|
|
1603 |
if self.class_weights_tens is not None: |
|
|
1604 |
self.class_weights_tens /= tf.reduce_sum(self.class_weights_tens, -1) |
|
|
1605 |
loss = tf.reduce_sum(loss * self.class_weights_tens, -1) |
|
|
1606 |
|
|
|
1607 |
return tf.math.reduce_mean(loss) |
|
|
1608 |
|
|
|
1609 |
def _mean_coordinates(self, tensor): |
|
|
1610 |
tensor = l2i_et.expand_dims(tensor, axis=-2) # (B, dim1, dim2, ..., dimN, 1, nchan) |
|
|
1611 |
numerator = tf.reduce_sum(tensor * self.coordinates, axis=self.spatial_axes) # (B, ndim, nchan) |
|
|
1612 |
denominator = tf.reduce_sum(tensor, axis=self.spatial_axes) + tf.keras.backend.epsilon() |
|
|
1613 |
return numerator / denominator |
|
|
1614 |
|
|
|
1615 |
def compute_output_shape(self, input_shape): |
|
|
1616 |
return [[]] |
|
|
1617 |
|
|
|
1618 |
|
|
|
1619 |
class ResetValuesToZero(Layer): |
|
|
1620 |
"""This layer enables to reset given values to 0 within the input tensors. |
|
|
1621 |
|
|
|
1622 |
:param values: list of values to be reset to 0. |
|
|
1623 |
|
|
|
1624 |
example: |
|
|
1625 |
input = tf.convert_to_tensor(np.array([[1, 0, 2, 2, 2, 2, 0], |
|
|
1626 |
[1, 3, 3, 3, 3, 3, 3], |
|
|
1627 |
[1, 0, 0, 0, 4, 4, 4]])) |
|
|
1628 |
values = [1, 3] |
|
|
1629 |
ResetValuesToZero(values)(input) |
|
|
1630 |
>> [[0, 0, 2, 2, 2, 2, 0], |
|
|
1631 |
[0, 0, 0, 0, 0, 0, 0], |
|
|
1632 |
[0, 0, 0, 0, 4, 4, 4]] |
|
|
1633 |
""" |
|
|
1634 |
|
|
|
1635 |
def __init__(self, values, **kwargs): |
|
|
1636 |
assert values is not None, 'please provide correct list of values, received None' |
|
|
1637 |
self.values = utils.reformat_to_list(values) |
|
|
1638 |
self.values_tens = None |
|
|
1639 |
self.n_values = len(values) |
|
|
1640 |
super(ResetValuesToZero, self).__init__(**kwargs) |
|
|
1641 |
|
|
|
1642 |
def get_config(self): |
|
|
1643 |
config = super().get_config() |
|
|
1644 |
config["values"] = self.values |
|
|
1645 |
return config |
|
|
1646 |
|
|
|
1647 |
def build(self, input_shape): |
|
|
1648 |
self.values_tens = tf.convert_to_tensor(self.values) |
|
|
1649 |
self.built = True |
|
|
1650 |
super(ResetValuesToZero, self).build(input_shape) |
|
|
1651 |
|
|
|
1652 |
def call(self, inputs, **kwargs): |
|
|
1653 |
values = tf.cast(self.values_tens, dtype=inputs.dtype) |
|
|
1654 |
for i in range(self.n_values): |
|
|
1655 |
inputs = tf.where(tf.equal(inputs, values[i]), tf.zeros_like(inputs), inputs) |
|
|
1656 |
return inputs |
|
|
1657 |
|
|
|
1658 |
|
|
|
1659 |
class ConvertLabels(Layer): |
|
|
1660 |
"""Convert all labels in a tensor by the corresponding given set of values. |
|
|
1661 |
labels_converted = ConvertLabels(source_values, dest_values)(labels). |
|
|
1662 |
labels must be an int32 tensor, and labels_converted will also be int32. |
|
|
1663 |
|
|
|
1664 |
:param source_values: list of all the possible values in labels. Must be a list or a 1D numpy array. |
|
|
1665 |
:param dest_values: list of all the target label values. Must be ordered the same as source values: |
|
|
1666 |
labels[labels == source_values[i]] = dest_values[i]. |
|
|
1667 |
If None (default), dest_values is equal to [0, ..., N-1], where N is the total number of values in source_values, |
|
|
1668 |
which enables to remap label maps to [0, ..., N-1]. |
|
|
1669 |
""" |
|
|
1670 |
|
|
|
1671 |
def __init__(self, source_values, dest_values=None, **kwargs): |
|
|
1672 |
self.source_values = source_values |
|
|
1673 |
self.dest_values = dest_values |
|
|
1674 |
self.lut = None |
|
|
1675 |
super(ConvertLabels, self).__init__(**kwargs) |
|
|
1676 |
|
|
|
1677 |
def get_config(self): |
|
|
1678 |
config = super().get_config() |
|
|
1679 |
config["source_values"] = self.source_values |
|
|
1680 |
config["dest_values"] = self.dest_values |
|
|
1681 |
return config |
|
|
1682 |
|
|
|
1683 |
def build(self, input_shape): |
|
|
1684 |
self.lut = tf.convert_to_tensor(utils.get_mapping_lut(self.source_values, dest=self.dest_values), dtype='int32') |
|
|
1685 |
self.built = True |
|
|
1686 |
super(ConvertLabels, self).build(input_shape) |
|
|
1687 |
|
|
|
1688 |
def call(self, inputs, **kwargs): |
|
|
1689 |
return tf.gather(self.lut, tf.cast(inputs, dtype='int32')) |
|
|
1690 |
|
|
|
1691 |
|
|
|
1692 |
class PadAroundCentre(Layer): |
|
|
1693 |
"""Pad the input tensor to the specified shape with the given value. |
|
|
1694 |
The input tensor is expected to have shape [batchsize, shape_dim1, ..., shape_dimn, channel]. |
|
|
1695 |
:param pad_margin: margin to use for padding. The tensor will be padded by the provided margin on each side. |
|
|
1696 |
Can either be a number (all axes padded with the same margin), or a list/numpy array of length n_dims. |
|
|
1697 |
example: if tensor is of shape [batch, x, y, z, n_channels] and margin=10, then the padded tensor will be of |
|
|
1698 |
shape [batch, x+2*10, y+2*10, z+2*10, n_channels]. |
|
|
1699 |
:param pad_shape: shape to pad the tensor to. Can either be a number (all axes padded to the same shape), or a |
|
|
1700 |
list/numpy array of length n_dims. |
|
|
1701 |
:param value: value to pad the tensors with. Default is 0. |
|
|
1702 |
""" |
|
|
1703 |
|
|
|
1704 |
def __init__(self, pad_margin=None, pad_shape=None, value=0, **kwargs): |
|
|
1705 |
self.pad_margin = pad_margin |
|
|
1706 |
self.pad_shape = pad_shape |
|
|
1707 |
self.value = value |
|
|
1708 |
self.pad_margin_tens = None |
|
|
1709 |
self.pad_shape_tens = None |
|
|
1710 |
self.n_dims = None |
|
|
1711 |
super(PadAroundCentre, self).__init__(**kwargs) |
|
|
1712 |
|
|
|
1713 |
def get_config(self): |
|
|
1714 |
config = super().get_config() |
|
|
1715 |
config["pad_margin"] = self.pad_margin |
|
|
1716 |
config["pad_shape"] = self.pad_shape |
|
|
1717 |
config["value"] = self.value |
|
|
1718 |
return config |
|
|
1719 |
|
|
|
1720 |
def build(self, input_shape): |
|
|
1721 |
# input shape |
|
|
1722 |
self.n_dims = len(input_shape) - 2 |
|
|
1723 |
shape = list(input_shape) |
|
|
1724 |
shape[0] = 0 |
|
|
1725 |
shape[-1] = 0 |
|
|
1726 |
|
|
|
1727 |
if self.pad_margin is not None: |
|
|
1728 |
assert self.pad_shape is None, 'please do not provide a padding shape and margin at the same time.' |
|
|
1729 |
|
|
|
1730 |
# reformat padding margins |
|
|
1731 |
pad = np.transpose(np.array([[0] + utils.reformat_to_list(self.pad_margin, self.n_dims) + [0]] * 2)) |
|
|
1732 |
self.pad_margin_tens = tf.convert_to_tensor(pad, dtype='int32') |
|
|
1733 |
|
|
|
1734 |
elif self.pad_shape is not None: |
|
|
1735 |
assert self.pad_margin is None, 'please do not provide a padding shape and margin at the same time.' |
|
|
1736 |
|
|
|
1737 |
# pad shape |
|
|
1738 |
tensor_shape = tf.cast(tf.convert_to_tensor(shape), 'int32') |
|
|
1739 |
self.pad_shape_tens = np.array([0] + utils.reformat_to_list(self.pad_shape, length=self.n_dims) + [0]) |
|
|
1740 |
self.pad_shape_tens = tf.convert_to_tensor(self.pad_shape_tens, dtype='int32') |
|
|
1741 |
self.pad_shape_tens = tf.math.maximum(tensor_shape, self.pad_shape_tens) |
|
|
1742 |
|
|
|
1743 |
# padding margin |
|
|
1744 |
min_margins = (self.pad_shape_tens - tensor_shape) / 2 |
|
|
1745 |
max_margins = self.pad_shape_tens - tensor_shape - min_margins |
|
|
1746 |
self.pad_margin_tens = tf.stack([min_margins, max_margins], axis=-1) |
|
|
1747 |
|
|
|
1748 |
else: |
|
|
1749 |
raise Exception('please either provide a padding shape or a padding margin.') |
|
|
1750 |
|
|
|
1751 |
self.built = True |
|
|
1752 |
super(PadAroundCentre, self).build(input_shape) |
|
|
1753 |
|
|
|
1754 |
def call(self, inputs, **kwargs): |
|
|
1755 |
return tf.pad(inputs, self.pad_margin_tens, mode='CONSTANT', constant_values=self.value) |
|
|
1756 |
|
|
|
1757 |
|
|
|
1758 |
class MaskEdges(Layer): |
|
|
1759 |
"""Reset the edges of a tensor to zero (i.e. with bands of zeros along the specified axes). |
|
|
1760 |
The width of the zero-band is randomly drawn from a uniform distribution, whose range is given in boundaries. |
|
|
1761 |
|
|
|
1762 |
:param axes: axes along which to reset edges to zero. Can be an int (single axis), or a sequence. |
|
|
1763 |
:param boundaries: numpy array of shape (len(axes), 4). Each row contains the two bounds of the uniform |
|
|
1764 |
distributions from which we draw the width of the zero-bands on each side. |
|
|
1765 |
Those bounds must be expressed in relative side (i.e. between 0 and 1). |
|
|
1766 |
:return: a tensor of the same shape as the input, with bands of zeros along the specified axes. |
|
|
1767 |
|
|
|
1768 |
example: |
|
|
1769 |
tensor=tf.constant([[[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], |
|
|
1770 |
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], |
|
|
1771 |
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], |
|
|
1772 |
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], |
|
|
1773 |
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], |
|
|
1774 |
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], |
|
|
1775 |
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], |
|
|
1776 |
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], |
|
|
1777 |
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], |
|
|
1778 |
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]]]) # shape = [1,10,10,1] |
|
|
1779 |
axes=1 |
|
|
1780 |
boundaries = np.array([[0.2, 0.45, 0.85, 0.9]]) |
|
|
1781 |
|
|
|
1782 |
In this case, we reset the edges along the 2nd dimension (i.e. the 1st dimension after the batch dimension), |
|
|
1783 |
the 1st zero-band will expand from the 1st row to a number drawn from [0.2*tensor.shape[1], 0.45*tensor.shape[1]], |
|
|
1784 |
and the 2nd zero-band will expand from a row drawn from [0.85*tensor.shape[1], 0.9*tensor.shape[1]], to the end of |
|
|
1785 |
the tensor. A possible output could be: |
|
|
1786 |
array([[[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], |
|
|
1787 |
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], |
|
|
1788 |
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], |
|
|
1789 |
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], |
|
|
1790 |
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], |
|
|
1791 |
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], |
|
|
1792 |
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], |
|
|
1793 |
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], |
|
|
1794 |
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], |
|
|
1795 |
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]]]) # shape = [1,10,10,1] |
|
|
1796 |
""" |
|
|
1797 |
|
|
|
1798 |
def __init__(self, axes, boundaries, prob_mask=1, **kwargs): |
|
|
1799 |
self.axes = utils.reformat_to_list(axes, dtype='int') |
|
|
1800 |
self.boundaries = utils.reformat_to_n_channels_array(boundaries, n_dims=4, n_channels=len(self.axes)) |
|
|
1801 |
self.prob_mask = prob_mask |
|
|
1802 |
self.inputshape = None |
|
|
1803 |
super(MaskEdges, self).__init__(**kwargs) |
|
|
1804 |
|
|
|
1805 |
def get_config(self): |
|
|
1806 |
config = super().get_config() |
|
|
1807 |
config["axes"] = self.axes |
|
|
1808 |
config["boundaries"] = self.boundaries |
|
|
1809 |
config["prob_mask"] = self.prob_mask |
|
|
1810 |
return config |
|
|
1811 |
|
|
|
1812 |
def build(self, input_shape): |
|
|
1813 |
self.inputshape = input_shape |
|
|
1814 |
self.built = True |
|
|
1815 |
super(MaskEdges, self).build(input_shape) |
|
|
1816 |
|
|
|
1817 |
def call(self, inputs, **kwargs): |
|
|
1818 |
|
|
|
1819 |
# build mask |
|
|
1820 |
mask = tf.ones_like(inputs) |
|
|
1821 |
for i, axis in enumerate(self.axes): |
|
|
1822 |
|
|
|
1823 |
# select restricting indices |
|
|
1824 |
axis_boundaries = self.boundaries[i, :] |
|
|
1825 |
idx1 = tf.math.round(tf.random.uniform([1], |
|
|
1826 |
minval=axis_boundaries[0] * self.inputshape[axis], |
|
|
1827 |
maxval=axis_boundaries[1] * self.inputshape[axis])) |
|
|
1828 |
idx2 = tf.math.round(tf.random.uniform([1], |
|
|
1829 |
minval=axis_boundaries[2] * self.inputshape[axis], |
|
|
1830 |
maxval=axis_boundaries[3] * self.inputshape[axis] - 1) - idx1) |
|
|
1831 |
idx3 = self.inputshape[axis] - idx1 - idx2 |
|
|
1832 |
split_idx = tf.cast(tf.concat([idx1, idx2, idx3], axis=0), dtype='int32') |
|
|
1833 |
|
|
|
1834 |
# update mask |
|
|
1835 |
split_list = tf.split(inputs, split_idx, axis=axis) |
|
|
1836 |
tmp_mask = tf.concat([tf.zeros_like(split_list[0]), |
|
|
1837 |
tf.ones_like(split_list[1]), |
|
|
1838 |
tf.zeros_like(split_list[2])], axis=axis) |
|
|
1839 |
mask = mask * tmp_mask |
|
|
1840 |
|
|
|
1841 |
# mask second_channel |
|
|
1842 |
tensor = K.switch(tf.squeeze(K.greater(tf.random.uniform([1], 0, 1), 1 - self.prob_mask)), |
|
|
1843 |
inputs * mask, |
|
|
1844 |
inputs) |
|
|
1845 |
|
|
|
1846 |
return [tensor, mask] |
|
|
1847 |
|
|
|
1848 |
def compute_output_shape(self, input_shape): |
|
|
1849 |
return [input_shape] * 2 |
|
|
1850 |
|
|
|
1851 |
|
|
|
1852 |
class ImageGradients(Layer): |
|
|
1853 |
|
|
|
1854 |
def __init__(self, gradient_type='sobel', return_magnitude=False, **kwargs): |
|
|
1855 |
|
|
|
1856 |
self.gradient_type = gradient_type |
|
|
1857 |
assert (self.gradient_type == 'sobel') | (self.gradient_type == '1-step_diff'), \ |
|
|
1858 |
'gradient_type should be either sobel or 1-step_diff, had %s' % self.gradient_type |
|
|
1859 |
|
|
|
1860 |
# shape |
|
|
1861 |
self.n_dims = 0 |
|
|
1862 |
self.shape = None |
|
|
1863 |
self.n_channels = 0 |
|
|
1864 |
|
|
|
1865 |
# convolution params if sobel diff |
|
|
1866 |
self.stride = None |
|
|
1867 |
self.kernels = None |
|
|
1868 |
self.convnd = None |
|
|
1869 |
|
|
|
1870 |
self.return_magnitude = return_magnitude |
|
|
1871 |
|
|
|
1872 |
super(ImageGradients, self).__init__(**kwargs) |
|
|
1873 |
|
|
|
1874 |
def get_config(self): |
|
|
1875 |
config = super().get_config() |
|
|
1876 |
config["gradient_type"] = self.gradient_type |
|
|
1877 |
config["return_magnitude"] = self.return_magnitude |
|
|
1878 |
return config |
|
|
1879 |
|
|
|
1880 |
def build(self, input_shape): |
|
|
1881 |
|
|
|
1882 |
# get shapes |
|
|
1883 |
self.n_dims = len(input_shape) - 2 |
|
|
1884 |
self.shape = input_shape[1:] |
|
|
1885 |
self.n_channels = input_shape[-1] |
|
|
1886 |
|
|
|
1887 |
# prepare kernel if sobel gradients |
|
|
1888 |
if self.gradient_type == 'sobel': |
|
|
1889 |
self.kernels = l2i_et.sobel_kernels(self.n_dims) |
|
|
1890 |
self.stride = [1] * (self.n_dims + 2) |
|
|
1891 |
self.convnd = getattr(tf.nn, 'conv%dd' % self.n_dims) |
|
|
1892 |
else: |
|
|
1893 |
self.kernels = self.convnd = self.stride = None |
|
|
1894 |
|
|
|
1895 |
self.built = True |
|
|
1896 |
super(ImageGradients, self).build(input_shape) |
|
|
1897 |
|
|
|
1898 |
def call(self, inputs, **kwargs): |
|
|
1899 |
|
|
|
1900 |
image = inputs |
|
|
1901 |
batchsize = tf.split(tf.shape(inputs), [1, -1])[0] |
|
|
1902 |
gradients = list() |
|
|
1903 |
|
|
|
1904 |
# sobel method |
|
|
1905 |
if self.gradient_type == 'sobel': |
|
|
1906 |
# get sobel gradients in each direction |
|
|
1907 |
for n in range(self.n_dims): |
|
|
1908 |
gradient = image |
|
|
1909 |
# apply 1D kernel in each direction (sobel kernels are separable), instead of applying a nD kernel |
|
|
1910 |
for k in self.kernels[n]: |
|
|
1911 |
gradient = tf.concat([self.convnd(tf.expand_dims(gradient[..., n], -1), k, self.stride, 'SAME') |
|
|
1912 |
for n in range(self.n_channels)], -1) |
|
|
1913 |
gradients.append(gradient) |
|
|
1914 |
|
|
|
1915 |
# 1-step method, only supports 2 and 3D |
|
|
1916 |
else: |
|
|
1917 |
|
|
|
1918 |
# get 1-step diff |
|
|
1919 |
if self.n_dims == 2: |
|
|
1920 |
gradients.append(image[:, 1:, :, :] - image[:, :-1, :, :]) # dx |
|
|
1921 |
gradients.append(image[:, :, 1:, :] - image[:, :, :-1, :]) # dy |
|
|
1922 |
|
|
|
1923 |
elif self.n_dims == 3: |
|
|
1924 |
gradients.append(image[:, 1:, :, :, :] - image[:, :-1, :, :, :]) # dx |
|
|
1925 |
gradients.append(image[:, :, 1:, :, :] - image[:, :, :-1, :, :]) # dy |
|
|
1926 |
gradients.append(image[:, :, :, 1:, :] - image[:, :, :, :-1, :]) # dz |
|
|
1927 |
|
|
|
1928 |
else: |
|
|
1929 |
raise Exception('ImageGradients only support 2D or 3D tensors for 1-step diff, had: %dD' % self.n_dims) |
|
|
1930 |
|
|
|
1931 |
# pad with zeros to return tensors of the same shape as input |
|
|
1932 |
for i in range(self.n_dims): |
|
|
1933 |
tmp_shape = list(self.shape) |
|
|
1934 |
tmp_shape[i] = 1 |
|
|
1935 |
zeros = tf.zeros(tf.concat([batchsize, tf.convert_to_tensor(tmp_shape, dtype='int32')], 0), image.dtype) |
|
|
1936 |
gradients[i] = tf.concat([gradients[i], zeros], axis=i + 1) |
|
|
1937 |
|
|
|
1938 |
# compute total gradient magnitude if necessary, or concatenate different gradients along the channel axis |
|
|
1939 |
if self.return_magnitude: |
|
|
1940 |
gradients = tf.sqrt(tf.reduce_sum(tf.square(tf.stack(gradients, axis=-1)), axis=-1)) |
|
|
1941 |
else: |
|
|
1942 |
gradients = tf.concat(gradients, axis=-1) |
|
|
1943 |
|
|
|
1944 |
return gradients |
|
|
1945 |
|
|
|
1946 |
def compute_output_shape(self, input_shape): |
|
|
1947 |
if not self.return_magnitude: |
|
|
1948 |
input_shape = list(input_shape) |
|
|
1949 |
input_shape[-1] = self.n_dims |
|
|
1950 |
return tuple(input_shape) |
|
|
1951 |
|
|
|
1952 |
|
|
|
1953 |
class RandomDilationErosion(Layer): |
|
|
1954 |
""" |
|
|
1955 |
GPU implementation of binary dilation or erosion. The operation can be chosen to be always a dilation, or always an |
|
|
1956 |
erosion, or randomly choosing between them for each element of the batch. |
|
|
1957 |
The chosen operation is applied to the input with a given probability. Moreover, it is also possible to randomise |
|
|
1958 |
the factor of the operation for each element of the mini-batch. |
|
|
1959 |
:param min_factor: minimum possible value for the dilation/erosion factor. Must be an integer. |
|
|
1960 |
:param max_factor: minimum possible value for the dilation/erosion factor. Must be an integer. |
|
|
1961 |
Set it to the same value as min_factor to always perform dilation/erosion with the same factor. |
|
|
1962 |
:param prob: probability with which to apply the selected operation to the input. |
|
|
1963 |
:param operation: which operation to apply. Can be 'dilation' or 'erosion' or 'random'. |
|
|
1964 |
:param return_mask: if operation is erosion and the input of this layer is a label map, we have the |
|
|
1965 |
choice to either return the eroded label map or the mask (return_mask=True) |
|
|
1966 |
""" |
|
|
1967 |
|
|
|
1968 |
def __init__(self, min_factor, max_factor, max_factor_dilate=None, prob=1, operation='random', return_mask=False, |
|
|
1969 |
**kwargs): |
|
|
1970 |
|
|
|
1971 |
self.min_factor = min_factor |
|
|
1972 |
self.max_factor = max_factor |
|
|
1973 |
self.max_factor_dilate = max_factor_dilate if max_factor_dilate is not None else self.max_factor |
|
|
1974 |
self.prob = prob |
|
|
1975 |
self.operation = operation |
|
|
1976 |
self.return_mask = return_mask |
|
|
1977 |
self.n_dims = None |
|
|
1978 |
self.inshape = None |
|
|
1979 |
self.n_channels = None |
|
|
1980 |
self.convnd = None |
|
|
1981 |
super(RandomDilationErosion, self).__init__(**kwargs) |
|
|
1982 |
|
|
|
1983 |
def get_config(self): |
|
|
1984 |
config = super().get_config() |
|
|
1985 |
config["min_factor"] = self.min_factor |
|
|
1986 |
config["max_factor"] = self.max_factor |
|
|
1987 |
config["max_factor_dilate"] = self.max_factor_dilate |
|
|
1988 |
config["prob"] = self.prob |
|
|
1989 |
config["operation"] = self.operation |
|
|
1990 |
config["return_mask"] = self.return_mask |
|
|
1991 |
return config |
|
|
1992 |
|
|
|
1993 |
def build(self, input_shape): |
|
|
1994 |
|
|
|
1995 |
# input shape |
|
|
1996 |
self.inshape = input_shape |
|
|
1997 |
self.n_dims = len(self.inshape) - 2 |
|
|
1998 |
self.n_channels = self.inshape[-1] |
|
|
1999 |
|
|
|
2000 |
# prepare convolution |
|
|
2001 |
self.convnd = getattr(tf.nn, 'conv%dd' % self.n_dims) |
|
|
2002 |
|
|
|
2003 |
self.built = True |
|
|
2004 |
super(RandomDilationErosion, self).build(input_shape) |
|
|
2005 |
|
|
|
2006 |
def call(self, inputs, **kwargs): |
|
|
2007 |
|
|
|
2008 |
# sample probability of applying operation. If random negative is erosion and positive is dilation |
|
|
2009 |
batchsize = tf.split(tf.shape(inputs), [1, -1])[0] |
|
|
2010 |
shape = tf.concat([batchsize, tf.convert_to_tensor([1], dtype='int32')], axis=0) |
|
|
2011 |
if self.operation == 'dilation': |
|
|
2012 |
prob = tf.random.uniform(shape, 0, 1) |
|
|
2013 |
elif self.operation == 'erosion': |
|
|
2014 |
prob = tf.random.uniform(shape, -1, 0) |
|
|
2015 |
elif self.operation == 'random': |
|
|
2016 |
prob = tf.random.uniform(shape, -1, 1) |
|
|
2017 |
else: |
|
|
2018 |
raise ValueError("operation should either be 'dilation' 'erosion' or 'random', had %s" % self.operation) |
|
|
2019 |
|
|
|
2020 |
# build kernel |
|
|
2021 |
if self.min_factor == self.max_factor: |
|
|
2022 |
dist_threshold = self.min_factor * tf.ones(shape, dtype='int32') |
|
|
2023 |
else: |
|
|
2024 |
if (self.max_factor == self.max_factor_dilate) | (self.operation != 'random'): |
|
|
2025 |
dist_threshold = tf.random.uniform(shape, minval=self.min_factor, maxval=self.max_factor, dtype='int32') |
|
|
2026 |
else: |
|
|
2027 |
dist_threshold = tf.cast(tf.map_fn(self._sample_factor, [prob], dtype=tf.float32), dtype='int32') |
|
|
2028 |
kernel = l2i_et.unit_kernel(dist_threshold, self.n_dims, max_dist_threshold=self.max_factor) |
|
|
2029 |
|
|
|
2030 |
# convolve input mask with kernel according to given probability |
|
|
2031 |
mask = tf.cast(tf.cast(inputs, dtype='bool'), dtype='float32') |
|
|
2032 |
mask = tf.map_fn(self._single_blur, [mask, kernel, prob], dtype=tf.float32) |
|
|
2033 |
mask = tf.cast(mask, 'bool') |
|
|
2034 |
|
|
|
2035 |
if self.return_mask: |
|
|
2036 |
return mask |
|
|
2037 |
else: |
|
|
2038 |
return inputs * tf.cast(mask, dtype=inputs.dtype) |
|
|
2039 |
|
|
|
2040 |
def _sample_factor(self, inputs): |
|
|
2041 |
return tf.cast(K.switch(K.less(tf.squeeze(inputs[0]), 0), |
|
|
2042 |
tf.random.uniform((1,), self.min_factor, self.max_factor, dtype='int32'), |
|
|
2043 |
tf.random.uniform((1,), self.min_factor, self.max_factor_dilate, dtype='int32')), |
|
|
2044 |
dtype='float32') |
|
|
2045 |
|
|
|
2046 |
def _single_blur(self, inputs): |
|
|
2047 |
# dilate... |
|
|
2048 |
new_mask = K.switch(K.greater(tf.squeeze(inputs[2]), 1 - self.prob + 0.001), |
|
|
2049 |
tf.cast(tf.greater(tf.squeeze(self.convnd(tf.expand_dims(inputs[0], 0), inputs[1], |
|
|
2050 |
[1] * (self.n_dims + 2), padding='SAME'), axis=0), 0.01), dtype='float32'), |
|
|
2051 |
inputs[0]) |
|
|
2052 |
# ...or erode |
|
|
2053 |
new_mask = K.switch(K.less(tf.squeeze(inputs[2]), - (1 - self.prob + 0.001)), |
|
|
2054 |
1 - tf.cast(tf.greater(tf.squeeze(self.convnd(tf.expand_dims(1 - new_mask, 0), inputs[1], |
|
|
2055 |
[1] * (self.n_dims + 2), padding='SAME'), axis=0), 0.01), dtype='float32'), |
|
|
2056 |
new_mask) |
|
|
2057 |
return new_mask |
|
|
2058 |
|
|
|
2059 |
def compute_output_shape(self, input_shape): |
|
|
2060 |
return input_shape |