|
a |
|
b/ext/neuron/layers.py |
|
|
1 |
""" |
|
|
2 |
tensorflow/keras utilities for the neuron project |
|
|
3 |
|
|
|
4 |
If you use this code, please cite |
|
|
5 |
Dalca AV, Guttag J, Sabuncu MR |
|
|
6 |
Anatomical Priors in Convolutional Networks for Unsupervised Biomedical Segmentation, |
|
|
7 |
CVPR 2018 |
|
|
8 |
|
|
|
9 |
or for the transformation/integration functions: |
|
|
10 |
|
|
|
11 |
Unsupervised Learning for Fast Probabilistic Diffeomorphic Registration |
|
|
12 |
Adrian V. Dalca, Guha Balakrishnan, John Guttag, Mert R. Sabuncu |
|
|
13 |
MICCAI 2018. |
|
|
14 |
|
|
|
15 |
Contact: adalca [at] csail [dot] mit [dot] edu |
|
|
16 |
License: GPLv3 |
|
|
17 |
""" |
|
|
18 |
|
|
|
19 |
# third party |
|
|
20 |
import tensorflow as tf |
|
|
21 |
from keras import backend as K |
|
|
22 |
from keras.layers import Layer |
|
|
23 |
from copy import deepcopy |
|
|
24 |
|
|
|
25 |
# local |
|
|
26 |
from ext.neuron.utils import transform, resize, integrate_vec, affine_to_shift, combine_non_linear_and_aff_to_shift |
|
|
27 |
|
|
|
28 |
|
|
|
29 |
class SpatialTransformer(Layer): |
|
|
30 |
""" |
|
|
31 |
N-D Spatial Transformer Tensorflow / Keras Layer |
|
|
32 |
|
|
|
33 |
The Layer can handle both affine and dense transforms. |
|
|
34 |
Both transforms are meant to give a 'shift' from the current position. |
|
|
35 |
Therefore, a dense transform gives displacements (not absolute locations) at each voxel, |
|
|
36 |
and an affine transform gives the *difference* of the affine matrix from |
|
|
37 |
the identity matrix. |
|
|
38 |
|
|
|
39 |
If you find this function useful, please cite: |
|
|
40 |
Unsupervised Learning for Fast Probabilistic Diffeomorphic Registration |
|
|
41 |
Adrian V. Dalca, Guha Balakrishnan, John Guttag, Mert R. Sabuncu |
|
|
42 |
MICCAI 2018. |
|
|
43 |
|
|
|
44 |
Originally, this code was based on voxelmorph code, which |
|
|
45 |
was in turn transformed to be dense with the help of (affine) STN code |
|
|
46 |
via https://github.com/kevinzakka/spatial-transformer-network |
|
|
47 |
|
|
|
48 |
Since then, we've re-written the code to be generalized to any |
|
|
49 |
dimensions, and along the way wrote grid and interpolation functions |
|
|
50 |
""" |
|
|
51 |
|
|
|
52 |
def __init__(self, |
|
|
53 |
interp_method='linear', |
|
|
54 |
indexing='ij', |
|
|
55 |
single_transform=False, |
|
|
56 |
**kwargs): |
|
|
57 |
""" |
|
|
58 |
Parameters: |
|
|
59 |
interp_method: 'linear' or 'nearest' |
|
|
60 |
single_transform: whether a single transform supplied for the whole batch |
|
|
61 |
indexing (default: 'ij'): 'ij' (matrix) or 'xy' (cartesian) |
|
|
62 |
'xy' indexing will have the first two entries of the flow |
|
|
63 |
(along last axis) flipped compared to 'ij' indexing |
|
|
64 |
""" |
|
|
65 |
self.interp_method = interp_method |
|
|
66 |
self.ndims = None |
|
|
67 |
self.inshape = None |
|
|
68 |
self.single_transform = single_transform |
|
|
69 |
self.is_affine = list() |
|
|
70 |
|
|
|
71 |
assert indexing in ['ij', 'xy'], "indexing has to be 'ij' (matrix) or 'xy' (cartesian)" |
|
|
72 |
self.indexing = indexing |
|
|
73 |
|
|
|
74 |
super(self.__class__, self).__init__(**kwargs) |
|
|
75 |
|
|
|
76 |
def get_config(self): |
|
|
77 |
config = super().get_config() |
|
|
78 |
config["interp_method"] = self.interp_method |
|
|
79 |
config["indexing"] = self.indexing |
|
|
80 |
config["single_transform"] = self.single_transform |
|
|
81 |
return config |
|
|
82 |
|
|
|
83 |
def build(self, input_shape): |
|
|
84 |
""" |
|
|
85 |
input_shape should be a list for two inputs: |
|
|
86 |
input1: image. |
|
|
87 |
input2: list of transform Tensors |
|
|
88 |
if affine: |
|
|
89 |
should be an N+1 x N+1 matrix |
|
|
90 |
*or* a N+1*N+1 tensor (which will be reshaped to N x (N+1) and an identity row added) |
|
|
91 |
if not affine: |
|
|
92 |
should be a *vol_shape x N |
|
|
93 |
""" |
|
|
94 |
|
|
|
95 |
if len(input_shape) > 3: |
|
|
96 |
raise Exception('Spatial Transformer must be called on a list of min length 2 and max length 3.' |
|
|
97 |
'First argument is the image followed by the affine and non linear transforms.') |
|
|
98 |
|
|
|
99 |
# set up number of dimensions |
|
|
100 |
self.ndims = len(input_shape[0]) - 2 |
|
|
101 |
self.inshape = input_shape |
|
|
102 |
trf_shape = [trans_shape[1:] for trans_shape in input_shape[1:]] |
|
|
103 |
|
|
|
104 |
for (i, shape) in enumerate(trf_shape): |
|
|
105 |
|
|
|
106 |
# the transform is an affine iff: |
|
|
107 |
# it's a 1D Tensor [dense transforms need to be at least ndims + 1] |
|
|
108 |
# it's a 2D Tensor and shape == [N+1, N+1]. |
|
|
109 |
self.is_affine.append(len(shape) == 1 or |
|
|
110 |
(len(shape) == 2 and all([f == (self.ndims + 1) for f in shape]))) |
|
|
111 |
|
|
|
112 |
# check sizes |
|
|
113 |
if self.is_affine[i] and len(shape) == 1: |
|
|
114 |
ex = self.ndims * (self.ndims + 1) |
|
|
115 |
if shape[0] != ex: |
|
|
116 |
raise Exception('Expected flattened affine of len %d but got %d' % (ex, shape[0])) |
|
|
117 |
|
|
|
118 |
if not self.is_affine[i]: |
|
|
119 |
if shape[-1] != self.ndims: |
|
|
120 |
raise Exception('Offset flow field size expected: %d, found: %d' % (self.ndims, shape[-1])) |
|
|
121 |
|
|
|
122 |
# confirm built |
|
|
123 |
self.built = True |
|
|
124 |
|
|
|
125 |
def call(self, inputs, **kwargs): |
|
|
126 |
""" |
|
|
127 |
Parameters |
|
|
128 |
inputs: list with several entries: the volume followed by the transforms |
|
|
129 |
""" |
|
|
130 |
|
|
|
131 |
# check shapes |
|
|
132 |
assert 1 < len(inputs) < 4, "inputs has to be len 2 or 3, found: %d" % len(inputs) |
|
|
133 |
vol = inputs[0] |
|
|
134 |
trf = inputs[1:] |
|
|
135 |
|
|
|
136 |
# necessary for multi_gpu models... |
|
|
137 |
vol = K.reshape(vol, [-1, *self.inshape[0][1:]]) |
|
|
138 |
for i in range(len(trf)): |
|
|
139 |
trf[i] = K.reshape(trf[i], [-1, *self.inshape[i+1][1:]]) |
|
|
140 |
|
|
|
141 |
# reorder transforms, non-linear first and affine second |
|
|
142 |
ind_nonlinear_linear = [i[0] for i in sorted(enumerate(self.is_affine), key=lambda x:x[1])] |
|
|
143 |
self.is_affine = [self.is_affine[i] for i in ind_nonlinear_linear] |
|
|
144 |
self.inshape = [self.inshape[i] for i in ind_nonlinear_linear] |
|
|
145 |
trf = [trf[i] for i in ind_nonlinear_linear] |
|
|
146 |
|
|
|
147 |
# go from affine to deformation field |
|
|
148 |
if len(trf) == 1: |
|
|
149 |
trf = trf[0] |
|
|
150 |
if self.is_affine[0]: |
|
|
151 |
trf = tf.map_fn(lambda x: self._single_aff_to_shift(x, vol.shape[1:-1]), trf, dtype=tf.float32) |
|
|
152 |
# combine non-linear and affine to obtain a single deformation field |
|
|
153 |
elif len(trf) == 2: |
|
|
154 |
trf = tf.map_fn(lambda x: self._non_linear_and_aff_to_shift(x, vol.shape[1:-1]), trf, dtype=tf.float32) |
|
|
155 |
|
|
|
156 |
# prepare location shift |
|
|
157 |
if self.indexing == 'xy': # shift the first two dimensions |
|
|
158 |
trf_split = tf.split(trf, trf.shape[-1], axis=-1) |
|
|
159 |
trf_lst = [trf_split[1], trf_split[0], *trf_split[2:]] |
|
|
160 |
trf = tf.concat(trf_lst, -1) |
|
|
161 |
|
|
|
162 |
# map transform across batch |
|
|
163 |
if self.single_transform: |
|
|
164 |
return tf.map_fn(self._single_transform, [vol, trf[0, :]], dtype=tf.float32) |
|
|
165 |
else: |
|
|
166 |
return tf.map_fn(self._single_transform, [vol, trf], dtype=tf.float32) |
|
|
167 |
|
|
|
168 |
def _single_aff_to_shift(self, trf, volshape): |
|
|
169 |
if len(trf.shape) == 1: # go from vector to matrix |
|
|
170 |
trf = tf.reshape(trf, [self.ndims, self.ndims + 1]) |
|
|
171 |
return affine_to_shift(trf, volshape, shift_center=True) |
|
|
172 |
|
|
|
173 |
def _non_linear_and_aff_to_shift(self, trf, volshape): |
|
|
174 |
if len(trf[1].shape) == 1: # go from vector to matrix |
|
|
175 |
trf[1] = tf.reshape(trf[1], [self.ndims, self.ndims + 1]) |
|
|
176 |
return combine_non_linear_and_aff_to_shift(trf, volshape, shift_center=True) |
|
|
177 |
|
|
|
178 |
def _single_transform(self, inputs): |
|
|
179 |
return transform(inputs[0], inputs[1], interp_method=self.interp_method) |
|
|
180 |
|
|
|
181 |
|
|
|
182 |
class VecInt(Layer): |
|
|
183 |
""" |
|
|
184 |
Vector Integration Layer |
|
|
185 |
|
|
|
186 |
Enables vector integration via several methods |
|
|
187 |
(ode or quadrature for time-dependent vector fields, |
|
|
188 |
scaling and squaring for stationary fields) |
|
|
189 |
|
|
|
190 |
If you find this function useful, please cite: |
|
|
191 |
Unsupervised Learning for Fast Probabilistic Diffeomorphic Registration |
|
|
192 |
Adrian V. Dalca, Guha Balakrishnan, John Guttag, Mert R. Sabuncu |
|
|
193 |
MICCAI 2018. |
|
|
194 |
""" |
|
|
195 |
|
|
|
196 |
def __init__(self, indexing='ij', method='ss', int_steps=7, out_time_pt=1, |
|
|
197 |
ode_args=None, |
|
|
198 |
odeint_fn=None, **kwargs): |
|
|
199 |
""" |
|
|
200 |
Parameters: |
|
|
201 |
method can be any of the methods in neuron.utils.integrate_vec |
|
|
202 |
indexing can be 'xy' (switches first two dimensions) or 'ij' |
|
|
203 |
int_steps is the number of integration steps |
|
|
204 |
out_time_pt is time point at which to output if using odeint integration |
|
|
205 |
""" |
|
|
206 |
|
|
|
207 |
assert indexing in ['ij', 'xy'], "indexing has to be 'ij' (matrix) or 'xy' (cartesian)" |
|
|
208 |
self.indexing = indexing |
|
|
209 |
self.method = method |
|
|
210 |
self.int_steps = int_steps |
|
|
211 |
self.inshape = None |
|
|
212 |
self.out_time_pt = out_time_pt |
|
|
213 |
self.odeint_fn = odeint_fn # if none then will use a tensorflow function |
|
|
214 |
self.ode_args = ode_args |
|
|
215 |
if ode_args is None: |
|
|
216 |
self.ode_args = {'rtol': 1e-6, 'atol': 1e-12} |
|
|
217 |
super(self.__class__, self).__init__(**kwargs) |
|
|
218 |
|
|
|
219 |
def get_config(self): |
|
|
220 |
config = super().get_config() |
|
|
221 |
config["indexing"] = self.indexing |
|
|
222 |
config["method"] = self.method |
|
|
223 |
config["int_steps"] = self.int_steps |
|
|
224 |
config["out_time_pt"] = self.out_time_pt |
|
|
225 |
config["ode_args"] = self.ode_args |
|
|
226 |
config["odeint_fn"] = self.odeint_fn |
|
|
227 |
return config |
|
|
228 |
|
|
|
229 |
def build(self, input_shape): |
|
|
230 |
# confirm built |
|
|
231 |
self.built = True |
|
|
232 |
|
|
|
233 |
trf_shape = input_shape |
|
|
234 |
if isinstance(input_shape[0], (list, tuple)): |
|
|
235 |
trf_shape = input_shape[0] |
|
|
236 |
self.inshape = trf_shape |
|
|
237 |
|
|
|
238 |
if trf_shape[-1] != len(trf_shape) - 2: |
|
|
239 |
raise Exception('transform ndims %d does not match expected ndims %d' % (trf_shape[-1], len(trf_shape) - 2)) |
|
|
240 |
|
|
|
241 |
def call(self, inputs, **kwargs): |
|
|
242 |
if not isinstance(inputs, (list, tuple)): |
|
|
243 |
inputs = [inputs] |
|
|
244 |
loc_shift = inputs[0] |
|
|
245 |
|
|
|
246 |
# necessary for multi_gpu models... |
|
|
247 |
loc_shift = K.reshape(loc_shift, [-1, *self.inshape[1:]]) |
|
|
248 |
|
|
|
249 |
# prepare location shift |
|
|
250 |
if self.indexing == 'xy': # shift the first two dimensions |
|
|
251 |
loc_shift_split = tf.split(loc_shift, loc_shift.shape[-1], axis=-1) |
|
|
252 |
loc_shift_lst = [loc_shift_split[1], loc_shift_split[0], *loc_shift_split[2:]] |
|
|
253 |
loc_shift = tf.concat(loc_shift_lst, -1) |
|
|
254 |
|
|
|
255 |
if len(inputs) > 1: |
|
|
256 |
assert self.out_time_pt is None, 'out_time_pt should be None if providing batch_based out_time_pt' |
|
|
257 |
|
|
|
258 |
# map transform across batch |
|
|
259 |
out = tf.map_fn(self._single_int, [loc_shift] + inputs[1:], dtype=tf.float32) |
|
|
260 |
return out |
|
|
261 |
|
|
|
262 |
def _single_int(self, inputs): |
|
|
263 |
|
|
|
264 |
vel = inputs[0] |
|
|
265 |
out_time_pt = self.out_time_pt |
|
|
266 |
if len(inputs) == 2: |
|
|
267 |
out_time_pt = inputs[1] |
|
|
268 |
return integrate_vec(vel, method=self.method, |
|
|
269 |
nb_steps=self.int_steps, |
|
|
270 |
ode_args=self.ode_args, |
|
|
271 |
out_time_pt=out_time_pt, |
|
|
272 |
odeint_fn=self.odeint_fn) |
|
|
273 |
|
|
|
274 |
|
|
|
275 |
class Resize(Layer): |
|
|
276 |
""" |
|
|
277 |
N-D Resize Tensorflow / Keras Layer |
|
|
278 |
Note: this is not re-shaping an existing volume, but resizing, like scipy's "Zoom" |
|
|
279 |
|
|
|
280 |
If you find this function useful, please cite: |
|
|
281 |
Anatomical Priors in Convolutional Networks for Unsupervised Biomedical Segmentation,Dalca AV, Guttag J, Sabuncu MR |
|
|
282 |
CVPR 2018 |
|
|
283 |
|
|
|
284 |
Since then, we've re-written the code to be generalized to any |
|
|
285 |
dimensions, and along the way wrote grid and interpolation functions |
|
|
286 |
""" |
|
|
287 |
|
|
|
288 |
def __init__(self, |
|
|
289 |
zoom_factor=None, |
|
|
290 |
size=None, |
|
|
291 |
interp_method='linear', |
|
|
292 |
**kwargs): |
|
|
293 |
""" |
|
|
294 |
Parameters: |
|
|
295 |
interp_method: 'linear' or 'nearest' |
|
|
296 |
'xy' indexing will have the first two entries of the flow |
|
|
297 |
(along last axis) flipped compared to 'ij' indexing |
|
|
298 |
""" |
|
|
299 |
self.zoom_factor = zoom_factor |
|
|
300 |
self.size = list(size) |
|
|
301 |
self.zoom_factor0 = None |
|
|
302 |
self.size0 = None |
|
|
303 |
self.interp_method = interp_method |
|
|
304 |
self.ndims = None |
|
|
305 |
self.inshape = None |
|
|
306 |
super(Resize, self).__init__(**kwargs) |
|
|
307 |
|
|
|
308 |
def get_config(self): |
|
|
309 |
config = super().get_config() |
|
|
310 |
config["zoom_factor"] = self.zoom_factor |
|
|
311 |
config["size"] = self.size |
|
|
312 |
config["interp_method"] = self.interp_method |
|
|
313 |
return config |
|
|
314 |
|
|
|
315 |
def build(self, input_shape): |
|
|
316 |
""" |
|
|
317 |
input_shape should be an element of list of one inputs: |
|
|
318 |
input1: volume |
|
|
319 |
should be a *vol_shape x N |
|
|
320 |
""" |
|
|
321 |
|
|
|
322 |
if isinstance(input_shape[0], (list, tuple)) and len(input_shape) > 1: |
|
|
323 |
raise Exception('Resize must be called on a list of length 1.') |
|
|
324 |
|
|
|
325 |
if isinstance(input_shape[0], (list, tuple)): |
|
|
326 |
input_shape = input_shape[0] |
|
|
327 |
|
|
|
328 |
# set up number of dimensions |
|
|
329 |
self.ndims = len(input_shape) - 2 |
|
|
330 |
self.inshape = input_shape |
|
|
331 |
|
|
|
332 |
# check zoom_factor |
|
|
333 |
if isinstance(self.zoom_factor, float): |
|
|
334 |
self.zoom_factor0 = [self.zoom_factor] * self.ndims |
|
|
335 |
elif self.zoom_factor is None: |
|
|
336 |
self.zoom_factor0 = [0] * self.ndims |
|
|
337 |
elif isinstance(self.zoom_factor, (list, tuple)): |
|
|
338 |
self.zoom_factor0 = deepcopy(self.zoom_factor) |
|
|
339 |
assert len(self.zoom_factor0) == self.ndims, \ |
|
|
340 |
'zoom factor length {} does not match number of dimensions {}'.format(len(self.zoom_factor), self.ndims) |
|
|
341 |
else: |
|
|
342 |
raise Exception('zoom_factor should be an int or a list/tuple of int (or None if size is not set to None)') |
|
|
343 |
|
|
|
344 |
# check size |
|
|
345 |
if isinstance(self.size, int): |
|
|
346 |
self.size0 = [self.size] * self.ndims |
|
|
347 |
elif self.size is None: |
|
|
348 |
self.size0 = [0] * self.ndims |
|
|
349 |
elif isinstance(self.size, (list, tuple)): |
|
|
350 |
self.size0 = deepcopy(self.size) |
|
|
351 |
assert len(self.size0) == self.ndims, \ |
|
|
352 |
'size length {} does not match number of dimensions {}'.format(len(self.size0), self.ndims) |
|
|
353 |
else: |
|
|
354 |
raise Exception('size should be an int or a list/tuple of int (or None if zoom_factor is not set to None)') |
|
|
355 |
|
|
|
356 |
# confirm built |
|
|
357 |
self.built = True |
|
|
358 |
|
|
|
359 |
super(Resize, self).build(input_shape) # Be sure to call this somewhere! |
|
|
360 |
|
|
|
361 |
def call(self, inputs, **kwargs): |
|
|
362 |
""" |
|
|
363 |
Parameters |
|
|
364 |
inputs: volume or list of one volume |
|
|
365 |
""" |
|
|
366 |
|
|
|
367 |
# check shapes |
|
|
368 |
if isinstance(inputs, (list, tuple)): |
|
|
369 |
assert len(inputs) == 1, "inputs has to be len 1. found: %d" % len(inputs) |
|
|
370 |
vol = inputs[0] |
|
|
371 |
else: |
|
|
372 |
vol = inputs |
|
|
373 |
|
|
|
374 |
# necessary for multi_gpu models... |
|
|
375 |
vol = K.reshape(vol, [-1, *self.inshape[1:]]) |
|
|
376 |
|
|
|
377 |
# set value of missing size or zoom_factor |
|
|
378 |
if not any(self.zoom_factor0): |
|
|
379 |
self.zoom_factor0 = [self.size0[i] / self.inshape[i+1] for i in range(self.ndims)] |
|
|
380 |
else: |
|
|
381 |
self.size0 = [int(self.inshape[f+1] * self.zoom_factor0[f]) for f in range(self.ndims)] |
|
|
382 |
|
|
|
383 |
# map transform across batch |
|
|
384 |
return tf.map_fn(self._single_resize, vol, dtype=vol.dtype) |
|
|
385 |
|
|
|
386 |
def compute_output_shape(self, input_shape): |
|
|
387 |
|
|
|
388 |
output_shape = [input_shape[0]] |
|
|
389 |
output_shape += [int(input_shape[1:-1][f] * self.zoom_factor0[f]) for f in range(self.ndims)] |
|
|
390 |
output_shape += [input_shape[-1]] |
|
|
391 |
return tuple(output_shape) |
|
|
392 |
|
|
|
393 |
def _single_resize(self, inputs): |
|
|
394 |
return resize(inputs, self.zoom_factor0, self.size0, interp_method=self.interp_method) |
|
|
395 |
|
|
|
396 |
|
|
|
397 |
# Zoom naming of resize, to match scipy's naming |
|
|
398 |
Zoom = Resize |
|
|
399 |
|
|
|
400 |
|
|
|
401 |
######################################################### |
|
|
402 |
# "Local" layers -- layers with parameters at each voxel |
|
|
403 |
######################################################### |
|
|
404 |
|
|
|
405 |
class LocalBias(Layer): |
|
|
406 |
""" |
|
|
407 |
Local bias layer: each pixel/voxel has its own bias operation (one parameter) |
|
|
408 |
out[v] = in[v] + b |
|
|
409 |
""" |
|
|
410 |
|
|
|
411 |
def __init__(self, my_initializer='RandomNormal', biasmult=1.0, **kwargs): |
|
|
412 |
self.initializer = my_initializer |
|
|
413 |
self.biasmult = biasmult |
|
|
414 |
self.kernel = None |
|
|
415 |
super(LocalBias, self).__init__(**kwargs) |
|
|
416 |
|
|
|
417 |
def get_config(self): |
|
|
418 |
config = super().get_config() |
|
|
419 |
config["my_initializer"] = self.initializer |
|
|
420 |
config["biasmult"] = self.biasmult |
|
|
421 |
return config |
|
|
422 |
|
|
|
423 |
def build(self, input_shape): |
|
|
424 |
# Create a trainable weight variable for this layer. |
|
|
425 |
self.kernel = self.add_weight(name='kernel', |
|
|
426 |
shape=input_shape[1:], |
|
|
427 |
initializer=self.initializer, |
|
|
428 |
trainable=True) |
|
|
429 |
super(LocalBias, self).build(input_shape) # Be sure to call this somewhere! |
|
|
430 |
|
|
|
431 |
def call(self, x, **kwargs): |
|
|
432 |
return x + self.kernel * self.biasmult # weights are difference from input |
|
|
433 |
|
|
|
434 |
def compute_output_shape(self, input_shape): |
|
|
435 |
return input_shape |