Diff of /ext/neuron/utils.py [000000] .. [e571d1]

Switch to unified view

a b/ext/neuron/utils.py
1
"""
2
tensorflow/keras utilities for the neuron project
3
4
If you use this code, please cite 
5
Dalca AV, Guttag J, Sabuncu MR
6
Anatomical Priors in Convolutional Networks for Unsupervised Biomedical Segmentation, 
7
CVPR 2018
8
9
or for the transformation/interpolation related functions:
10
11
Unsupervised Learning for Fast Probabilistic Diffeomorphic Registration
12
Adrian V. Dalca, Guha Balakrishnan, John Guttag, Mert R. Sabuncu
13
MICCAI 2018.
14
15
Contact: adalca [at] csail [dot] mit [dot] edu
16
License: GPLv3
17
"""
18
19
import itertools
20
import numpy as np
21
import tensorflow as tf
22
import keras.backend as K
23
24
25
def interpn(vol, loc, interp_method='linear'):
26
    """
27
    N-D gridded interpolation in tensorflow
28
29
    vol can have more dimensions than loc[i], in which case loc[i] acts as a slice 
30
    for the first dimensions
31
32
    Parameters:
33
        vol: volume with size vol_shape or [*vol_shape, nb_features]
34
        loc: an N-long list of N-D Tensors (the interpolation locations) for the new grid
35
            each tensor has to have the same size (but not nec. same size as vol)
36
            or a tensor of size [*new_vol_shape, D]
37
        interp_method: interpolation type 'linear' (default) or 'nearest'
38
39
    Returns:
40
        new interpolated volume of the same size as the entries in loc
41
    """
42
43
    if isinstance(loc, (list, tuple)):
44
        loc = tf.stack(loc, -1)
45
    nb_dims = loc.shape[-1]
46
47
    if len(vol.shape) not in [nb_dims, nb_dims + 1]:
48
        raise Exception("Number of loc Tensors %d does not match volume dimension %d"
49
                        % (nb_dims, len(vol.shape[:-1])))
50
51
    if nb_dims > len(vol.shape):
52
        raise Exception("Loc dimension %d does not match volume dimension %d"
53
                        % (nb_dims, len(vol.shape)))
54
55
    if len(vol.shape) == nb_dims:
56
        vol = K.expand_dims(vol, -1)
57
58
    # flatten and float location Tensors
59
    loc = tf.cast(loc, 'float32')
60
61
    if isinstance(vol.shape, tf.TensorShape):
62
        volshape = vol.shape.as_list()
63
    else:
64
        volshape = vol.shape
65
66
    # interpolate
67
    if interp_method == 'linear':
68
        loc0 = tf.floor(loc)
69
70
        # clip values
71
        max_loc = [d - 1 for d in vol.get_shape().as_list()]
72
        clipped_loc = [tf.clip_by_value(loc[..., d], 0, max_loc[d]) for d in range(nb_dims)]
73
        loc0lst = [tf.clip_by_value(loc0[..., d], 0, max_loc[d]) for d in range(nb_dims)]
74
75
        # get other end of point cube
76
        loc1 = [tf.clip_by_value(loc0lst[d] + 1, 0, max_loc[d]) for d in range(nb_dims)]
77
        locs = [[tf.cast(f, 'int32') for f in loc0lst], [tf.cast(f, 'int32') for f in loc1]]
78
79
        # compute the difference between the upper value and the original value
80
        # differences are basically 1 - (pt - floor(pt))
81
        #   because: floor(pt) + 1 - pt = 1 + (floor(pt) - pt) = 1 - (pt - floor(pt))
82
        diff_loc1 = [loc1[d] - clipped_loc[d] for d in range(nb_dims)]
83
        diff_loc0 = [1 - d for d in diff_loc1]
84
        weights_loc = [diff_loc1, diff_loc0]  # note reverse ordering since weights are inverse of diff.
85
86
        # go through all the cube corners, indexed by a ND binary vector 
