Switch to unified view

a b/ants/registration/landmark_transforms.py
1
__all__ = ["fit_transform_to_paired_points", 
2
           "fit_time_varying_transform_to_point_sets"]
3
4
import numpy as np
5
import math
6
import time
7
8
import ants
9
10
def convergence_monitoring(values, window_size=10):
11
     if len(values) >= window_size:
12
         u = np.linspace(0.0, 1.0, num=window_size)
13
         scattered_data = np.expand_dims(values[-window_size:], axis=-1)
14
         parametric_data = np.expand_dims(u, axis=-1)
15
         spacing = 1 / (window_size-1)
16
         bspline_line = ants.fit_bspline_object_to_scattered_data(scattered_data, parametric_data,
17
             parametric_domain_origin=[0.0], parametric_domain_spacing=[spacing],
18
             parametric_domain_size=[window_size], number_of_fitting_levels=1, mesh_size=1,
19
             spline_order=1)
20
         bspline_slope = -(bspline_line[1][0] - bspline_line[0][0]) / spacing
21
         return(bspline_slope)
22
     else:
23
         return None
24
25
26
def fit_transform_to_paired_points(moving_points,
27
                                   fixed_points,
28
                                   transform_type="affine",
29
                                   regularization=1e-6,
30
                                   domain_image=None,
31
                                   number_of_fitting_levels=4,
32
                                   mesh_size=1,
33
                                   spline_order=3,
34
                                   enforce_stationary_boundary=True,
35
                                   displacement_weights=None,
36
                                   number_of_compositions=10,
37
                                   composition_step_size=0.5,
38
                                   sigma=0.0,
39
                                   convergence_threshold=1e-6,
40
                                   number_of_time_steps=2,
41
                                   number_of_integration_steps=100,
42
                                   rasterize_points=False,
43
                                   verbose=False
44
                                  ):
45
    """
46
    Estimate a transform from corresponding fixed and moving landmarks.
47
48
    ANTsR function: fitTransformToPairedPoints
49
50
    Arguments
51
    ---------
52
    moving_points : array
53
        Moving points specified in physical space as a n x d matrix where n is the number
54
        of points and d is the dimensionality.
55
56
    fixed_points : array
57
        Fixed points specified in physical space as a n x d matrix where n is the number
58
        of points and d is the dimensionality.
59
60
    transform_type : character
61
        'rigid', 'similarity', "affine', 'bspline', 'tps', 'diffeo', 'syn', or 'time-varying (tv)'.
62
63
    regularization : scalar
64
        Ridge penalty in [0,1] for linear transforms.
65
66
    domain_image : ANTs image
67
        Defines physical domain of the nonlinear transform.  Must be defined for nonlinear
68
        transforms.
69
70
    number_of_fitting_levels : integer
71
        Integer specifying the number of fitting levels for the B-spline interpolation of the
72
        displacement field.
73
74
    mesh_size : integer or array
75
        Defines the mesh size at the initial fitting level for the B-spline interpolation of the
76
        displacement field.
77
78
    spline_order : integer
79
        Spline order of the B-spline displacement field.
80
81
    enforce_stationary_boundary : boolean
82
        Ensure no displacements on the image boundary (B-spline only).
83
84
    displacement_weights : array
85
        Defines the individual weighting of the corresponding scattered data value.  Default = NULL
86
        meaning all displacements are weighted the same.
87
88
    number_of_compositions : integer
89
        Total number of compositions for the diffeomorphic transforms.
90
91
    composition_step_size : scalar
92
        Scalar multiplication factor of the weighting of the update field for the diffeomorphic transforms.
93
94
    sigma : scalar
95
        Gaussian smoothing standard deviation of the update field (in mm).
96
97
    convergence_threshold : scalar
98
        Composition-based convergence parameter for the diff. transforms using a
99
        window size of 10 values.
100
101
    number_of_time_steps : integer
102
        Time-varying velocity field parameter.
103
104
    number_of_integration_steps : scalar
105
        Number of steps used for integrating the velocity field.
106
107
    rasterize_points : boolean
108
       Use nearest neighbor rasterization of points for estimating the update
109
       field (potential speed-up).  Default = False.
110
111
    verbose : bool
112
        Print progress to the screen.
113
114
    Returns
115
    -------
116
117
    ANTs transform
118
119
    Example
120
    -------
121
    >>> import ants
122
    >>> import numpy as np
123
    >>> fixed = np.array([[50.0,50.0],[200.0,50.0],[200.0,200.0]])
124
    >>> moving = np.array([[50.0,50.0],[50.0,200.0],[200.0,200.0]])
125
    >>> xfrm = ants.fit_transform_to_paired_points(moving, fixed, transform_type="affine")
126
    >>> xfrm = ants.fit_transform_to_paired_points(moving, fixed, transform_type="rigid")
127
    >>> xfrm = ants.fit_transform_to_paired_points(moving, fixed, transform_type="similarity")
128
    >>> domain_image = ants.image_read(ants.get_ants_data("r16"))
129
    >>> xfrm = ants.fit_transform_to_paired_points(moving, fixed, transform_type="bspline", domain_image=domain_image, number_of_fitting_levels=5)
130
    >>> xfrm = ants.fit_transform_to_paired_points(moving, fixed, transform_type="diffeo", domain_image=domain_image, number_of_fitting_levels=6)
131
    """
