|
a |
|
b/ext/neuron/utils.py |
|
|
1 |
""" |
|
|
2 |
tensorflow/keras utilities for the neuron project |
|
|
3 |
|
|
|
4 |
If you use this code, please cite |
|
|
5 |
Dalca AV, Guttag J, Sabuncu MR |
|
|
6 |
Anatomical Priors in Convolutional Networks for Unsupervised Biomedical Segmentation, |
|
|
7 |
CVPR 2018 |
|
|
8 |
|
|
|
9 |
or for the transformation/interpolation related functions: |
|
|
10 |
|
|
|
11 |
Unsupervised Learning for Fast Probabilistic Diffeomorphic Registration |
|
|
12 |
Adrian V. Dalca, Guha Balakrishnan, John Guttag, Mert R. Sabuncu |
|
|
13 |
MICCAI 2018. |
|
|
14 |
|
|
|
15 |
Contact: adalca [at] csail [dot] mit [dot] edu |
|
|
16 |
License: GPLv3 |
|
|
17 |
""" |
|
|
18 |
|
|
|
19 |
import itertools |
|
|
20 |
import numpy as np |
|
|
21 |
import tensorflow as tf |
|
|
22 |
import keras.backend as K |
|
|
23 |
|
|
|
24 |
|
|
|
25 |
def interpn(vol, loc, interp_method='linear'): |
|
|
26 |
""" |
|
|
27 |
N-D gridded interpolation in tensorflow |
|
|
28 |
|
|
|
29 |
vol can have more dimensions than loc[i], in which case loc[i] acts as a slice |
|
|
30 |
for the first dimensions |
|
|
31 |
|
|
|
32 |
Parameters: |
|
|
33 |
vol: volume with size vol_shape or [*vol_shape, nb_features] |
|
|
34 |
loc: an N-long list of N-D Tensors (the interpolation locations) for the new grid |
|
|
35 |
each tensor has to have the same size (but not nec. same size as vol) |
|
|
36 |
or a tensor of size [*new_vol_shape, D] |
|
|
37 |
interp_method: interpolation type 'linear' (default) or 'nearest' |
|
|
38 |
|
|
|
39 |
Returns: |
|
|
40 |
new interpolated volume of the same size as the entries in loc |
|
|
41 |
""" |
|
|
42 |
|
|
|
43 |
if isinstance(loc, (list, tuple)): |
|
|
44 |
loc = tf.stack(loc, -1) |
|
|
45 |
nb_dims = loc.shape[-1] |
|
|
46 |
|
|
|
47 |
if len(vol.shape) not in [nb_dims, nb_dims + 1]: |
|
|
48 |
raise Exception("Number of loc Tensors %d does not match volume dimension %d" |
|
|
49 |
% (nb_dims, len(vol.shape[:-1]))) |
|
|
50 |
|
|
|
51 |
if nb_dims > len(vol.shape): |
|
|
52 |
raise Exception("Loc dimension %d does not match volume dimension %d" |
|
|
53 |
% (nb_dims, len(vol.shape))) |
|
|
54 |
|
|
|
55 |
if len(vol.shape) == nb_dims: |
|
|
56 |
vol = K.expand_dims(vol, -1) |
|
|
57 |
|
|
|
58 |
# flatten and float location Tensors |
|
|
59 |
loc = tf.cast(loc, 'float32') |
|
|
60 |
|
|
|
61 |
if isinstance(vol.shape, tf.TensorShape): |
|
|
62 |
volshape = vol.shape.as_list() |
|
|
63 |
else: |
|
|
64 |
volshape = vol.shape |
|
|
65 |
|
|
|
66 |
# interpolate |
|
|
67 |
if interp_method == 'linear': |
|
|
68 |
loc0 = tf.floor(loc) |
|
|
69 |
|
|
|
70 |
# clip values |
|
|
71 |
max_loc = [d - 1 for d in vol.get_shape().as_list()] |
|
|
72 |
clipped_loc = [tf.clip_by_value(loc[..., d], 0, max_loc[d]) for d in range(nb_dims)] |
|
|
73 |
loc0lst = [tf.clip_by_value(loc0[..., d], 0, max_loc[d]) for d in range(nb_dims)] |
|
|
74 |
|
|
|
75 |
# get other end of point cube |
|
|
76 |
loc1 = [tf.clip_by_value(loc0lst[d] + 1, 0, max_loc[d]) for d in range(nb_dims)] |
|
|
77 |
locs = [[tf.cast(f, 'int32') for f in loc0lst], [tf.cast(f, 'int32') for f in loc1]] |
|
|
78 |
|
|
|
79 |
# compute the difference between the upper value and the original value |
|
|
80 |
# differences are basically 1 - (pt - floor(pt)) |
|
|
81 |
# because: floor(pt) + 1 - pt = 1 + (floor(pt) - pt) = 1 - (pt - floor(pt)) |
|
|
82 |
diff_loc1 = [loc1[d] - clipped_loc[d] for d in range(nb_dims)] |
|
|
83 |
diff_loc0 = [1 - d for d in diff_loc1] |
|
|
84 |
weights_loc = [diff_loc1, diff_loc0] # note reverse ordering since weights are inverse of diff. |
|
|
85 |
|
|
|
86 |
# go through all the cube corners, indexed by a ND binary vector |
|
|
87 |
# e.g. [0, 0] means this "first" corner in a 2-D "cube" |
|
|
88 |
cube_pts = list(itertools.product([0, 1], repeat=nb_dims)) |
|
|
89 |
interp_vol = 0 |
|
|
90 |
|
|
|
91 |
for c in cube_pts: |
|
|
92 |
# get nd values |
|
|
93 |
# note re: indices above volumes via https://github.com/tensorflow/tensorflow/issues/15091 |
|
|
94 |
# It works on GPU because we do not perform index validation checking on GPU -- it's too |
|
|
95 |
# expensive. Instead we fill the output with zero for the corresponding value. The CPU |
|
|
96 |
# version caught the bad index and returned the appropriate error. |
|
|
97 |
subs = [locs[c[d]][d] for d in range(nb_dims)] |
|
|
98 |
|
|
|
99 |
idx = sub2ind(vol.shape[:-1], subs) |
|
|
100 |
vol_val = tf.gather(tf.reshape(vol, [-1, volshape[-1]]), idx) |
|
|
101 |
|
|
|
102 |
# get the weight of this cube_pt based on the distance |
|
|
103 |
# if c[d] is 0 --> want weight = 1 - (pt - floor[pt]) = diff_loc1 |
|
|
104 |
# if c[d] is 1 --> want weight = pt - floor[pt] = diff_loc0 |
|
|
105 |
wts_lst = [weights_loc[c[d]][d] for d in range(nb_dims)] |
|
|
106 |
wt = prod_n(wts_lst) |
|
|
107 |
wt = K.expand_dims(wt, -1) |
|
|
108 |
|
|
|
109 |
# compute final weighted value for each cube corner |
|
|
110 |
interp_vol += wt * vol_val |
|
|
111 |
|
|
|
112 |
else: |
|
|
113 |
assert interp_method == 'nearest' |
|
|
114 |
roundloc = tf.cast(tf.round(loc), 'int32') |
|
|
115 |
|
|
|
116 |
# clip values |
|
|
117 |
max_loc = [tf.cast(d - 1, 'int32') for d in vol.shape] |
|
|
118 |
roundloc = [tf.clip_by_value(roundloc[..., d], 0, max_loc[d]) for d in range(nb_dims)] |
|
|
119 |
|
|
|
120 |
# get values |
|
|
121 |
idx = sub2ind(vol.shape[:-1], roundloc) |
|
|
122 |
interp_vol = tf.gather(tf.reshape(vol, [-1, vol.shape[-1]]), idx) |
|
|
123 |
|
|
|
124 |
return interp_vol |
|
|
125 |
|
|
|
126 |
|
|
|
127 |
def resize(vol, zoom_factor, new_shape, interp_method='linear'): |
|
|
128 |
""" |
|
|
129 |
if zoom_factor is a list, it will determine the ndims, in which case vol has to be of length ndims or ndims + 1 |
|
|
130 |
|
|
|
131 |
if zoom_factor is an integer, then vol must be of length ndims + 1 |
|
|
132 |
|
|
|
133 |
new_shape should be a list of length ndims |
|
|
134 |
|
|
|
135 |
""" |
|
|
136 |
|
|
|
137 |
if isinstance(zoom_factor, (list, tuple)): |
|
|
138 |
ndims = len(zoom_factor) |
|
|
139 |
vol_shape = vol.shape[:ndims] |
|
|
140 |
assert len(vol_shape) in (ndims, ndims + 1), \ |
|
|
141 |
"zoom_factor length %d does not match ndims %d" % (len(vol_shape), ndims) |
|
|
142 |
else: |
|
|
143 |
vol_shape = vol.shape[:-1] |
|
|
144 |
ndims = len(vol_shape) |
|
|
145 |
zoom_factor = [zoom_factor] * ndims |
|
|
146 |
|
|
|
147 |
# get grid for new shape |
|
|
148 |
grid = volshape_to_ndgrid(new_shape) |
|
|
149 |
grid = [tf.cast(f, 'float32') for f in grid] |
|
|
150 |
offset = [grid[f] / zoom_factor[f] - grid[f] for f in range(ndims)] |
|
|
151 |
offset = tf.stack(offset, ndims) |
|
|
152 |
|
|
|
153 |
# transform |
|
|
154 |
return transform(vol, offset, interp_method) |
|
|
155 |
|
|
|
156 |
|
|
|
157 |
zoom = resize |
|
|
158 |
|
|
|
159 |
|
|
|
160 |
def affine_to_shift(affine_matrix, volshape, shift_center=True, indexing='ij'): |
|
|
161 |
""" |
|
|
162 |
transform an affine matrix to a dense location shift tensor in tensorflow |
|
|
163 |
|
|
|
164 |
Algorithm: |
|
|
165 |
- get grid and shift grid to be centered at the center of the image (optionally) |
|
|
166 |
- apply affine matrix to each index. |
|
|
167 |
- subtract grid |
|
|
168 |
|
|
|
169 |
Parameters: |
|
|
170 |
affine_matrix: ND+1 x ND+1 or ND x ND+1 matrix (Tensor) |
|
|
171 |
volshape: 1xN Nd Tensor of the size of the volume. |
|
|
172 |
shift_center (optional) |
|
|
173 |
indexing |
|
|
174 |
|
|
|
175 |
Returns: |
|
|
176 |
shift field (Tensor) of size *volshape x N |
|
|
177 |
""" |
|
|
178 |
|
|
|
179 |
if isinstance(volshape, tf.TensorShape): |
|
|
180 |
volshape = volshape.as_list() |
|
|
181 |
|
|
|
182 |
if affine_matrix.dtype != 'float32': |
|
|
183 |
affine_matrix = tf.cast(affine_matrix, 'float32') |
|
|
184 |
|
|
|
185 |
nb_dims = len(volshape) |
|
|
186 |
|
|
|
187 |
if len(affine_matrix.shape) == 1: |
|
|
188 |
if len(affine_matrix) != (nb_dims * (nb_dims + 1)): |
|
|
189 |
raise ValueError('transform is supposed a vector of len ndims * (ndims + 1).' |
|
|
190 |
'Got len %d' % len(affine_matrix)) |
|
|
191 |
|
|
|
192 |
affine_matrix = tf.reshape(affine_matrix, [nb_dims, nb_dims + 1]) |
|
|
193 |
|
|
|
194 |
if not (affine_matrix.shape[0] in [nb_dims, nb_dims + 1] and affine_matrix.shape[1] == (nb_dims + 1)): |
|
|
195 |
raise Exception('Affine matrix shape should match' |
|
|
196 |
'%d+1 x %d+1 or ' % (nb_dims, nb_dims) + |
|
|
197 |
'%d x %d+1.' % (nb_dims, nb_dims) + |
|
|
198 |
'Got: ' + str(volshape)) |
|
|
199 |
|
|
|
200 |
# list of volume ndgrid |
|
|
201 |
# N-long list, each entry of shape volshape |
|
|
202 |
mesh = volshape_to_meshgrid(volshape, indexing=indexing) |
|
|
203 |
mesh = [tf.cast(f, 'float32') for f in mesh] |
|
|
204 |
|
|
|
205 |
if shift_center: |
|
|
206 |
mesh = [mesh[f] - (volshape[f] - 1) / 2 for f in range(len(volshape))] |
|
|
207 |
|
|
|
208 |
# add an all-ones entry and transform into a large matrix |
|
|
209 |
flat_mesh = [flatten(f) for f in mesh] |
|
|
210 |
flat_mesh.append(tf.ones(flat_mesh[0].shape, dtype='float32')) |
|
|
211 |
mesh_matrix = tf.transpose(tf.stack(flat_mesh, axis=1)) # 4 x nb_voxels |
|
|
212 |
|
|
|
213 |
# compute locations |
|
|
214 |
loc_matrix = tf.matmul(affine_matrix, mesh_matrix) # N+1 x nb_voxels |
|
|
215 |
loc_matrix = tf.transpose(loc_matrix[:nb_dims, :]) # nb_voxels x N |
|
|
216 |
loc = tf.reshape(loc_matrix, list(volshape) + [nb_dims]) # *volshape x N |
|
|
217 |
|
|
|
218 |
# get shifts and return |
|
|
219 |
return loc - tf.stack(mesh, axis=nb_dims) |
|
|
220 |
|
|
|
221 |
|
|
|
222 |
def combine_non_linear_and_aff_to_shift(transform_list, volshape, shift_center=True, indexing='ij'): |
|
|
223 |
""" |
|
|
224 |
transform an affine matrix to a dense location shift tensor in tensorflow |
|
|
225 |
|
|
|
226 |
Algorithm: |
|
|
227 |
- get grid and shift grid to be centered at the center of the image (optionally) |
|
|
228 |
- apply affine matrix to each index. |
|
|
229 |
- subtract grid |
|
|
230 |
|
|
|
231 |
Parameters: |
|
|
232 |
transform_list: list of non-linear tensor (size of volshape) and affine ND+1 x ND+1 or ND x ND+1 tensor |
|
|
233 |
volshape: 1xN Nd Tensor of the size of the volume. |
|
|
234 |
shift_center (optional) |
|
|
235 |
indexing |
|
|
236 |
|
|
|
237 |
Returns: |
|
|
238 |
shift field (Tensor) of size *volshape x N |
|
|
239 |
""" |
|
|
240 |
|
|
|
241 |
if isinstance(volshape, tf.TensorShape): |
|
|
242 |
volshape = volshape.as_list() |
|
|
243 |
|
|
|
244 |
# convert transforms to floats |
|
|
245 |
for i in range(len(transform_list)): |
|
|
246 |
if transform_list[i].dtype != 'float32': |
|
|
247 |
transform_list[i] = tf.cast(transform_list[i], 'float32') |
|
|
248 |
|
|
|
249 |
nb_dims = len(volshape) |
|
|
250 |
|
|
|
251 |
# transform affine to matrix if given as vector |
|
|
252 |
if len(transform_list[1].shape) == 1: |
|
|
253 |
if len(transform_list[1]) != (nb_dims * (nb_dims + 1)): |
|
|
254 |
raise ValueError('transform is supposed a vector of len ndims * (ndims + 1).' |
|
|
255 |
'Got len %d' % len(transform_list[1])) |
|
|
256 |
|
|
|
257 |
transform_list[1] = tf.reshape(transform_list[1], [nb_dims, nb_dims + 1]) |
|
|
258 |
|
|
|
259 |
if not (transform_list[1].shape[0] in [nb_dims, nb_dims + 1] and transform_list[1].shape[1] == (nb_dims + 1)): |
|
|
260 |
raise Exception('Affine matrix shape should match' |
|
|
261 |
'%d+1 x %d+1 or ' % (nb_dims, nb_dims) + |
|
|
262 |
'%d x %d+1.' % (nb_dims, nb_dims) + |
|
|
263 |
'Got: ' + str(volshape)) |
|
|
264 |
|
|
|
265 |
# list of volume ndgrid |
|
|
266 |
# N-long list, each entry of shape volshape |
|
|
267 |
mesh = volshape_to_meshgrid(volshape, indexing=indexing) |
|
|
268 |
mesh = [tf.cast(f, 'float32') for f in mesh] |
|
|
269 |
|
|
|
270 |
if shift_center: |
|
|
271 |
mesh = [mesh[f] - (volshape[f] - 1) / 2 for f in range(len(volshape))] |
|
|
272 |
|
|
|
273 |
# add an all-ones entry and transform into a large matrix |
|
|
274 |
# non_linear_mesh = tf.unstack(transform_list[0], axis=3) |
|
|
275 |
non_linear_mesh = tf.unstack(transform_list[0], axis=-1) |
|
|
276 |
flat_mesh = [flatten(mesh[i]+non_linear_mesh[i]) for i in range(len(mesh))] |
|
|
277 |
flat_mesh.append(tf.ones(flat_mesh[0].shape, dtype='float32')) |
|
|
278 |
mesh_matrix = tf.transpose(tf.stack(flat_mesh, axis=1)) # N+1 x nb_voxels |
|
|
279 |
|
|
|
280 |
# compute locations |
|
|
281 |
loc_matrix = tf.matmul(transform_list[1], mesh_matrix) # N+1 x nb_voxels |
|
|
282 |
loc_matrix = tf.transpose(loc_matrix[:nb_dims, :]) # nb_voxels x N |
|
|
283 |
loc = tf.reshape(loc_matrix, list(volshape) + [nb_dims]) # *volshape x N |
|
|
284 |
|
|
|
285 |
# get shifts and return |
|
|
286 |
return loc - tf.stack(mesh, axis=nb_dims) |
|
|
287 |
|
|
|
288 |
|
|
|
289 |
def transform(vol, loc_shift, interp_method='linear', indexing='ij'): |
|
|
290 |
""" |
|
|
291 |
transform interpolation N-D volumes (features) given shifts at each location in tensorflow |
|
|
292 |
|
|
|
293 |
Essentially interpolates volume vol at locations determined by loc_shift. |
|
|
294 |
This is a spatial transform in the sense that at location [x] we now have the data from, |
|
|
295 |
[x + shift] so we've moved data. |
|
|
296 |
|
|
|
297 |
Parameters: |
|
|
298 |
vol: volume with size vol_shape or [*vol_shape, nb_features] |
|
|
299 |
loc_shift: shift volume [*new_vol_shape, N] |
|
|
300 |
interp_method (default:'linear'): 'linear', 'nearest' |
|
|
301 |
indexing (default: 'ij'): 'ij' (matrix) or 'xy' (cartesian). |
|
|
302 |
In general, prefer to leave this 'ij' |
|
|
303 |
|
|
|
304 |
Return: |
|
|
305 |
new interpolated volumes in the same size as loc_shift[0] |
|
|
306 |
""" |
|
|
307 |
|
|
|
308 |
# parse shapes |
|
|
309 |
if isinstance(loc_shift.shape, tf.TensorShape): |
|
|
310 |
volshape = loc_shift.shape[:-1].as_list() |
|
|
311 |
else: |
|
|
312 |
volshape = loc_shift.shape[:-1] |
|
|
313 |
nb_dims = len(volshape) |
|
|
314 |
|
|
|
315 |
# location should be meshed and delta |
|
|
316 |
mesh = volshape_to_meshgrid(volshape, indexing=indexing) # volume mesh |
|
|
317 |
loc = [tf.cast(mesh[d], 'float32') + loc_shift[..., d] for d in range(nb_dims)] |
|
|
318 |
|
|
|
319 |
# test single |
|
|
320 |
return interpn(vol, loc, interp_method=interp_method) |
|
|
321 |
|
|
|
322 |
|
|
|
323 |
def integrate_vec(vec, time_dep=False, method='ss', **kwargs): |
|
|
324 |
""" |
|
|
325 |
Integrate (stationary of time-dependent) vector field (N-D Tensor) in tensorflow |
|
|
326 |
|
|
|
327 |
Aside from directly using tensorflow's numerical integration odeint(), also implements |
|
|
328 |
"scaling and squaring", and quadrature. Note that the diff. equation given to odeint |
|
|
329 |
is the one used in quadrature. |
|
|
330 |
|
|
|
331 |
Parameters: |
|
|
332 |
vec: the Tensor field to integrate. |
|
|
333 |
If vol_size is the size of the intrinsic volume, and vol_ndim = len(vol_size), |
|
|
334 |
then vector shape (vec_shape) should be |
|
|
335 |
[vol_size, vol_ndim] (if stationary) |
|
|
336 |
[vol_size, vol_ndim, nb_time_steps] (if time dependent) |
|
|
337 |
time_dep: bool whether vector is time dependent |
|
|
338 |
method: 'scaling_and_squaring' or 'ss' or 'quadrature' |
|
|
339 |
|
|
|
340 |
if using 'scaling_and_squaring': currently only supports integrating to time point 1. |
|
|
341 |
nb_steps int number of steps. Note that this means the vec field gets broken own to 2**nb_steps. |
|
|
342 |
so nb_steps of 0 means integral = vec. |
|
|
343 |
|
|
|
344 |
Returns: |
|
|
345 |
int_vec: integral of vector field with same shape as the input |
|
|
346 |
""" |
|
|
347 |
|
|
|
348 |
if method not in ['ss', 'scaling_and_squaring', 'ode', 'quadrature']: |
|
|
349 |
raise ValueError("method has to be 'scaling_and_squaring' or 'ode'. found: %s" % method) |
|
|
350 |
|
|
|
351 |
if method in ['ss', 'scaling_and_squaring']: |
|
|
352 |
nb_steps = kwargs['nb_steps'] |
|
|
353 |
assert nb_steps >= 0, 'nb_steps should be >= 0, found: %d' % nb_steps |
|
|
354 |
|
|
|
355 |
if time_dep: |
|
|
356 |
svec = K.permute_dimensions(vec, [-1, *range(0, vec.shape[-1] - 1)]) |
|
|
357 |
assert 2 ** nb_steps == svec.shape[0], "2**nb_steps and vector shape don't match" |
|
|
358 |
|
|
|
359 |
svec = svec / (2 ** nb_steps) |
|
|
360 |
for _ in range(nb_steps): |
|
|
361 |
svec = svec[0::2] + tf.map_fn(transform, svec[1::2, :], svec[0::2, :]) |
|
|
362 |
|
|
|
363 |
disp = svec[0, :] |
|
|
364 |
|
|
|
365 |
else: |
|
|
366 |
vec = vec / (2 ** nb_steps) |
|
|
367 |
for _ in range(nb_steps): |
|
|
368 |
vec += transform(vec, vec) |
|
|
369 |
disp = vec |
|
|
370 |
|
|
|
371 |
else: # method == 'quadrature': |
|
|
372 |
nb_steps = kwargs['nb_steps'] |
|
|
373 |
assert nb_steps >= 1, 'nb_steps should be >= 1, found: %d' % nb_steps |
|
|
374 |
|
|
|
375 |
vec = vec / nb_steps |
|
|
376 |
|
|
|
377 |
if time_dep: |
|
|
378 |
disp = vec[..., 0] |
|
|
379 |
for si in range(nb_steps - 1): |
|
|
380 |
disp += transform(vec[..., si + 1], disp) |
|
|
381 |
else: |
|
|
382 |
disp = vec |
|
|
383 |
for _ in range(nb_steps - 1): |
|
|
384 |
disp += transform(vec, disp) |
|
|
385 |
|
|
|
386 |
return disp |
|
|
387 |
|
|
|
388 |
|
|
|
389 |
def volshape_to_ndgrid(volshape, **kwargs): |
|
|
390 |
""" |
|
|
391 |
compute Tensor ndgrid from a volume size |
|
|
392 |
|
|
|
393 |
Parameters: |
|
|
394 |
volshape: the volume size |
|
|
395 |
|
|
|
396 |
Returns: |
|
|
397 |
A list of Tensors |
|
|
398 |
|
|
|
399 |
See Also: |
|
|
400 |
ndgrid |
|
|
401 |
""" |
|
|
402 |
|
|
|
403 |
isint = [float(d).is_integer() for d in volshape] |
|
|
404 |
if not all(isint): |
|
|
405 |
raise ValueError("volshape needs to be a list of integers") |
|
|
406 |
|
|
|
407 |
linvec = [tf.range(0, d) for d in volshape] |
|
|
408 |
return ndgrid(*linvec, **kwargs) |
|
|
409 |
|
|
|
410 |
|
|
|
411 |
def volshape_to_meshgrid(volshape, **kwargs): |
|
|
412 |
""" |
|
|
413 |
compute Tensor meshgrid from a volume size |
|
|
414 |
|
|
|
415 |
Parameters: |
|
|
416 |
volshape: the volume size |
|
|
417 |
|
|
|
418 |
Returns: |
|
|
419 |
A list of Tensors |
|
|
420 |
|
|
|
421 |
See Also: |
|
|
422 |
tf.meshgrid, meshgrid, ndgrid, volshape_to_ndgrid |
|
|
423 |
""" |
|
|
424 |
|
|
|
425 |
isint = [float(d).is_integer() for d in volshape] |
|
|
426 |
if not all(isint): |
|
|
427 |
raise ValueError("volshape needs to be a list of integers") |
|
|
428 |
|
|
|
429 |
linvec = [tf.range(0, d) for d in volshape] |
|
|
430 |
return meshgrid(*linvec, **kwargs) |
|
|
431 |
|
|
|
432 |
|
|
|
433 |
def ndgrid(*args, **kwargs): |
|
|
434 |
""" |
|
|
435 |
broadcast Tensors on an N-D grid with ij indexing |
|
|
436 |
uses meshgrid with ij indexing |
|
|
437 |
|
|
|
438 |
Parameters: |
|
|
439 |
*args: Tensors with rank 1 |
|
|
440 |
**args: "name" (optional) |
|
|
441 |
|
|
|
442 |
Returns: |
|
|
443 |
A list of Tensors |
|
|
444 |
|
|
|
445 |
""" |
|
|
446 |
return meshgrid(*args, indexing='ij', **kwargs) |
|
|
447 |
|
|
|
448 |
|
|
|
449 |
def meshgrid(*args, **kwargs): |
|
|
450 |
""" |
|
|
451 |
|
|
|
452 |
meshgrid code that builds on (copies) tensorflow's meshgrid but dramatically |
|
|
453 |
improves runtime by changing the last step to tiling instead of multiplication. |
|
|
454 |
https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/python/ops/array_ops.py#L1921 |
|
|
455 |
|
|
|
456 |
Broadcasts parameters for evaluation on an N-D grid. |
|
|
457 |
Given N one-dimensional coordinate arrays `*args`, returns a list `outputs` |
|
|
458 |
of N-D coordinate arrays for evaluating expressions on an N-D grid. |
|
|
459 |
Notes: |
|
|
460 |
`meshgrid` supports cartesian ('xy') and matrix ('ij') indexing conventions. |
|
|
461 |
When the `indexing` argument is set to 'xy' (the default), the broadcasting |
|
|
462 |
instructions for the first two dimensions are swapped. |
|
|
463 |
Examples: |
|
|
464 |
Calling `X, Y = meshgrid(x, y)` with the tensors |
|
|
465 |
```python |
|
|
466 |
x = [1, 2, 3] |
|
|
467 |
y = [4, 5, 6] |
|
|
468 |
X, Y = meshgrid(x, y) |
|
|
469 |
# X = [[1, 2, 3], |
|
|
470 |
# [1, 2, 3], |
|
|
471 |
# [1, 2, 3]] |
|
|
472 |
# Y = [[4, 4, 4], |
|
|
473 |
# [5, 5, 5], |
|
|
474 |
# [6, 6, 6]] |
|
|
475 |
``` |
|
|
476 |
Args: |
|
|
477 |
*args: `Tensor`s with rank 1. |
|
|
478 |
**kwargs: |
|
|
479 |
- indexing: Either 'xy' or 'ij' (optional, default: 'xy'). |
|
|
480 |
- name: A name for the operation (optional). |
|
|
481 |
Returns: |
|
|
482 |
outputs: A list of N `Tensor`s with rank N. |
|
|
483 |
Raises: |
|
|
484 |
TypeError: When no keyword arguments (kwargs) are passed. |
|
|
485 |
ValueError: When indexing keyword argument is not one of `xy` or `ij`. |
|
|
486 |
""" |
|
|
487 |
|
|
|
488 |
indexing = kwargs.pop("indexing", "xy") |
|
|
489 |
if kwargs: |
|
|
490 |
key = list(kwargs.keys())[0] |
|
|
491 |
raise TypeError("'{}' is an invalid keyword argument " |
|
|
492 |
"for this function".format(key)) |
|
|
493 |
|
|
|
494 |
if indexing not in ("xy", "ij"): |
|
|
495 |
raise ValueError("indexing parameter must be either 'xy' or 'ij'") |
|
|
496 |
|
|
|
497 |
# with ops.name_scope(name, "meshgrid", args) as name: |
|
|
498 |
ndim = len(args) |
|
|
499 |
s0 = (1,) * ndim |
|
|
500 |
|
|
|
501 |
# Prepare reshape by inserting dimensions with size 1 where needed |
|
|
502 |
output = [] |
|
|
503 |
for i, x in enumerate(args): |
|
|
504 |
output.append(tf.reshape(tf.stack(x), (s0[:i] + (-1,) + s0[i + 1::]))) |
|
|
505 |
# Create parameters for broadcasting each tensor to the full size |
|
|
506 |
shapes = [tf.size(x) for x in args] |
|
|
507 |
sz = [x.get_shape().as_list()[0] for x in args] |
|
|
508 |
|
|
|
509 |
# output_dtype = tf.convert_to_tensor(args[0]).dtype.base_dtype |
|
|
510 |
if indexing == "xy" and ndim > 1: |
|
|
511 |
output[0] = tf.reshape(output[0], (1, -1) + (1,) * (ndim - 2)) |
|
|
512 |
output[1] = tf.reshape(output[1], (-1, 1) + (1,) * (ndim - 2)) |
|
|
513 |
shapes[0], shapes[1] = shapes[1], shapes[0] |
|
|
514 |
sz[0], sz[1] = sz[1], sz[0] |
|
|
515 |
|
|
|
516 |
for i in range(len(output)): |
|
|
517 |
stack_sz = [*sz[:i], 1, *sz[(i + 1):]] |
|
|
518 |
if indexing == 'xy' and ndim > 1 and i < 2: |
|
|
519 |
stack_sz[0], stack_sz[1] = stack_sz[1], stack_sz[0] |
|
|
520 |
output[i] = tf.tile(output[i], tf.stack(stack_sz)) |
|
|
521 |
return output |
|
|
522 |
|
|
|
523 |
|
|
|
524 |
def flatten(v): |
|
|
525 |
"""flatten Tensor v""" |
|
|
526 |
|
|
|
527 |
return tf.reshape(v, [-1]) |
|
|
528 |
|
|
|
529 |
|
|
|
530 |
def prod_n(lst): |
|
|
531 |
prod = lst[0] |
|
|
532 |
for p in lst[1:]: |
|
|
533 |
prod *= p |
|
|
534 |
return prod |
|
|
535 |
|
|
|
536 |
|
|
|
537 |
def sub2ind(siz, subs): |
|
|
538 |
"""assumes column-order major""" |
|
|
539 |
# subs is a list |
|
|
540 |
assert len(siz) == len(subs), 'found inconsistent siz and subs: %d %d' % (len(siz), len(subs)) |
|
|
541 |
|
|
|
542 |
k = np.cumprod(siz[::-1]) |
|
|
543 |
|
|
|
544 |
ndx = subs[-1] |
|
|
545 |
for i, v in enumerate(subs[:-1][::-1]): |
|
|
546 |
ndx = ndx + v * k[i] |
|
|
547 |
|
|
|
548 |
return ndx |