87
        # e.g. [0, 0] means this "first" corner in a 2-D "cube"
88
        cube_pts = list(itertools.product([0, 1], repeat=nb_dims))
89
        interp_vol = 0
90
91
        for c in cube_pts:
92
            # get nd values
93
            # note re: indices above volumes via https://github.com/tensorflow/tensorflow/issues/15091
94
            #   It works on GPU because we do not perform index validation checking on GPU -- it's too
95
            #   expensive. Instead we fill the output with zero for the corresponding value. The CPU
96
            #   version caught the bad index and returned the appropriate error.
97
            subs = [locs[c[d]][d] for d in range(nb_dims)]
98
99
            idx = sub2ind(vol.shape[:-1], subs)
100
            vol_val = tf.gather(tf.reshape(vol, [-1, volshape[-1]]), idx)
101
102
            # get the weight of this cube_pt based on the distance
103
            # if c[d] is 0 --> want weight = 1 - (pt - floor[pt]) = diff_loc1
104
            # if c[d] is 1 --> want weight = pt - floor[pt] = diff_loc0
105
            wts_lst = [weights_loc[c[d]][d] for d in range(nb_dims)]
106
            wt = prod_n(wts_lst)
107
            wt = K.expand_dims(wt, -1)
108
109
            # compute final weighted value for each cube corner
110
            interp_vol += wt * vol_val
111
112
    else:
113
        assert interp_method == 'nearest'
114
        roundloc = tf.cast(tf.round(loc), 'int32')
115
116
        # clip values
117
        max_loc = [tf.cast(d - 1, 'int32') for d in vol.shape]
118
        roundloc = [tf.clip_by_value(roundloc[..., d], 0, max_loc[d]) for d in range(nb_dims)]
119
120
        # get values
121
        idx = sub2ind(vol.shape[:-1], roundloc)
122
        interp_vol = tf.gather(tf.reshape(vol, [-1, vol.shape[-1]]), idx)
123
124
    return interp_vol
125
126
127
def resize(vol, zoom_factor, new_shape, interp_method='linear'):
128
    """
129
    if zoom_factor is a list, it will determine the ndims, in which case vol has to be of length ndims or ndims + 1
130
131
    if zoom_factor is an integer, then vol must be of length ndims + 1
132
133
    new_shape should be a list of length ndims
134
135
    """
136
137
    if isinstance(zoom_factor, (list, tuple)):
138
        ndims = len(zoom_factor)
139
        vol_shape = vol.shape[:ndims]
140
        assert len(vol_shape) in (ndims, ndims + 1), \
141
            "zoom_factor length %d does not match ndims %d" % (len(vol_shape), ndims)
142
    else:
143
        vol_shape = vol.shape[:-1]
144
        ndims = len(vol_shape)
145
        zoom_factor = [zoom_factor] * ndims
146
147
    # get grid for new shape
148
    grid = volshape_to_ndgrid(new_shape)
149
    grid = [tf.cast(f, 'float32') for f in grid]
150
    offset = [grid[f] / zoom_factor[f] - grid[f] for f in range(ndims)]
151
    offset = tf.stack(offset, ndims)
152
153
    # transform
154
    return transform(vol, offset, interp_method)
155
156
157
zoom = resize
158
159
160
def affine_to_shift(affine_matrix, volshape, shift_center=True, indexing='ij'):
161
    """
162
    transform an affine matrix to a dense location shift tensor in tensorflow
163
164
    Algorithm:
165
        - get grid and shift grid to be centered at the center of the image (optionally)
166
        - apply affine matrix to each index.
167
        - subtract grid
168
169
    Parameters:
170
        affine_matrix: ND+1 x ND+1 or ND x ND+1 matrix (Tensor)
171
        volshape: 1xN Nd Tensor of the size of the volume.
172
        shift_center (optional)
173
        indexing
174
175
    Returns:
176
        shift field (Tensor) of size *volshape x N
177
    """
178
179
    if isinstance(volshape, tf.TensorShape):
180
        volshape = volshape.as_list()