132
133
    def polar_decomposition(X):
134
         U, d, V = np.linalg.svd(X, full_matrices=False)
135
         P = np.matmul(U, np.matmul(np.diag(d), np.transpose(U)))
136
         Z = np.matmul(U, V)
137
         if np.linalg.det(Z) < 0:
138
             n = X.shape[0]
139
             reflection_matrix = np.identity(n)
140
             reflection_matrix[0,0] = -1.0
141
             Z = np.matmul(Z, reflection_matrix)
142
         return({"P" : P, "Z" : Z, "Xtilde" : np.matmul(P, Z)})
143
144
    def create_zero_displacement_field(domain_image):
145
         field_array = np.zeros((*domain_image.shape, domain_image.dimension))
146
         field = ants.from_numpy(field_array, origin=domain_image.origin,
147
                 spacing=domain_image.spacing, direction=domain_image.direction,
148
                 has_components=True)
149
         return(field)
150
151
    def create_zero_velocity_field(domain_image, number_of_time_points=2):
152
         field_array = np.zeros((*domain_image.shape, number_of_time_points, domain_image.dimension))
153
         origin = (*domain_image.origin, 0.0)
154
         spacing = (*domain_image.spacing, 1.0)
155
         direction = np.eye(domain_image.dimension + 1)
156
         direction[0:domain_image.dimension,0:domain_image.dimension] = domain_image.direction
157
         field = ants.from_numpy(field_array, origin=origin, spacing=spacing, direction=direction,
158
                 has_components=True)
159
         return(field)
160
161
    allowed_transforms = ['rigid', 'affine', 'similarity', 'bspline', 'tps', 'diffeo', 'syn', 'tv', 'time-varying']
162
    if not transform_type.lower() in allowed_transforms:
163
        raise ValueError(transform_type + " transform not supported.")
164
165
    transform_type = transform_type.lower()
166
167
    if domain_image is None and transform_type in ['bspline', 'tps', 'diffeo', 'syn', 'tv', 'time-varying']:
168
        raise ValueError("Domain image needs to be specified.")
169
170
    if not fixed_points.shape == moving_points.shape:
171
        raise ValueError("Mismatch in the size of the point sets.")
172
173
    if regularization > 1:
174
        regularization = 1
175
    elif regularization < 0:
176
        regularization = 0
177
178
    number_of_points = fixed_points.shape[0]
179
    dimensionality = fixed_points.shape[1]
180
181
    if transform_type in ['rigid', 'affine', 'similarity']:
182
        center_fixed = fixed_points.mean(axis=0)
183
        center_moving = moving_points.mean(axis=0)
184
185
        x = fixed_points - center_fixed
186
        y = moving_points - center_moving
187
188
        y_prior = np.concatenate((y, np.ones((number_of_points, 1))), axis=1)
189
190
        x11 = np.concatenate((x, np.ones((number_of_points, 1))), axis=1)
191
        M = x11 * (1.0 - regularization) + regularization * y_prior
192
        Minv = np.linalg.lstsq(M, y, rcond=None)[0]
193
194
        p = polar_decomposition(Minv[0:dimensionality, 0:dimensionality].T)
195
        A = p['Xtilde']
196
        translation = Minv[dimensionality,:] + center_moving - center_fixed
197
198
        if transform_type in ['rigid', 'similarity']:
199
            # Kabsch algorithm
200
            #    http://web.stanford.edu/class/cs273/refs/umeyama.pdf
201
202
            C = np.dot(y.T, x)
203
            x_svd = np.linalg.svd(C * (1.0 - regularization) + np.eye(dimensionality) * regularization)
204
            x_det = np.linalg.det(np.dot(x_svd[0], x_svd[2]))
205
206
            if x_det < 0:
207
                x_svd[2][dimensionality-1, :] *= -1
208
209
            A = np.dot(x_svd[0], x_svd[2])
210
211
            if transform_type == 'similarity':
212
                scaling = (math.sqrt((np.power(y, 2).sum(axis=1) / number_of_points).mean()) /
213
                           math.sqrt((np.power(x, 2).sum(axis=1) / number_of_points).mean()))
214
                A = np.dot(A, np.eye(dimensionality) * scaling)
215
216
        xfrm = ants.create_ants_transform(matrix=A, translation=translation,
217
              dimension=dimensionality, center=center_fixed)
218
219
        return xfrm
220
221
    elif transform_type == "bspline":
