Diff of /ext/neuron/layers.py [000000] .. [e571d1]

Switch to unified view

a b/ext/neuron/layers.py
1
"""
2
tensorflow/keras utilities for the neuron project
3
4
If you use this code, please cite 
5
Dalca AV, Guttag J, Sabuncu MR
6
Anatomical Priors in Convolutional Networks for Unsupervised Biomedical Segmentation, 
7
CVPR 2018
8
9
or for the transformation/integration functions:
10
11
Unsupervised Learning for Fast Probabilistic Diffeomorphic Registration
12
Adrian V. Dalca, Guha Balakrishnan, John Guttag, Mert R. Sabuncu
13
MICCAI 2018.
14
15
Contact: adalca [at] csail [dot] mit [dot] edu
16
License: GPLv3
17
"""
18
19
# third party
20
import tensorflow as tf
21
from keras import backend as K
22
from keras.layers import Layer
23
from copy import deepcopy
24
25
# local
26
from ext.neuron.utils import transform, resize, integrate_vec, affine_to_shift, combine_non_linear_and_aff_to_shift
27
28
29
class SpatialTransformer(Layer):
30
    """
31
    N-D Spatial Transformer Tensorflow / Keras Layer
32
33
    The Layer can handle both affine and dense transforms. 
34
    Both transforms are meant to give a 'shift' from the current position.
35
    Therefore, a dense transform gives displacements (not absolute locations) at each voxel,
36
    and an affine transform gives the *difference* of the affine matrix from 
37
    the identity matrix.
38
39
    If you find this function useful, please cite:
40
      Unsupervised Learning for Fast Probabilistic Diffeomorphic Registration
41
      Adrian V. Dalca, Guha Balakrishnan, John Guttag, Mert R. Sabuncu
42
      MICCAI 2018.
43
44
    Originally, this code was based on voxelmorph code, which 
45
    was in turn transformed to be dense with the help of (affine) STN code 
46
    via https://github.com/kevinzakka/spatial-transformer-network
47
48
    Since then, we've re-written the code to be generalized to any 
49
    dimensions, and along the way wrote grid and interpolation functions
50
    """
51
52
    def __init__(self,
53
                 interp_method='linear',
54
                 indexing='ij',
55
                 single_transform=False,
56
                 **kwargs):
57
        """
58
        Parameters: 
59
            interp_method: 'linear' or 'nearest'
60
            single_transform: whether a single transform supplied for the whole batch
61
            indexing (default: 'ij'): 'ij' (matrix) or 'xy' (cartesian)
62
                'xy' indexing will have the first two entries of the flow 
63
                (along last axis) flipped compared to 'ij' indexing
64
        """
65
        self.interp_method = interp_method
66
        self.ndims = None
67
        self.inshape = None
68
        self.single_transform = single_transform
69
        self.is_affine = list()
70
71
        assert indexing in ['ij', 'xy'], "indexing has to be 'ij' (matrix) or 'xy' (cartesian)"
72
        self.indexing = indexing
73
74
        super(self.__class__, self).__init__(**kwargs)
75
76
    def get_config(self):
77
        config = super().get_config()
78
        config["interp_method"] = self.interp_method
79
        config["indexing"] = self.indexing
80
        config["single_transform"] = self.single_transform
81
        return config
82
83
    def build(self, input_shape):
84
        """
85
        input_shape should be a list for two inputs:
86
        input1: image.
87
        input2: list of transform Tensors
88
            if affine:
89
                should be an N+1 x N+1 matrix
90
                *or* a N+1*N+1 tensor (which will be reshaped to N x (N+1) and an identity row added)
91
            if not affine:
92
                should be a *vol_shape x N
93
        """
94
95
        if len(input_shape) > 3:
96
            raise Exception('Spatial Transformer must be called on a list of min length 2 and max length 3.'
97
                            'First argument is the image followed by the affine and non linear transforms.')
98
99
        # set up number of dimensions
100
        self.ndims = len(input_shape[0]) - 2
101
        self.inshape = input_shape
102
        trf_shape = [trans_shape[1:] for trans_shape in input_shape[1:]]
103
104
        for (i, shape) in enumerate(trf_shape):
105
106
            # the transform is an affine iff:
107
            # it's a 1D Tensor [dense transforms need to be at least ndims + 1]
108
            # it's a 2D Tensor and shape == [N+1, N+1].
109
            self.is_affine.append(len(shape) == 1 or
110
                                  (len(shape) == 2 and all([f == (self.ndims + 1) for f in shape])))
111
112
            # check sizes