181
182
    if affine_matrix.dtype != 'float32':
183
        affine_matrix = tf.cast(affine_matrix, 'float32')
184
185
    nb_dims = len(volshape)
186
187
    if len(affine_matrix.shape) == 1:
188
        if len(affine_matrix) != (nb_dims * (nb_dims + 1)):
189
            raise ValueError('transform is supposed a vector of len ndims * (ndims + 1).'
190
                             'Got len %d' % len(affine_matrix))
191
192
        affine_matrix = tf.reshape(affine_matrix, [nb_dims, nb_dims + 1])
193
194
    if not (affine_matrix.shape[0] in [nb_dims, nb_dims + 1] and affine_matrix.shape[1] == (nb_dims + 1)):
195
        raise Exception('Affine matrix shape should match'
196
                        '%d+1 x %d+1 or ' % (nb_dims, nb_dims) +
197
                        '%d x %d+1.' % (nb_dims, nb_dims) +
198
                        'Got: ' + str(volshape))
199
200
    # list of volume ndgrid
201
    # N-long list, each entry of shape volshape
202
    mesh = volshape_to_meshgrid(volshape, indexing=indexing)
203
    mesh = [tf.cast(f, 'float32') for f in mesh]
204
205
    if shift_center:
206
        mesh = [mesh[f] - (volshape[f] - 1) / 2 for f in range(len(volshape))]
207
208
    # add an all-ones entry and transform into a large matrix
209
    flat_mesh = [flatten(f) for f in mesh]
210
    flat_mesh.append(tf.ones(flat_mesh[0].shape, dtype='float32'))
211
    mesh_matrix = tf.transpose(tf.stack(flat_mesh, axis=1))  # 4 x nb_voxels
212
213
    # compute locations
214
    loc_matrix = tf.matmul(affine_matrix, mesh_matrix)  # N+1 x nb_voxels
215
    loc_matrix = tf.transpose(loc_matrix[:nb_dims, :])  # nb_voxels x N
216
    loc = tf.reshape(loc_matrix, list(volshape) + [nb_dims])  # *volshape x N
217
218
    # get shifts and return
219
    return loc - tf.stack(mesh, axis=nb_dims)
220
221
222
def combine_non_linear_and_aff_to_shift(transform_list, volshape, shift_center=True, indexing='ij'):
223
    """
224
    transform an affine matrix to a dense location shift tensor in tensorflow
225
226
    Algorithm:
227
        - get grid and shift grid to be centered at the center of the image (optionally)
228
        - apply affine matrix to each index.
229
        - subtract grid
230
231
    Parameters:
232
        transform_list: list of non-linear tensor (size of volshape) and affine ND+1 x ND+1 or ND x ND+1 tensor
233
        volshape: 1xN Nd Tensor of the size of the volume.
234
        shift_center (optional)
235
        indexing
236
237
    Returns:
238
        shift field (Tensor) of size *volshape x N
239
    """
240
241
    if isinstance(volshape, tf.TensorShape):
242
        volshape = volshape.as_list()
243
244
    # convert transforms to floats
245
    for i in range(len(transform_list)):
246
        if transform_list[i].dtype != 'float32':
247
            transform_list[i] = tf.cast(transform_list[i], 'float32')
248
249
    nb_dims = len(volshape)
250
251
    # transform affine to matrix if given as vector
252
    if len(transform_list[1].shape) == 1:
253
        if len(transform_list[1]) != (nb_dims * (nb_dims + 1)):
254
            raise ValueError('transform is supposed a vector of len ndims * (ndims + 1).'
255
                             'Got len %d' % len(transform_list[1]))
256
257
        transform_list[1] = tf.reshape(transform_list[1], [nb_dims, nb_dims + 1])
258
259
    if not (transform_list[1].shape[0] in [nb_dims, nb_dims + 1] and transform_list[1].shape[1] == (nb_dims + 1)):