222
223
        bspline_displacement_field = ants.fit_bspline_displacement_field(
224
            displacement_origins=fixed_points,
225
            displacements=moving_points - fixed_points,
226
            displacement_weights=displacement_weights,
227
            origin=domain_image.origin,
228
            spacing=domain_image.spacing,
229
            size=domain_image.shape,
230
            direction=domain_image.direction,
231
            number_of_fitting_levels=number_of_fitting_levels,
232
            mesh_size=mesh_size,
233
            spline_order=spline_order,
234
            enforce_stationary_boundary=enforce_stationary_boundary,
235
            rasterize_points=rasterize_points)
236
237
        xfrm = ants.transform_from_displacement_field(bspline_displacement_field)
238
239
        return xfrm
240
241
    elif transform_type == "tps":
242
243
        tps_displacement_field = ants.fit_thin_plate_spline_displacement_field(
244
            displacement_origins=fixed_points,
245
            displacements=moving_points - fixed_points,
246
            origin=domain_image.origin,
247
            spacing=domain_image.spacing,
248
            size=domain_image.shape,
249
            direction=domain_image.direction)
250
251
        xfrm = ants.transform_from_displacement_field(tps_displacement_field)
252
253
        return xfrm
254
255
    elif transform_type == "diffeo":
256
257
        if verbose:
258
            start_total_time = time.time()
259
260
        updated_fixed_points = np.empty_like(fixed_points)
261
        updated_fixed_points[:] = fixed_points
262
263
        total_field = create_zero_displacement_field(domain_image)
264
        total_field_xfrm = None
265
266
        error_values = []
267
        for i in range(number_of_compositions):
268
269
            if verbose:
270
                start_time = time.time()
271
272
            update_field = ants.fit_bspline_displacement_field(
273
              displacement_origins=updated_fixed_points,
274
              displacements=moving_points - updated_fixed_points,
275
              displacement_weights=displacement_weights,
276
              origin=domain_image.origin,
277
              spacing=domain_image.spacing,
278
              size=domain_image.shape,
279
              direction=domain_image.direction,
280
              number_of_fitting_levels=number_of_fitting_levels,
281
              mesh_size=mesh_size,
282
              spline_order=spline_order,
283
              enforce_stationary_boundary=True,
284
              rasterize_points=rasterize_points
285
            )
286
287
            update_field = update_field * composition_step_size
288
            if sigma > 0:
289
                update_field = ants.smooth_image(update_field, sigma)
290
291
            total_field = ants.compose_displacement_fields(update_field, total_field)
292
            total_field_xfrm = ants.transform_from_displacement_field(total_field)
293
294
            if i < number_of_compositions - 1:
295
                for j in range(updated_fixed_points.shape[0]):
296
                    updated_fixed_points[j,:] = total_field_xfrm.apply_to_point(tuple(fixed_points[j,:]))
297
298
            error_values.append(np.mean(np.sqrt(np.sum(np.square(updated_fixed_points - moving_points), axis=1, keepdims=True))))
299
            convergence_value = convergence_monitoring(error_values)
300
            if verbose:
301
                end_time = time.time()
302
                diff_time = end_time - start_time
303
                print("Composition " + str(i) + ": error = " + str(error_values[-1]) +
304
                      " (convergence = " + str(convergence_value) + ", elapsed time = " + str(diff_time) + ")")
305
            if not convergence_value is None and convergence_value <= convergence_threshold:
306
                break
307
308
        if verbose:
309
            end_total_time = time.time()
310
            diff_total_time = end_total_time - start_total_time
311
            print("Total elapsed time = " + str(diff_total_time) + ".")
312
313
        return(total_field_xfrm)
314
315
    elif transform_type == "syn":
316
317
        if verbose:
318
            start_total_time = time.time()
319
320
        updated_fixed_points = np.empty_like(fixed_points)
321
        updated_fixed_points[:] = fixed_points
322
        updated_moving_points = np.empty_like(moving_points)
323
        updated_moving_points[:] = moving_points
324
325
        total_field_fixed_to_middle = create_zero_displacement_field(domain_image)
326
        total_inverse_field_fixed_to_middle = create_zero_displacement_field(domain_image)
327
328
        total_field_moving_to_middle = create_zero_displacement_field(domain_image)
329
        total_inverse_field_moving_to_middle = create_zero_displacement_field(domain_image)
330
331
        error_values = []
332
        for i in range(number_of_compositions):
333
334
            if verbose:
335
                start_time = time.time()
336
337
            update_field_fixed_to_middle = ants.fit_bspline_displacement_field(
338
              displacement_origins=updated_fixed_points,
339
              displacements=updated_moving_points - updated_fixed_points,
340
              displacement_weights=displacement_weights,
341
              origin=domain_image.origin,
342
              spacing=domain_image.spacing,
343
              size=domain_image.shape,
344
              direction=domain_image.direction,
345
              number_of_fitting_levels=number_of_fitting_levels,
346
              mesh_size=mesh_size,
347
              spline_order=spline_order,
348
              enforce_stationary_boundary=True,
349
              rasterize_points=rasterize_points
350
            )
