a b/tests/test_registration.py
1
"""
2
Test ants.registration module
3
4
nptest.assert_allclose
5
self.assertEqual
6
self.assertTrue
7
"""
8
9
import os
10
import unittest
11
from common import run_tests
12
from tempfile import mktemp
13
14
import numpy as np
15
import numpy.testing as nptest
16
import pandas as pd
17
18
import ants
19
20
21
class TestModule_affine_initializer(unittest.TestCase):
22
    def setUp(self):
23
        pass
24
25
    def tearDown(self):
26
        pass
27
28
    def test_example(self):
29
        # test ANTsPy/ANTsR example
30
        fi = ants.image_read(ants.get_ants_data("r16"))
31
        mi = ants.image_read(ants.get_ants_data("r27"))
32
        txfile = ants.affine_initializer(fi, mi)
33
        tx = ants.read_transform(txfile)
34
35
36
class TestModule_apply_transforms(unittest.TestCase):
37
    def setUp(self):
38
        pass
39
40
    def tearDown(self):
41
        pass
42
43
    def test_example(self):
44
        # test ANTsPy/ANTsR example
45
        fixed = ants.image_read(ants.get_ants_data("r16"))
46
        moving = ants.image_read(ants.get_ants_data("r64"))
47
        fixed = ants.resample_image(fixed, (64, 64), 1, 0)
48
        moving = ants.resample_image(moving, (128, 128), 1, 0)
49
        mytx = ants.registration(fixed=fixed, moving=moving, type_of_transform="SyN")
50
        mywarpedimage = ants.apply_transforms(
51
            fixed=fixed, moving=moving, transformlist=mytx["fwdtransforms"]
52
        )
53
        self.assertEqual(mywarpedimage.pixeltype, moving.pixeltype)
54
        self.assertTrue(ants.image_physical_space_consistency(fixed, mywarpedimage,
55
                                                              0.0001, datatype = False))
56
57
        # Call with float precision for transforms, but should still return input type
58
        mywarpedimage2 = ants.apply_transforms(
59
            fixed=fixed, moving=moving, transformlist=mytx["fwdtransforms"], singleprecision=True
60
        )
61
        self.assertEqual(mywarpedimage2.pixeltype, moving.pixeltype)
62
        self.assertLessEqual(np.sum((mywarpedimage.numpy() - mywarpedimage2.numpy()) ** 2), 0.1)
63
64
        # bad interpolator
65
        with self.assertRaises(Exception):
66
            mywarpedimage = ants.apply_transforms(
67
                fixed=fixed,
68
                moving=moving,
69
                transformlist=mytx["fwdtransforms"],
70
                interpolator="unsupported-interp",
71
            )
72
73
        # transform doesnt exist
74
        with self.assertRaises(Exception):
75
            mywarpedimage = ants.apply_transforms(
76
                fixed=fixed,
77
                moving=moving,
78
                transformlist=["blah-blah.mat"],
79
                interpolator="unsupported-interp",
80
            )
81
82
83
class TestModule_create_jacobian_determinant_image(unittest.TestCase):
84
    def setUp(self):
85
        pass
86
87
    def tearDown(self):
88
        pass
89
90
    def test_example(self):
91
        fi = ants.image_read(ants.get_ants_data("r16"))
92
        mi = ants.image_read(ants.get_ants_data("r64"))
93
        fi = ants.resample_image(fi, (128, 128), 1, 0)
94
        mi = ants.resample_image(mi, (128, 128), 1, 0)
95
        mytx = ants.registration(fixed=fi, moving=mi, type_of_transform=("SyN"))
96
        try:
97
            jac = ants.create_jacobian_determinant_image(
98
                fi, mytx["fwdtransforms"][0], 1
99
            )
100
        except:
101
            pass
102
103
104
class TestModule_create_warped_grid(unittest.TestCase):
105
    def setUp(self):
106
        pass
107
108
    def tearDown(self):
109
        pass
110
111
    def test_example(self):
112
        fi = ants.image_read(ants.get_ants_data("r16"))
113
        mi = ants.image_read(ants.get_ants_data("r64"))
114
        mygr = ants.create_warped_grid(mi)