260
        raise Exception('Affine matrix shape should match'
261
                        '%d+1 x %d+1 or ' % (nb_dims, nb_dims) +
262
                        '%d x %d+1.' % (nb_dims, nb_dims) +
263
                        'Got: ' + str(volshape))
264
265
    # list of volume ndgrid
266
    # N-long list, each entry of shape volshape
267
    mesh = volshape_to_meshgrid(volshape, indexing=indexing)
268
    mesh = [tf.cast(f, 'float32') for f in mesh]
269
270
    if shift_center:
271
        mesh = [mesh[f] - (volshape[f] - 1) / 2 for f in range(len(volshape))]
272
273
    # add an all-ones entry and transform into a large matrix
274
    # non_linear_mesh = tf.unstack(transform_list[0], axis=3)
275
    non_linear_mesh = tf.unstack(transform_list[0], axis=-1)
276
    flat_mesh = [flatten(mesh[i]+non_linear_mesh[i]) for i in range(len(mesh))]
277
    flat_mesh.append(tf.ones(flat_mesh[0].shape, dtype='float32'))
278
    mesh_matrix = tf.transpose(tf.stack(flat_mesh, axis=1))  # N+1 x nb_voxels
279
280
    # compute locations
281
    loc_matrix = tf.matmul(transform_list[1], mesh_matrix)  # N+1 x nb_voxels
282
    loc_matrix = tf.transpose(loc_matrix[:nb_dims, :])  # nb_voxels x N
283
    loc = tf.reshape(loc_matrix, list(volshape) + [nb_dims])  # *volshape x N
284
285
    # get shifts and return
286
    return loc - tf.stack(mesh, axis=nb_dims)
287
288
289
def transform(vol, loc_shift, interp_method='linear', indexing='ij'):
290
    """
291
    transform interpolation N-D volumes (features) given shifts at each location in tensorflow
292
293
    Essentially interpolates volume vol at locations determined by loc_shift. 
294
    This is a spatial transform in the sense that at location [x] we now have the data from, 
295
    [x + shift] so we've moved data.
296
297
    Parameters:
298
        vol: volume with size vol_shape or [*vol_shape, nb_features]
299
        loc_shift: shift volume [*new_vol_shape, N]
300
        interp_method (default:'linear'): 'linear', 'nearest'
301
        indexing (default: 'ij'): 'ij' (matrix) or 'xy' (cartesian).
302
            In general, prefer to leave this 'ij'
303
    
304
    Return:
305
        new interpolated volumes in the same size as loc_shift[0]
306
    """
307
308
    # parse shapes
309
    if isinstance(loc_shift.shape, tf.TensorShape):
310
        volshape = loc_shift.shape[:-1].as_list()
311
    else:
312
        volshape = loc_shift.shape[:-1]
313
    nb_dims = len(volshape)
314
315
    # location should be meshed and delta
316
    mesh = volshape_to_meshgrid(volshape, indexing=indexing)  # volume mesh
317
    loc = [tf.cast(mesh[d], 'float32') + loc_shift[..., d] for d in range(nb_dims)]
318
319
    # test single
320
    return interpn(vol, loc, interp_method=interp_method)
321
322
323
def integrate_vec(vec, time_dep=False, method='ss', **kwargs):
324
    """
325
    Integrate (stationary of time-dependent) vector field (N-D Tensor) in tensorflow
326
    
327
    Aside from directly using tensorflow's numerical integration odeint(), also implements 
328
    "scaling and squaring", and quadrature. Note that the diff. equation given to odeint
329
    is the one used in quadrature.   
330
331
    Parameters:
332
        vec: the Tensor field to integrate. 
333
            If vol_size is the size of the intrinsic volume, and vol_ndim = len(vol_size),
334
            then vector shape (vec_shape) should be 
335
            [vol_size, vol_ndim] (if stationary)
336
            [vol_size, vol_ndim, nb_time_steps] (if time dependent)
337
        time_dep: bool whether vector is time dependent
338
        method: 'scaling_and_squaring' or 'ss' or 'quadrature'
339
        
340
        if using 'scaling_and_squaring': currently only supports integrating to time point 1.
341
            nb_steps int number of steps. Note that this means the vec field gets broken own to 2**nb_steps.
342
            so nb_steps of 0 means integral = vec.
343
344
    Returns:
345
        int_vec: integral of vector field with same shape as the input
346
    """