351
352
            update_field_moving_to_middle = ants.fit_bspline_displacement_field(
353
              displacement_origins=updated_moving_points,
354
              displacements=updated_fixed_points - updated_moving_points,
355
              displacement_weights=displacement_weights,
356
              origin=domain_image.origin,
357
              spacing=domain_image.spacing,
358
              size=domain_image.shape,
359
              direction=domain_image.direction,
360
              number_of_fitting_levels=number_of_fitting_levels,
361
              mesh_size=mesh_size,
362
              spline_order=spline_order,
363
              enforce_stationary_boundary=True,
364
              rasterize_points=rasterize_points
365
            )
366
367
            update_field_fixed_to_middle = update_field_fixed_to_middle * composition_step_size
368
            update_field_moving_to_middle = update_field_moving_to_middle * composition_step_size
369
            if sigma > 0:
370
                update_field_fixed_to_middle = ants.smooth_image(update_field_fixed_to_middle, sigma)
371
                update_field_moving_to_middle = ants.smooth_image(update_field_moving_to_middle, sigma)
372
373
            # Add the update field to both forward displacement fields.
374
375
            total_field_fixed_to_middle = ants.compose_displacement_fields(update_field_fixed_to_middle, total_field_fixed_to_middle)
376
            total_field_moving_to_middle = ants.compose_displacement_fields(update_field_moving_to_middle, total_field_moving_to_middle)
377
378
            # Iteratively estimate the inverse fields.
379
380
            total_inverse_field_fixed_to_middle = ants.invert_displacement_field(total_field_fixed_to_middle, total_inverse_field_fixed_to_middle)
381
            total_inverse_field_moving_to_middle = ants.invert_displacement_field(total_field_moving_to_middle, total_inverse_field_moving_to_middle)
382
383
            total_field_fixed_to_middle = ants.invert_displacement_field(total_inverse_field_fixed_to_middle, total_field_fixed_to_middle)
384
            total_field_moving_to_middle = ants.invert_displacement_field(total_inverse_field_moving_to_middle, total_field_moving_to_middle)
385
386
            total_field_fixed_to_middle_xfrm = ants.transform_from_displacement_field(total_field_fixed_to_middle)
387
            total_field_moving_to_middle_xfrm = ants.transform_from_displacement_field(total_field_moving_to_middle)
388
389
            total_inverse_field_fixed_to_middle_xfrm = ants.transform_from_displacement_field(total_inverse_field_fixed_to_middle)
390
            total_inverse_field_moving_to_middle_xfrm = ants.transform_from_displacement_field(total_inverse_field_moving_to_middle)
391
392
            if i < number_of_compositions - 1:
393
                for j in range(updated_fixed_points.shape[0]):
394
                    updated_fixed_points[j,:] = total_field_fixed_to_middle_xfrm.apply_to_point(tuple(fixed_points[j,:]))
395
                    updated_moving_points[j,:] = total_field_moving_to_middle_xfrm.apply_to_point(tuple(moving_points[j,:]))
396
397
            error_values.append(np.mean(np.sqrt(np.sum(np.square(updated_fixed_points - updated_moving_points), axis=1, keepdims=True))))
398
            convergence_value = convergence_monitoring(error_values)
399
            if verbose:
400
                end_time = time.time()
401
                diff_time = end_time - start_time
402
                print("Composition " + str(i) + ": error = " + str(error_values[-1]) +
403
                      " (convergence = " + str(convergence_value) + ", elapsed time = " + str(diff_time) + ")")
404
            if not convergence_value is None and convergence_value <= convergence_threshold:
405
                break
406
407
        total_forward_field = ants.compose_displacement_fields(total_inverse_field_moving_to_middle, total_field_fixed_to_middle)
408
        total_forward_xfrm = ants.transform_from_displacement_field(total_forward_field)
409
        total_inverse_field = ants.compose_displacement_fields(total_inverse_field_fixed_to_middle, total_field_moving_to_middle)
410
        total_inverse_xfrm = ants.transform_from_displacement_field(total_inverse_field)
411
412
        if verbose:
413
            end_total_time = time.time()
414
            diff_total_time = end_total_time - start_total_time
415
            print("Total elapsed time = " + str(diff_total_time) + ".")
416
417
        return_dict = {'forward_transform' : total_forward_xfrm,
418
                       'inverse_transform' : total_inverse_xfrm,
419
                       'fixed_to_middle_transform' : total_field_fixed_to_middle_xfrm,
420
                       'middle_to_fixed_transform' : total_inverse_field_fixed_to_middle_xfrm,
421
                       'moving_to_middle_transform' : total_field_moving_to_middle_xfrm,
422
                       'middle_to_moving_transform' : total_inverse_field_moving_to_middle_xfrm
423
                       }
424
        return(return_dict)
425
426
    elif transform_type == "tv" or transform_type == "time-varying":
427
428
        if verbose:
429
            start_total_time = time.time()
430
431
        updated_fixed_points = np.empty_like(fixed_points)