115
116
        mytx = ants.registration(fixed=fi, moving=mi, type_of_transform=("SyN"))
117
        mywarpedgrid = ants.create_warped_grid(
118
            mi,
119
            grid_directions=(False, True),
120
            transform=mytx["fwdtransforms"],
121
            fixed_reference_image=fi,
122
        )
123
124
125
class TestModule_fsl2antstransform(unittest.TestCase):
126
    def setUp(self):
127
        pass
128
129
    def tearDown(self):
130
        pass
131
132
    def test_example(self):
133
        fslmat = np.zeros((4, 4))
134
        np.fill_diagonal(fslmat, 1)
135
        img = ants.image_read(ants.get_ants_data("ch2"))
136
        tx = ants.fsl2antstransform(fslmat, img, img)
137
138
139
class TestModule_interface(unittest.TestCase):
140
    def setUp(self):
141
        self.transform_types = {
142
            "SyNBold",
143
            "SyNBoldAff",
144
            "ElasticSyN",
145
            "SyN",
146
            "SyNRA",
147
            "SyNOnly",
148
            "SyNAggro",
149
            "SyNCC",
150
            "TRSAA",
151
            "SyNabp",
152
            "SyNLessAggro",
153
            "TVMSQ",
154
            "TVMSQC",
155
            "Rigid",
156
            "Similarity",
157
            "Translation",
158
            "Affine",
159
            "AffineFast",
160
            "BOLDAffine",
161
            "QuickRigid",
162
            "DenseRigid",
163
            "BOLDRigid",
164
            "antsRegistrationSyNQuick[b,32,26]",
165
            "antsRegistrationSyNQuick[s]",
166
            "antsRegistrationSyNRepro[s]",
167
            "antsRegistrationSyN[s]"
168
        }
169
170
    def tearDown(self):
171
        pass
172
173
    def test_example(self):
174
        fi = ants.image_read(ants.get_ants_data("r16"))
175
        mi = ants.image_read(ants.get_ants_data("r64"))
176
        fi = ants.resample_image(fi, (60, 60), 1, 0)
177
        mi = ants.resample_image(mi, (60, 60), 1, 0)
178
        mytx = ants.registration(fixed=fi, moving=mi, type_of_transform="SyN")
179
180
    def test_affine_interface(self):
181
        print("Starting affine interface registration test")
182
        fi = ants.image_read(ants.get_ants_data("r16"))
183
        mi = ants.image_read(ants.get_ants_data("r64"))
184
        with self.assertRaises(ValueError):
185
            ants.registration(
186
                fixed=fi,
187
                moving=mi,
188
                type_of_transform="Translation",
189
                aff_iterations=4,
190
                aff_shrink_factors=4,
191
                aff_smoothing_sigmas=(4, 4),
192
            )
193
194
        mytx = ants.registration(
195
            fixed=fi,
196
            moving=mi,
197
            type_of_transform="Affine",
198
            aff_iterations=(4, 4),
199
            aff_shrink_factors=(4, 4),
200
            aff_smoothing_sigmas=(4, 4),
201
        )
202
        mytx = ants.registration(
203
            fixed=fi,
204
            moving=mi,
205
            type_of_transform="Translation",
206
            aff_iterations=4,
207
            aff_shrink_factors=4,
208
            aff_smoothing_sigmas=4,
209
        )
210
211
    def test_registration_types(self):
212
        print("Starting long registration interface test")
213
        fi = ants.image_read(ants.get_ants_data("r16"))
214
        mi = ants.image_read(ants.get_ants_data("r64"))
215
        fi = ants.resample_image(fi, (60, 60), 1, 0)
216
        mi = ants.resample_image(mi, (60, 60), 1, 0)
217
218
        for ttype in self.transform_types:
219
            print(ttype)
220
            mytx = ants.registration(fixed=fi, moving=mi, type_of_transform=ttype)
221
222
            # with mask
223
            fimask = fi > fi.mean()
224
            mytx = ants.registration(
225
                fixed=fi, moving=mi, mask=fimask, type_of_transform=ttype
226
            )
227
        print("Finished long registration interface test")
228
229
230
class TestModule_metrics(unittest.TestCase):
231
    def setUp(self):