347
348
    if method not in ['ss', 'scaling_and_squaring', 'ode', 'quadrature']:
349
        raise ValueError("method has to be 'scaling_and_squaring' or 'ode'. found: %s" % method)
350
351
    if method in ['ss', 'scaling_and_squaring']:
352
        nb_steps = kwargs['nb_steps']
353
        assert nb_steps >= 0, 'nb_steps should be >= 0, found: %d' % nb_steps
354
355
        if time_dep:
356
            svec = K.permute_dimensions(vec, [-1, *range(0, vec.shape[-1] - 1)])
357
            assert 2 ** nb_steps == svec.shape[0], "2**nb_steps and vector shape don't match"
358
359
            svec = svec / (2 ** nb_steps)
360
            for _ in range(nb_steps):
361
                svec = svec[0::2] + tf.map_fn(transform, svec[1::2, :], svec[0::2, :])
362
363
            disp = svec[0, :]
364
365
        else:
366
            vec = vec / (2 ** nb_steps)
367
            for _ in range(nb_steps):
368
                vec += transform(vec, vec)
369
            disp = vec
370
371
    else:  # method == 'quadrature':
372
        nb_steps = kwargs['nb_steps']
373
        assert nb_steps >= 1, 'nb_steps should be >= 1, found: %d' % nb_steps
374
375
        vec = vec / nb_steps
376
377
        if time_dep:
378
            disp = vec[..., 0]
379
            for si in range(nb_steps - 1):
380
                disp += transform(vec[..., si + 1], disp)
381
        else:
382
            disp = vec
383
            for _ in range(nb_steps - 1):
384
                disp += transform(vec, disp)
385
386
    return disp
387
388
389
def volshape_to_ndgrid(volshape, **kwargs):
390
    """
391
    compute Tensor ndgrid from a volume size
392
393
    Parameters:
394
        volshape: the volume size
395
396
    Returns:
397
        A list of Tensors
398
399
    See Also:
400
        ndgrid
401
    """
402
403
    isint = [float(d).is_integer() for d in volshape]
404
    if not all(isint):
405
        raise ValueError("volshape needs to be a list of integers")
406
407
    linvec = [tf.range(0, d) for d in volshape]
408
    return ndgrid(*linvec, **kwargs)
409
410
411
def volshape_to_meshgrid(volshape, **kwargs):
412
    """
413
    compute Tensor meshgrid from a volume size
414
415
    Parameters:
416
        volshape: the volume size
417
418
    Returns:
419
        A list of Tensors
420
421
    See Also:
422
        tf.meshgrid, meshgrid, ndgrid, volshape_to_ndgrid
423
    """
424
425
    isint = [float(d).is_integer() for d in volshape]
426
    if not all(isint):
427
        raise ValueError("volshape needs to be a list of integers")
428
429
    linvec = [tf.range(0, d) for d in volshape]
430
    return meshgrid(*linvec, **kwargs)
431
432
433
def ndgrid(*args, **kwargs):
434
    """
435
    broadcast Tensors on an N-D grid with ij indexing
436
    uses meshgrid with ij indexing
437
438
    Parameters:
439
        *args: Tensors with rank 1
440
        **args: "name" (optional)
441
442
    Returns:
443
        A list of Tensors
444
    
445
    """
446
    return meshgrid(*args, indexing='ij', **kwargs)