432
        updated_fixed_points[:] = fixed_points
433
        updated_moving_points = np.empty_like(moving_points)
434
        updated_moving_points[:] = moving_points
435
436
        velocity_field = create_zero_velocity_field(domain_image, number_of_time_steps)
437
        velocity_field_array = velocity_field.numpy()
438
439
        last_update_derivative_field = create_zero_velocity_field(domain_image, number_of_time_steps)
440
        last_update_derivative_field_array = last_update_derivative_field.numpy()
441
442
        error_values = []
443
        for i in range(number_of_compositions):
444
445
            if verbose:
446
                start_time = time.time()
447
448
            update_derivative_field = create_zero_velocity_field(domain_image, number_of_time_steps)
449
            update_derivative_field_array = update_derivative_field.numpy()
450
451
            average_error = 0.0
452
            for n in range(number_of_time_steps):
453
454
                t = n / (number_of_time_steps - 1.0)
455
456
                if n > 0:
457
                    integrated_forward_field = ants.integrate_velocity_field(velocity_field, 0.0, t, number_of_integration_steps)
458
                    integrated_forward_field_xfrm = ants.transform_from_displacement_field(integrated_forward_field)
459
                    for j in range(updated_fixed_points.shape[0]):
460
                        updated_fixed_points[j,:] = integrated_forward_field_xfrm.apply_to_point(tuple(fixed_points[j,:]))
461
                else:
462
                    updated_fixed_points[:] = fixed_points
463
464
                if n < number_of_time_steps - 1:
465
                    integrated_inverse_field = ants.integrate_velocity_field(velocity_field, 1.0, t, number_of_integration_steps)
466
                    integrated_inverse_field_xfrm = ants.transform_from_displacement_field(integrated_inverse_field)
467
                    for j in range(updated_moving_points.shape[0]):
468
                        updated_moving_points[j,:] = integrated_inverse_field_xfrm.apply_to_point(tuple(moving_points[j,:]))
469
                else:
470
                    updated_moving_points[:] = moving_points
471
472
                update_derivative_field_at_timepoint = ants.fit_bspline_displacement_field(
473
                  displacement_origins=updated_fixed_points,
474
                  displacements=updated_moving_points - updated_fixed_points,
475
                  displacement_weights=displacement_weights,
476
                  origin=domain_image.origin,
477
                  spacing=domain_image.spacing,
478
                  size=domain_image.shape,
479
                  direction=domain_image.direction,
480
                  number_of_fitting_levels=number_of_fitting_levels,
481
                  mesh_size=mesh_size,
482
                  spline_order=spline_order,
483
                  enforce_stationary_boundary=True,
484
                  rasterize_points=rasterize_points
485
                  )
486
487
                if sigma > 0:
488
                    update_derivative_field_at_timepoint = ants.smooth_image(update_derivative_field_at_timepoint, sigma)
489
490
                update_derivative_field_at_timepoint_array = update_derivative_field_at_timepoint.numpy()
491
                grad_norms = np.sqrt(np.sum(np.square(update_derivative_field_at_timepoint_array), axis=-1, keepdims=False))
492
                max_norm = np.amax(grad_norms)
493
                median_norm = np.median(grad_norms)
494
                if verbose:
495
                    print("  integration point " + str(t) + ": max_norm = " + str(max_norm) + ", median_norm = " + str(median_norm))
496
                update_derivative_field_at_timepoint_array /= max_norm
497
                if domain_image.dimension == 2:
498
                    update_derivative_field_array[:,:,n,:] = update_derivative_field_at_timepoint_array
499
                elif domain_image.dimension == 3:
500
                    update_derivative_field_array[:,:,:,n,:] = update_derivative_field_at_timepoint_array
501
502
                rmse = np.mean(np.sqrt(np.sum(np.square(updated_moving_points - updated_fixed_points), axis=1, keepdims=True)))
503
                average_error = (average_error * n + rmse) / (n + 1)
504
505
            update_derivative_field_array = (update_derivative_field_array + last_update_derivative_field_array) * 0.5
506
            last_update_derivative_field_array = np.empty_like(update_derivative_field_array)
507
            last_update_derivative_field_array[:] = update_derivative_field_array
508
509
            velocity_field_array = velocity_field_array + update_derivative_field_array * composition_step_size
510
            velocity_field = ants.from_numpy(velocity_field_array, origin=velocity_field.origin,
511
                                             spacing=velocity_field.spacing, direction=velocity_field.direction,
512
                                             has_components=True)
513
514
            error_values.append(average_error)
515
            convergence_value = convergence_monitoring(error_values)
516
            if verbose:
517
                end_time = time.time()
518
                diff_time = end_time - start_time
519
                print("Composition " + str(i) + ": error = " + str(error_values[-1]) +
520
                      " (convergence = " + str(convergence_value) + ", elapsed time = " + str(diff_time) + ")")
521
            if not convergence_value is None and convergence_value <= convergence_threshold:
522
                break