232
        pass
233
234
    def tearDown(self):
235
        pass
236
237
    def test_example(self):
238
        fi = ants.image_read(ants.get_ants_data("r16")).clone("float")
239
        mi = ants.image_read(ants.get_ants_data("r64")).clone("float")
240
        mival = ants.image_mutual_information(fi, mi)  # -0.1796141
241
242
243
class TestModule_reflect_image(unittest.TestCase):
244
    def setUp(self):
245
        pass
246
247
    def tearDown(self):
248
        pass
249
250
    def test_example(self):
251
        fi = ants.image_read(ants.get_ants_data("r16"))
252
        axis = 2
253
        asym = ants.reflect_image(fi, axis, "Affine")["warpedmovout"]
254
        asym = asym - fi
255
256
257
class TestModule_reorient_image(unittest.TestCase):
258
    def setUp(self):
259
        pass
260
261
    def tearDown(self):
262
        pass
263
264
    def test_reorient_image(self):
265
        mni = ants.image_read(ants.get_data('mni'))
266
        mni2 = mni.reorient_image2()
267
268
    def test_get_center_of_mass(self):
269
        fi = ants.image_read(ants.get_ants_data("r16"))
270
        com = ants.get_center_of_mass(fi)
271
272
        self.assertEqual(len(com), fi.dimension)
273
274
        fi = ants.image_read(ants.get_ants_data("r64"))
275
        com = ants.get_center_of_mass(fi)
276
        self.assertEqual(len(com), fi.dimension)
277
278
        fi = fi.clone("unsigned int")
279
        com = ants.get_center_of_mass(fi)
280
        self.assertEqual(len(com), fi.dimension)
281
282
        # 3d
283
        img = ants.image_read(ants.get_ants_data("mni"))
284
        com = ants.get_center_of_mass(img)
285
        self.assertEqual(len(com), img.dimension)
286
287
288
class TestModule_resample_image(unittest.TestCase):
289
    def setUp(self):
290
        pass
291
292
    def tearDown(self):
293
        pass
294
295
    def test_resample_image_example(self):
296
        fi = ants.image_read(ants.get_ants_data("r16"))
297
        finn = ants.resample_image(fi, (50, 60), True, 0)
298
        filin = ants.resample_image(fi, (1.5, 1.5), False, 1)
299
300
    def test_resample_channels(self):
301
        img = ants.image_read( ants.get_ants_data("r16"))
302
        img = ants.merge_channels([img, img])
303
        outimg = ants.resample_image(img, (128,128), True)
304
        self.assertEqual(outimg.shape, (128, 128))
305
        self.assertEqual(outimg.components, 2)
306
307
    def test_resample_image_to_target_example(self):
308
        fi = ants.image_read(ants.get_ants_data("r16"))
309
        fi2mm = ants.resample_image(fi, (2, 2), use_voxels=0, interp_type=1)
310
        resampled = ants.resample_image_to_target(fi2mm, fi, verbose=True)
311
        self.assertTrue(ants.image_physical_space_consistency(fi, resampled, 0.0001, datatype=True))
312
313
314
class TestModule_symmetrize_image(unittest.TestCase):
315
    def setUp(self):
316
        pass
317
318
    def tearDown(self):
319
        pass
320
321
    def test_example(self):
322
        image = ants.image_read(ants.get_ants_data("r16"))
323
        simage = ants.symmetrize_image(image)
324
325
326
class TestModule_build_template(unittest.TestCase):
327
    def setUp(self):
328
        pass
329
330
    def tearDown(self):
331
        pass
332
333
    def test_example(self):
334
        image = ants.image_read(ants.get_ants_data("r16"))
335
        image2 = ants.image_read(ants.get_ants_data("r27"))
336
        timage = ants.build_template(image_list=(image, image2))
337
338
    def test_type_of_transform(self):
339
        image = ants.image_read(ants.get_ants_data("r16"))
340
        image2 = ants.image_read(ants.get_ants_data("r27"))
341
        timage = ants.build_template(image_list=(image, image2))
342
        timage = ants.build_template(
343
            image_list=(image, image2), type_of_transform="SyNCC"
344
        )