447
448
449
def meshgrid(*args, **kwargs):
450
    """
451
    
452
    meshgrid code that builds on (copies) tensorflow's meshgrid but dramatically
453
    improves runtime by changing the last step to tiling instead of multiplication.
454
    https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/python/ops/array_ops.py#L1921
455
    
456
    Broadcasts parameters for evaluation on an N-D grid.
457
    Given N one-dimensional coordinate arrays `*args`, returns a list `outputs`
458
    of N-D coordinate arrays for evaluating expressions on an N-D grid.
459
    Notes:
460
    `meshgrid` supports cartesian ('xy') and matrix ('ij') indexing conventions.
461
    When the `indexing` argument is set to 'xy' (the default), the broadcasting
462
    instructions for the first two dimensions are swapped.
463
    Examples:
464
    Calling `X, Y = meshgrid(x, y)` with the tensors
465
    ```python
466
    x = [1, 2, 3]
467
    y = [4, 5, 6]
468
    X, Y = meshgrid(x, y)
469
    # X = [[1, 2, 3],
470
    #      [1, 2, 3],
471
    #      [1, 2, 3]]
472
    # Y = [[4, 4, 4],
473
    #      [5, 5, 5],
474
    #      [6, 6, 6]]
475
    ```
476
    Args:
477
    *args: `Tensor`s with rank 1.
478
    **kwargs:
479
      - indexing: Either 'xy' or 'ij' (optional, default: 'xy').
480
      - name: A name for the operation (optional).
481
    Returns:
482
    outputs: A list of N `Tensor`s with rank N.
483
    Raises:
484
    TypeError: When no keyword arguments (kwargs) are passed.
485
    ValueError: When indexing keyword argument is not one of `xy` or `ij`.
486
    """
487
488
    indexing = kwargs.pop("indexing", "xy")
489
    if kwargs:
490
        key = list(kwargs.keys())[0]
491
        raise TypeError("'{}' is an invalid keyword argument "
492
                        "for this function".format(key))
493
494
    if indexing not in ("xy", "ij"):
495
        raise ValueError("indexing parameter must be either 'xy' or 'ij'")
496
497
    # with ops.name_scope(name, "meshgrid", args) as name:
498
    ndim = len(args)
499
    s0 = (1,) * ndim
500
501
    # Prepare reshape by inserting dimensions with size 1 where needed
502
    output = []
503
    for i, x in enumerate(args):
504
        output.append(tf.reshape(tf.stack(x), (s0[:i] + (-1,) + s0[i + 1::])))
505
    # Create parameters for broadcasting each tensor to the full size
506
    shapes = [tf.size(x) for x in args]
507
    sz = [x.get_shape().as_list()[0] for x in args]
508
509
    # output_dtype = tf.convert_to_tensor(args[0]).dtype.base_dtype
510
    if indexing == "xy" and ndim > 1:
511
        output[0] = tf.reshape(output[0], (1, -1) + (1,) * (ndim - 2))
512
        output[1] = tf.reshape(output[1], (-1, 1) + (1,) * (ndim - 2))
513
        shapes[0], shapes[1] = shapes[1], shapes[0]
514
        sz[0], sz[1] = sz[1], sz[0]
515
516
    for i in range(len(output)):
517
        stack_sz = [*sz[:i], 1, *sz[(i + 1):]]
518
        if indexing == 'xy' and ndim > 1 and i < 2:
519
            stack_sz[0], stack_sz[1] = stack_sz[1], stack_sz[0]
520
        output[i] = tf.tile(output[i], tf.stack(stack_sz))
521
    return output
522
523
524
def flatten(v):
525
    """flatten Tensor v"""
526
527
    return tf.reshape(v, [-1])
528
529
530
def prod_n(lst):
531
    prod = lst[0]
532
    for p in lst[1:]:
533
        prod *= p
534
    return prod
535
536
537
def sub2ind(siz, subs):
538
    """assumes column-order major"""
539
    # subs is a list
540
    assert len(siz) == len(subs), 'found inconsistent siz and subs: %d %d' % (len(siz), len(subs))
541
542
    k = np.cumprod(siz[::-1])
543
544
    ndx = subs[-1]
545
    for i, v in enumerate(subs[:-1][::-1]):
546
        ndx = ndx + v * k[i]
547
548
    return ndx