523
524
        forward_xfrm = ants.transform_from_displacement_field(ants.integrate_velocity_field(velocity_field, 0.0, 1.0, number_of_integration_steps))
525
        inverse_xfrm = ants.transform_from_displacement_field(ants.integrate_velocity_field(velocity_field, 1.0, 0.0, number_of_integration_steps))
526
527
        if verbose:
528
            end_total_time = time.time()
529
            diff_total_time = end_total_time - start_total_time
530
            print("Total elapsed time = " + str(diff_total_time) + ".")
531
532
        return_dict = {'forward_transform': forward_xfrm,
533
                       'inverse_transform': inverse_xfrm,
534
                       'velocity_field': velocity_field}
535
        return(return_dict)
536
537
    else:
538
        raise ValueError("Unrecognized transform_type.")
539
540
541
def fit_time_varying_transform_to_point_sets(point_sets,
542
                                             time_points=None,
543
                                             initial_velocity_field=None,
544
                                             number_of_time_steps=None,
545
                                             domain_image=None,
546
                                             number_of_fitting_levels=4,
547
                                             mesh_size=1,
548
                                             spline_order=3,
549
                                             displacement_weights=None,
550
                                             number_of_compositions=10,
551
                                             composition_step_size=0.5,
552
                                             number_of_integration_steps=100,
553
                                             sigma=0.0,
554
                                             convergence_threshold=1e-6,
555
                                             rasterize_points=False,
556
                                             verbose=False
557
                                            ):
558
    """
559
560
    Estimate a time-varying transform from corresponding point sets (> 2).
561
562
    ANTsR function: fitTimeVaryingTransformToPointSets
563
564
    Arguments
565
    ---------
566
    point_sets : list of arrays
567
        Corresponding points across sets specified in physical space as a n x d matrix where n
568
        is the number of points and d is the dimensionality.
569
570
    time_points : array of ordered scalars between 0 and 1
571
        Set of scalar values, one for each point-set, designating its time position in the velocity
572
        flow.  If not set, it defaults to equal spacing between 0 and 1.
573
574
    initial_velocity_field : initial ANTs velocity field
575
        Optional velocity field for initializing optimization.  Overrides the number of integration
576
        points.
577
578
    number_of_time_steps : integer
579
        Time-varying velocity field parameter.  Needs to be equal to or greater than the number of
580
        point sets.  If not specified, it defaults to the number of point sets.
581
582
    domain_image : ANTs image
583
        Defines physical domain of the nonlinear transform.  Must be defined.
584
585
    number_of_fitting_levels : integer
586
        Integer specifying the number of fitting levels for the B-spline interpolation of the
587
        displacement field.
588
589
    mesh_size : integer or array
590
        Defines the mesh size at the initial fitting level for the B-spline interpolation of the
591
        displacement field..
592
593
    spline_order : integer
594
        Spline order of the B-spline displacement field.
595
596
    displacement_weights : array
597
        Defines the individual weighting of the corresponding scattered data value.  Default = NULL
598
        meaning all displacements are weighted the same.
599
600
    number_of_compositions : integer
601
        Total number of compositions.
602
603
    composition_step_size : scalar
604
        Scalar multiplication factor of the weighting of the update field.
605
606
    number_of_integration_steps : scalar
607
        Number of steps used for integrating the velocity field.
608
609
    sigma : scalar
610
        Gaussian smoothing standard deviation of the update field (in mm).
611
612
    convergence_threshold : scalar
613
        Composition-based convergence parameter using a window size of 10 values.
614
615
    rasterize_points : boolean
616
        Use nearest neighbor rasterization of points for estimating the update field (potential
617
        speed-up).  Default = False.
618
619
    verbose : bool
620
        Print progress to the screen.
621
622
    Returns
623
    -------
624
625
    ANTs transform
626
627
    Example
628
    -------
629
    >>> import ants
630
    >>> import numpy as np
631
    """
632
633
    def create_zero_velocity_field(domain_image, number_of_time_points=2):
634
         field_array = np.zeros((*domain_image.shape, number_of_time_points, domain_image.dimension))
635
         origin = (*domain_image.origin, 0.0)
636
         spacing = (*domain_image.spacing, 1.0)
637
         direction = np.eye(domain_image.dimension + 1)
638
         direction[0:domain_image.dimension,0:domain_image.dimension] = domain_image.direction
639
         field = ants.from_numpy(field_array, origin=origin, spacing=spacing, direction=direction,
640
                 has_components=True)
641
         return(field)
642
643
    if not isinstance(point_sets, list):
644
        raise ValueError("point_sets should be a list of corresponding point sets.")
645
646
    number_of_point_sets = len(point_sets)
647
648
    if time_points is not None and len(time_points) != number_of_point_sets:
649
        raise ValueError("The number of time points should be the same as the number of point sets.")
650
651
    if time_points is None:
652
        time_points = np.linspace(0.0, 1.0, number_of_point_sets)
653
    time_points = np.array(time_points)