345
346
347
class TestModule_multivar(unittest.TestCase):
348
    def setUp(self):
349
        pass
350
351
    def tearDown(self):
352
        pass
353
354
    def test_example(self):
355
        image = ants.image_read(ants.get_ants_data("r16"))
356
        image2 = ants.image_read(ants.get_ants_data("r27"))
357
        demonsMetric = ["demons", image, image2, 1, 1]
358
        ccMetric = ["CC", image, image2, 2, 1]
359
        metrics = list()
360
        metrics.append(demonsMetric)
361
        reg3 = ants.registration(image, image2, "SyNOnly", multivariate_extras=metrics)
362
        metrics.append(ccMetric)
363
        reg2 = ants.registration(
364
            image, image2, "SyNOnly", multivariate_extras=metrics, verbose=True
365
        )
366
367
class TestModule_random(unittest.TestCase):
368
    def setUp(self):
369
        pass
370
371
    def tearDown(self):
372
        pass
373
374
    def test_landmark_transforms(self):
375
        fixed = np.array([[50.0,50.0],[200.0,50.0],[200.0,200.0]])
376
        moving = np.array([[50.0,50.0],[50.0,200.0],[200.0,200.0]])
377
        xfrm = ants.fit_transform_to_paired_points(moving, fixed, transform_type="syn",
378
                                            domain_image=ants.image_read(ants.get_data('r16')),
379
                                            verbose=True)
380
        xfrm = ants.fit_transform_to_paired_points(moving, fixed, transform_type="tv",
381
                                            domain_image=ants.image_read(ants.get_data('r16')))
382
        xfrm = ants.fit_transform_to_paired_points(moving, fixed, transform_type="affine")
383
        xfrm = ants.fit_transform_to_paired_points(moving, fixed, transform_type="rigid")
384
        xfrm = ants.fit_transform_to_paired_points(moving, fixed, transform_type="similarity")
385
        domain_image = ants.image_read(ants.get_ants_data("r16"))
386
        xfrm = ants.fit_transform_to_paired_points(moving, fixed, transform_type="bspline", domain_image=domain_image, number_of_fitting_levels=5)
387
        xfrm = ants.fit_transform_to_paired_points(moving, fixed, transform_type="diffeo", domain_image=domain_image, number_of_fitting_levels=6)
388
389
        res = ants.fit_time_varying_transform_to_point_sets([fixed, moving, moving],
390
                                                            domain_image=ants.image_read(ants.get_data('r16')),
391
                                                            verbose=True)
392
393
    def test_deformation_gradient(self):
394
        fi = ants.image_read( ants.get_ants_data('r16'))
395
        mi = ants.image_read( ants.get_ants_data('r64'))
396
        fi = ants.resample_image(fi,(128,128),1,0)
397
        mi = ants.resample_image(mi,(128,128),1,0)
398
        mytx = ants.registration(fixed=fi , moving=mi, type_of_transform = ('SyN') )
399
        dg = ants.deformation_gradient( ants.image_read( mytx['fwdtransforms'][0] ) )
400
401
        dg = ants.deformation_gradient( ants.image_read( mytx['fwdtransforms'][0] ),
402
                                       py_based=True)
403
404
        dg = ants.deformation_gradient( ants.image_read( mytx['fwdtransforms'][0] ),
405
                                       to_rotation=True)
406
407
        dg = ants.deformation_gradient( ants.image_read( mytx['fwdtransforms'][0] ),
408
                                       to_rotation=True, py_based=True)
409
410
    def test_jacobian(self):
411
        fi = ants.image_read( ants.get_ants_data('r16'))
412
        mi = ants.image_read( ants.get_ants_data('r64'))
413
        fi = ants.resample_image(fi,(128,128),1,0)
414
        mi = ants.resample_image(mi,(128,128),1,0)
415
        mytx = ants.registration(fixed=fi , moving=mi, type_of_transform = ('SyN') )
416
        jac = ants.create_jacobian_determinant_image(fi,mytx['fwdtransforms'][0],1)
417
418
    def test_apply_transforms(self):
419
        fixed = ants.image_read( ants.get_ants_data('r16') )