113
            if self.is_affine[i] and len(shape) == 1:
114
                ex = self.ndims * (self.ndims + 1)
115
                if shape[0] != ex:
116
                    raise Exception('Expected flattened affine of len %d but got %d' % (ex, shape[0]))
117
118
            if not self.is_affine[i]:
119
                if shape[-1] != self.ndims:
120
                    raise Exception('Offset flow field size expected: %d, found: %d' % (self.ndims, shape[-1]))
121
122
        # confirm built
123
        self.built = True
124
125
    def call(self, inputs, **kwargs):
126
        """
127
        Parameters
128
            inputs: list with several entries: the volume followed by the transforms
129
        """
130
131
        # check shapes
132
        assert 1 < len(inputs) < 4, "inputs has to be len 2 or 3, found: %d" % len(inputs)
133
        vol = inputs[0]
134
        trf = inputs[1:]
135
136
        # necessary for multi_gpu models...
137
        vol = K.reshape(vol, [-1, *self.inshape[0][1:]])
138
        for i in range(len(trf)):
139
            trf[i] = K.reshape(trf[i], [-1, *self.inshape[i+1][1:]])
140
141
        # reorder transforms, non-linear first and affine second
142
        ind_nonlinear_linear = [i[0] for i in sorted(enumerate(self.is_affine), key=lambda x:x[1])]
143
        self.is_affine = [self.is_affine[i] for i in ind_nonlinear_linear]
144
        self.inshape = [self.inshape[i] for i in ind_nonlinear_linear]
145
        trf = [trf[i] for i in ind_nonlinear_linear]
146
147
        # go from affine to deformation field
148
        if len(trf) == 1:
149
            trf = trf[0]
150
            if self.is_affine[0]:
151
                trf = tf.map_fn(lambda x: self._single_aff_to_shift(x, vol.shape[1:-1]), trf, dtype=tf.float32)
152
        # combine non-linear and affine to obtain a single deformation field
153
        elif len(trf) == 2:
154
            trf = tf.map_fn(lambda x: self._non_linear_and_aff_to_shift(x, vol.shape[1:-1]), trf, dtype=tf.float32)
155
156
        # prepare location shift
157
        if self.indexing == 'xy':  # shift the first two dimensions
158
            trf_split = tf.split(trf, trf.shape[-1], axis=-1)
159
            trf_lst = [trf_split[1], trf_split[0], *trf_split[2:]]
160
            trf = tf.concat(trf_lst, -1)
161
162
        # map transform across batch
163
        if self.single_transform:
164
            return tf.map_fn(self._single_transform, [vol, trf[0, :]], dtype=tf.float32)
165
        else:
166
            return tf.map_fn(self._single_transform, [vol, trf], dtype=tf.float32)
167
168
    def _single_aff_to_shift(self, trf, volshape):
169
        if len(trf.shape) == 1:  # go from vector to matrix
170
            trf = tf.reshape(trf, [self.ndims, self.ndims + 1])
171
        return affine_to_shift(trf, volshape, shift_center=True)
172
173
    def _non_linear_and_aff_to_shift(self, trf, volshape):
174
        if len(trf[1].shape) == 1:  # go from vector to matrix
175
            trf[1] = tf.reshape(trf[1], [self.ndims, self.ndims + 1])
176
        return combine_non_linear_and_aff_to_shift(trf, volshape, shift_center=True)
177
178
    def _single_transform(self, inputs):
179
        return transform(inputs[0], inputs[1], interp_method=self.interp_method)
180
181
182
class VecInt(Layer):
183
    """
184
    Vector Integration Layer
185
186
    Enables vector integration via several methods 
187
    (ode or quadrature for time-dependent vector fields, 
188
    scaling and squaring for stationary fields)
189
190
    If you find this function useful, please cite:
191
      Unsupervised Learning for Fast Probabilistic Diffeomorphic Registration
192
      Adrian V. Dalca, Guha Balakrishnan, John Guttag, Mert R. Sabuncu
193
      MICCAI 2018.
194
    """
195
196
    def __init__(self, indexing='ij', method='ss', int_steps=7, out_time_pt=1,
197
                 ode_args=None,
198
                 odeint_fn=None, **kwargs):
199
        """        
200
        Parameters:
201
            method can be any of the methods in neuron.utils.integrate_vec
202
            indexing can be 'xy' (switches first two dimensions) or 'ij'
203
            int_steps is the number of integration steps
204
            out_time_pt is time point at which to output if using odeint integration
205
        """
206
207
        assert indexing in ['ij', 'xy'], "indexing has to be 'ij' (matrix) or 'xy' (cartesian)"