654
655
    if np.any(time_points < 0.0) or np.any(time_points > 1.0):
656
        raise ValueError("time point values should be between 0 and 1.")
657
658
    if number_of_point_sets < 3:
659
        raise ValueError("Expecting three or greater point sets.")
660
661
    if domain_image is None:
662
        raise ValueError("Domain image needs to be specified.")
663
664
    number_of_points = point_sets[0].shape[0]
665
    dimensionality = point_sets[0].shape[1]
666
    for i in range(1, number_of_point_sets):
667
        if point_sets[i].shape[0] != number_of_points:
668
            raise ValueError("Point sets should match in terms of the number of points.")
669
        if point_sets[i].shape[1] != dimensionality:
670
            raise ValueError("Point sets should match in terms of dimensionality.")
671
672
    if verbose:
673
        start_total_time = time.time()
674
675
    updated_fixed_points = np.zeros(point_sets[0].shape)
676
    updated_moving_points = np.zeros(point_sets[0].shape)
677
678
    velocity_field = None
679
    if initial_velocity_field is None:
680
        if number_of_time_steps is None:
681
            number_of_time_steps = len(time_points)
682
        if number_of_time_steps < number_of_point_sets:
683
            raise ValueError("The number of integration points should be at least as great as the number of point sets.")
684
        velocity_field = create_zero_velocity_field(domain_image, number_of_time_steps)
685
    else:
686
        velocity_field = ants.image_clone(initial_velocity_field)
687
        number_of_time_steps = initial_velocity_field.shape[-1]
688
    velocity_field_array = velocity_field.numpy()
689
690
    last_update_derivative_field = create_zero_velocity_field(domain_image, number_of_time_steps)
691
    last_update_derivative_field_array = last_update_derivative_field.numpy()
692
693
    error_values = []
694
    for i in range(number_of_compositions):
695
696
        if verbose:
697
            start_time = time.time()
698
699
        update_derivative_field = create_zero_velocity_field(domain_image, number_of_time_steps)
700
        update_derivative_field_array = update_derivative_field.numpy()
701
702
        average_error = 0.0
703
        for n in range(number_of_time_steps):
704
705
            t = n / (number_of_time_steps - 1.0)
706
707
            t_index = 0
708
            for j in range(1, number_of_point_sets):
709
                if time_points[j-1] <= t and time_points[j] >= t:
710
                    t_index = j
711
                    break
712
713
            if n > 0 and n < number_of_time_steps - 1 and time_points[t_index-1] == t:
714
                updated_fixed_points[:] = point_sets[t_index-1]
715
                integrated_inverse_field = ants.integrate_velocity_field(velocity_field, time_points[t_index], t, number_of_integration_steps)
716
                integrated_inverse_field_xfrm = ants.transform_from_displacement_field(integrated_inverse_field)
717
                for j in range(updated_moving_points.shape[0]):
718
                    updated_moving_points[j,:] = integrated_inverse_field_xfrm.apply_to_point(tuple(point_sets[t_index][j,:]))
719
720
                update_derivative_field_at_timepoint_forward = ants.fit_bspline_displacement_field(
721
                  displacement_origins=updated_fixed_points,
722
                  displacements=updated_moving_points - updated_fixed_points,
723
                  displacement_weights=displacement_weights,
724
                  origin=domain_image.origin,
725
                  spacing=domain_image.spacing,
726
                  size=domain_image.shape,
727
                  direction=domain_image.direction,
728
                  number_of_fitting_levels=number_of_fitting_levels,
729
                  mesh_size=mesh_size,
730
                  spline_order=spline_order,
731
                  enforce_stationary_boundary=True,
732
                  rasterize_points=rasterize_points
733
                  )
734
735
                updated_moving_points[:] = point_sets[t_index-1]
736
                integrated_forward_field = ants.integrate_velocity_field(velocity_field, time_points[t_index-2], t, number_of_integration_steps)
737
                integrated_forward_field_xfrm = ants.transform_from_displacement_field(integrated_forward_field)
738
                for j in range(updated_fixed_points.shape[0]):
739
                    updated_fixed_points[j,:] = integrated_forward_field_xfrm.apply_to_point(tuple(point_sets[t_index-2][j,:]))
740
741
                update_derivative_field_at_timepoint_back = ants.fit_bspline_displacement_field(
742
                  displacement_origins=updated_fixed_points,
743
                  displacements=updated_moving_points - updated_fixed_points,
744
                  displacement_weights=displacement_weights,
745
                  origin=domain_image.origin,
746
                  spacing=domain_image.spacing,
747
                  size=domain_image.shape,
748
                  direction=domain_image.direction,
749
                  number_of_fitting_levels=number_of_fitting_levels,
750
                  mesh_size=mesh_size,
751
                  spline_order=spline_order,
752
                  enforce_stationary_boundary=True,
753
                  rasterize_points=rasterize_points
754
                  )
755
756
                update_derivative_field_at_timepoint = (update_derivative_field_at_timepoint_forward +
757
                                                        update_derivative_field_at_timepoint_back) / 2.0