420
        moving = ants.image_read( ants.get_ants_data('r64') )
421
        fixed = ants.resample_image(fixed, (64,64), 1, 0)
422
        moving = ants.resample_image(moving, (64,64), 1, 0)
423
        mytx = ants.registration(fixed=fixed , moving=moving ,
424
                                type_of_transform = 'SyN' )
425
        mywarpedimage = ants.apply_transforms( fixed=fixed, moving=moving,
426
                                            transformlist=mytx['fwdtransforms'] )
427
428
    def test_apply_transforms_to_points(self):
429
        fixed = ants.image_read( ants.get_ants_data('r16') )
430
        moving = ants.image_read( ants.get_ants_data('r27') )
431
        reg = ants.registration( fixed, moving, 'Affine' )
432
        d = {'x': [128, 127], 'y': [101, 111]}
433
        pts = pd.DataFrame(data=d)
434
        ptsw = ants.apply_transforms_to_points( 2, pts, reg['fwdtransforms'])
435
436
    def test_warped_grid(self):
437
        fi = ants.image_read( ants.get_ants_data( 'r16' ) )
438
        mi = ants.image_read( ants.get_ants_data( 'r64' ) )
439
        mygr = ants.create_warped_grid( mi )
440
        mytx = ants.registration(fixed=fi, moving=mi, type_of_transform = ('SyN') )
441
        mywarpedgrid = ants.create_warped_grid( mi, grid_directions=(False,True),
442
                            transform=mytx['fwdtransforms'], fixed_reference_image=fi )
443
444
    def test_more_registration(self):
445
        fi = ants.image_read(ants.get_ants_data('r16'))
446
        mi = ants.image_read(ants.get_ants_data('r64'))
447
        fi = ants.resample_image(fi, (60,60), 1, 0)
448
        mi = ants.resample_image(mi, (60,60), 1, 0)
449
        mytx = ants.registration(fixed=fi, moving=mi, type_of_transform = 'SyN' )
450
        mytx = ants.registration(fixed=fi, moving=mi, type_of_transform = 'antsRegistrationSyN[t]' )
451
        mytx = ants.registration(fixed=fi, moving=mi, type_of_transform = 'antsRegistrationSyN[b]' )
452
        mytx = ants.registration(fixed=fi, moving=mi, type_of_transform = 'antsRegistrationSyN[s]' )
453
454
    def test_motion_correction(self):
455
        fi = ants.image_read(ants.get_ants_data('ch2'))
456
        mytx = ants.motion_correction( fi )
457
458
    def test_label_image_registration(self):
459
        fi = ants.image_read(ants.get_ants_data('r16'))
460
        mi = ants.image_read(ants.get_ants_data('r64'))
461
        fi = ants.resample_image(fi, (60,60), 1, 0)
462
        mi = ants.resample_image(mi, (60,60), 1, 0)
463
        fi_seg = ants.threshold_image(fi, "Kmeans", 3)-1
464
        mi_seg = ants.threshold_image(mi, "Kmeans", 3)-1
465
        mytx = ants.label_image_registration([fi_seg],
466
                                             [mi_seg],
467
                                             fixed_intensity_images=fi,
468
                                             moving_intensity_images=mi)
469
470
471
    def test_reg_precision_option(self):
472
        # Check that registration and apply transforms works with float and double precision
473
        fi = ants.image_read(ants.get_ants_data("r16"))
474
        mi = ants.image_read(ants.get_ants_data("r64"))
475
        fi = ants.resample_image(fi, (60, 60), 1, 0)
476
        mi = ants.resample_image(mi, (60, 60), 1, 0)
477
        mytx = ants.registration(fixed=fi, moving=mi, type_of_transform="SyN") # should be float precision
478
        info = ants.image_header_info(mytx["fwdtransforms"][0])
479
        self.assertEqual(info['pixeltype'], 'float')
480
        mytx = ants.registration(fixed=fi, moving=mi, type_of_transform="SyN", singleprecision=False)
481
        info = ants.image_header_info(mytx["fwdtransforms"][0])
482
        self.assertEqual(info['pixeltype'], 'double')
483
484
if __name__ == "__main__":
485
    run_tests()