208
        self.indexing = indexing
209
        self.method = method
210
        self.int_steps = int_steps
211
        self.inshape = None
212
        self.out_time_pt = out_time_pt
213
        self.odeint_fn = odeint_fn  # if none then will use a tensorflow function
214
        self.ode_args = ode_args
215
        if ode_args is None:
216
            self.ode_args = {'rtol': 1e-6, 'atol': 1e-12}
217
        super(self.__class__, self).__init__(**kwargs)
218
219
    def get_config(self):
220
        config = super().get_config()
221
        config["indexing"] = self.indexing
222
        config["method"] = self.method
223
        config["int_steps"] = self.int_steps
224
        config["out_time_pt"] = self.out_time_pt
225
        config["ode_args"] = self.ode_args
226
        config["odeint_fn"] = self.odeint_fn
227
        return config
228
229
    def build(self, input_shape):
230
        # confirm built
231
        self.built = True
232
233
        trf_shape = input_shape
234
        if isinstance(input_shape[0], (list, tuple)):
235
            trf_shape = input_shape[0]
236
        self.inshape = trf_shape
237
238
        if trf_shape[-1] != len(trf_shape) - 2:
239
            raise Exception('transform ndims %d does not match expected ndims %d' % (trf_shape[-1], len(trf_shape) - 2))
240
241
    def call(self, inputs, **kwargs):
242
        if not isinstance(inputs, (list, tuple)):
243
            inputs = [inputs]
244
        loc_shift = inputs[0]
245
246
        # necessary for multi_gpu models...
247
        loc_shift = K.reshape(loc_shift, [-1, *self.inshape[1:]])
248
249
        # prepare location shift
250
        if self.indexing == 'xy':  # shift the first two dimensions
251
            loc_shift_split = tf.split(loc_shift, loc_shift.shape[-1], axis=-1)
252
            loc_shift_lst = [loc_shift_split[1], loc_shift_split[0], *loc_shift_split[2:]]
253
            loc_shift = tf.concat(loc_shift_lst, -1)
254
255
        if len(inputs) > 1:
256
            assert self.out_time_pt is None, 'out_time_pt should be None if providing batch_based out_time_pt'
257
258
        # map transform across batch
259
        out = tf.map_fn(self._single_int, [loc_shift] + inputs[1:], dtype=tf.float32)
260
        return out
261
262
    def _single_int(self, inputs):
263
264
        vel = inputs[0]
265
        out_time_pt = self.out_time_pt
266
        if len(inputs) == 2:
267
            out_time_pt = inputs[1]
268
        return integrate_vec(vel, method=self.method,
269
                             nb_steps=self.int_steps,
270
                             ode_args=self.ode_args,
271
                             out_time_pt=out_time_pt,
272
                             odeint_fn=self.odeint_fn)
273
274
275
class Resize(Layer):
276
    """
277
    N-D Resize Tensorflow / Keras Layer
278
    Note: this is not re-shaping an existing volume, but resizing, like scipy's "Zoom"
279
280
    If you find this function useful, please cite:
281
    Anatomical Priors in Convolutional Networks for Unsupervised Biomedical Segmentation,Dalca AV, Guttag J, Sabuncu MR
282
    CVPR 2018
283
284
    Since then, we've re-written the code to be generalized to any 
285
    dimensions, and along the way wrote grid and interpolation functions
286
    """
287
288
    def __init__(self,
289
                 zoom_factor=None,
290
                 size=None,
291
                 interp_method='linear',
292
                 **kwargs):
293
        """
294
        Parameters: 
295
            interp_method: 'linear' or 'nearest'
296
                'xy' indexing will have the first two entries of the flow 
297
                (along last axis) flipped compared to 'ij' indexing
298
        """
299
        self.zoom_factor = zoom_factor
300
        self.size = list(size)
301
        self.zoom_factor0 = None
302
        self.size0 = None
303
        self.interp_method = interp_method
304
        self.ndims = None
305
        self.inshape = None
306
        super(Resize, self).__init__(**kwargs)
307
308
    def get_config(self):
309
        config = super().get_config()
310
        config["zoom_factor"] = self.zoom_factor
311
        config["size"] = self.size
312
        config["interp_method"] = self.interp_method
313
        return config
314
315
    def build(self, input_shape):
316
        """
317
        input_shape should be an element of list of one inputs:
318
        input1: volume
319
                should be a *vol_shape x N
320
        """
321
322
        if isinstance(input_shape[0], (list, tuple)) and len(input_shape) > 1:
323
            raise Exception('Resize must be called on a list of length 1.')
324
325
        if isinstance(input_shape[0], (list, tuple)):
326
            input_shape = input_shape[0]
327
328
        # set up number of dimensions
329
        self.ndims = len(input_shape) - 2
330
        self.inshape = input_shape
331
332
        # check zoom_factor
333
        if isinstance(self.zoom_factor, float):
334
            self.zoom_factor0 = [self.zoom_factor] * self.ndims
335
        elif self.zoom_factor is None:
336
            self.zoom_factor0 = [0] * self.ndims
337
        elif isinstance(self.zoom_factor, (list, tuple)):
338
            self.zoom_factor0 = deepcopy(self.zoom_factor)
339
            assert len(self.zoom_factor0) == self.ndims, \
340
                'zoom factor length {} does not match number of dimensions {}'.format(len(self.zoom_factor), self.ndims)
341
        else:
342
            raise Exception('zoom_factor should be an int or a list/tuple of int (or None if size is not set to None)')
343
344
        # check size
345
        if isinstance(self.size, int):
346
            self.size0 = [self.size] * self.ndims
347
        elif self.size is None:
348
            self.size0 = [0] * self.ndims
349
        elif isinstance(self.size, (list, tuple)):
350
            self.size0 = deepcopy(self.size)
351
            assert len(self.size0) == self.ndims, \
352
                'size length {} does not match number of dimensions {}'.format(len(self.size0), self.ndims)
353
        else:
354
            raise Exception('size should be an int or a list/tuple of int (or None if zoom_factor is not set to None)')
355
356
        # confirm built
357
        self.built = True
358
359
        super(Resize, self).build(input_shape)  # Be sure to call this somewhere!
360
361
    def call(self, inputs, **kwargs):
362
        """
363
        Parameters
364
            inputs: volume or list of one volume
365
        """
366
367
        # check shapes
368
        if isinstance(inputs, (list, tuple)):
369
            assert len(inputs) == 1, "inputs has to be len 1. found: %d" % len(inputs)
370
            vol = inputs[0]
371
        else:
372
            vol = inputs
373
374
        # necessary for multi_gpu models...
375
        vol = K.reshape(vol, [-1, *self.inshape[1:]])
376
377
        # set value of missing size or zoom_factor
378
        if not any(self.zoom_factor0):
379
            self.zoom_factor0 = [self.size0[i] / self.inshape[i+1] for i in range(self.ndims)]
380
        else:
381
            self.size0 = [int(self.inshape[f+1] * self.zoom_factor0[f]) for f in range(self.ndims)]
382
383
        # map transform across batch
384
        return tf.map_fn(self._single_resize, vol, dtype=vol.dtype)
385
386
    def compute_output_shape(self, input_shape):
387
388
        output_shape = [input_shape[0]]
389
        output_shape += [int(input_shape[1:-1][f] * self.zoom_factor0[f]) for f in range(self.ndims)]
390
        output_shape += [input_shape[-1]]
391
        return tuple(output_shape)
392
393
    def _single_resize(self, inputs):
394
        return resize(inputs, self.zoom_factor0, self.size0, interp_method=self.interp_method)
395
396
397
# Zoom naming of resize, to match scipy's naming
398
Zoom = Resize
399
400
401
#########################################################
402
# "Local" layers -- layers with parameters at each voxel
403
#########################################################
404
405
class LocalBias(Layer):
406
    """ 
407
    Local bias layer: each pixel/voxel has its own bias operation (one parameter)
408
    out[v] = in[v] + b
409
    """
410
411
    def __init__(self, my_initializer='RandomNormal', biasmult=1.0, **kwargs):
412
        self.initializer = my_initializer
413
        self.biasmult = biasmult
414
        self.kernel = None
415
        super(LocalBias, self).__init__(**kwargs)
416
417
    def get_config(self):
418
        config = super().get_config()
419
        config["my_initializer"] = self.initializer
420
        config["biasmult"] = self.biasmult
421
        return config
422
423
    def build(self, input_shape):
424
        # Create a trainable weight variable for this layer.
425
        self.kernel = self.add_weight(name='kernel',
426
                                      shape=input_shape[1:],
427
                                      initializer=self.initializer,
428
                                      trainable=True)
429
        super(LocalBias, self).build(input_shape)  # Be sure to call this somewhere!
430
431
    def call(self, x, **kwargs):
432
        return x + self.kernel * self.biasmult  # weights are difference from input
433
434
    def compute_output_shape(self, input_shape):
435
        return input_shape