758
759
            else:
760
                if t == 0.0 and time_points[t_index-1] == 0.0:
761
                    updated_fixed_points[:] = point_sets[0]
762
                else:
763
                    integrated_forward_field = ants.integrate_velocity_field(velocity_field, time_points[t_index-1], t, number_of_integration_steps)
764
                    integrated_forward_field_xfrm = ants.transform_from_displacement_field(integrated_forward_field)
765
                    for j in range(updated_fixed_points.shape[0]):
766
                        updated_fixed_points[j,:] = integrated_forward_field_xfrm.apply_to_point(tuple(point_sets[t_index-1][j,:]))
767
768
                if t == 1.0 and time_points[t_index] == 1.0:
769
                    updated_moving_points[:] = point_sets[-1]
770
                else:
771
                    integrated_inverse_field = ants.integrate_velocity_field(velocity_field, time_points[t_index], t, number_of_integration_steps)
772
                    integrated_inverse_field_xfrm = ants.transform_from_displacement_field(integrated_inverse_field)
773
                    for j in range(updated_moving_points.shape[0]):
774
                        updated_moving_points[j,:] = integrated_inverse_field_xfrm.apply_to_point(tuple(point_sets[t_index][j,:]))
775
776
                update_derivative_field_at_timepoint = ants.fit_bspline_displacement_field(
777
                  displacement_origins=updated_fixed_points,
778
                  displacements=updated_moving_points - updated_fixed_points,
779
                  displacement_weights=displacement_weights,
780
                  origin=domain_image.origin,
781
                  spacing=domain_image.spacing,
782
                  size=domain_image.shape,
783
                  direction=domain_image.direction,
784
                  number_of_fitting_levels=number_of_fitting_levels,
785
                  mesh_size=mesh_size,
786
                  spline_order=spline_order,
787
                  enforce_stationary_boundary=True,
788
                  rasterize_points=rasterize_points
789
                  )
790
791
            if sigma > 0:
792
                update_derivative_field_at_timepoint = ants.smooth_image(update_derivative_field_at_timepoint, sigma)
793
794
            update_derivative_field_at_timepoint_array = update_derivative_field_at_timepoint.numpy()
795
            grad_norms = np.sqrt(np.sum(np.square(update_derivative_field_at_timepoint_array), axis=-1, keepdims=False))
796
            max_norm = np.amax(grad_norms)
797
            median_norm = np.median(grad_norms)
798
            if verbose:
799
                print("  integration point " + str(t) + ": max_norm = " + str(max_norm) + ", median_norm = " + str(median_norm))
800
            update_derivative_field_at_timepoint_array /= max_norm
801
            if domain_image.dimension == 2:
802
                update_derivative_field_array[:,:,n,:] = update_derivative_field_at_timepoint_array
803
            elif domain_image.dimension == 3:
804
                update_derivative_field_array[:,:,:,n,:] = update_derivative_field_at_timepoint_array
805
806
            rmse = np.mean(np.sqrt(np.sum(np.square(updated_moving_points - updated_fixed_points), axis=1, keepdims=True)))
807
            average_error = (average_error * n + rmse) / (n + 1)
808
809
        update_derivative_field_array = (update_derivative_field_array + last_update_derivative_field_array) * 0.5
810
        last_update_derivative_field_array = np.empty_like(update_derivative_field_array)
811
        last_update_derivative_field_array[:] = update_derivative_field_array
812
813
        velocity_field_array += (update_derivative_field_array * composition_step_size)
814
        velocity_field = ants.from_numpy(velocity_field_array, origin=velocity_field.origin,
815
                                         spacing=velocity_field.spacing, direction=velocity_field.direction,
816
                                         has_components=True)
817
818
        error_values.append(average_error)
819
        convergence_value = convergence_monitoring(error_values)
820
        if verbose:
821
            end_time = time.time()
822
            diff_time = end_time - start_time
823
            print("Composition " + str(i) + ": error = " + str(error_values[-1]) +
824
                  " (convergence = " + str(convergence_value) + ", elapsed time = " + str(diff_time) + ")")
825
        if not convergence_value is None and convergence_value <= convergence_threshold:
826
            break
827
828
    forward_xfrm = ants.transform_from_displacement_field(ants.integrate_velocity_field(velocity_field, 0.0, 1.0, number_of_integration_steps))
829
    inverse_xfrm = ants.transform_from_displacement_field(ants.integrate_velocity_field(velocity_field, 1.0, 0.0, number_of_integration_steps))
830
831
    if verbose:
832
        end_total_time = time.time()
833
        diff_total_time = end_total_time - start_total_time
834
        print("Total elapsed time = " + str(diff_total_time) + ".")
835
836
    return_dict = {'forward_transform': forward_xfrm,
837
                   'inverse_transform': inverse_xfrm,
838
                   'velocity_field': velocity_field}
839
    return(return_dict)
840
841
842
843