a b/ants/contrib/sampling/transforms.py
1
"""
2
Various data augmentation transforms for ANTsImage types
3
4
List of Transformations:
5
======================
6
- CastIntensity
7
- BlurIntensity
8
- NormalizeIntensity
9
- RescaleIntensity
10
- ShiftScaleIntensity
11
- SigmoidIntensity
12
======================
13
- FlipImage
14
- TranslateImage
15
16
TODO
17
----
18
- RotateImage
19
- ShearImage
20
- ScaleImage
21
- DeformImage
22
- PadImage
23
- HistogramEqualizeIntensity
24
- TruncateIntensity
25
- SharpenIntensity
26
- MorpholigicalIntensity
27
    - MD
28
    - ME
29
    - MO
30
    - MC
31
    - GD
32
    - GE
33
    - GO
34
    - GC
35
"""
36
__all__ = ['CastIntensity',
37
           'BlurIntensity',
38
           'LocallyBlurIntensity',
39
           'NormalizeIntensity',
40
           'RescaleIntensity',
41
           'ShiftScaleIntensity',
42
           'SigmoidIntensity',
43
           'FlipImage',
44
           'ScaleImage',
45
           'TranslateImage',
46
           'MultiResolutionImage']
47
48
from ... import utils
49
from ...core import ants_image as iio
50
51
52
class MultiResolutionImage(object):
53
    """
54
    Generate a set of images at multiple resolutions from an original image
55
    """
56
    def __init__(self, levels=4, keep_shape=False):
57
        self.levels = levels
58
        self.keep_shape = keep_shape
59
60
    def transform(self, X, y=None):
61
        """
62
        Generate a set of multi-resolution ANTsImage types
63
64
        Arguments
65
        ---------
66
        X : ANTsImage
67
            image to transform
68
69
        y : ANTsImage (optional)
70
            another image to transform
71
72
        Example
73
        -------
74
        >>> import ants
75
        >>> multires = ants.contrib.MultiResolutionImage(levels=4)
76
        >>> img = ants.image_read(ants.get_data('r16'))
77
        >>> imgs = multires.transform(img)
78
        """
79
        insuffix = X._libsuffix
80
        multires_fn = utils.get_lib_fn('multiResolutionAntsImage%s' % (insuffix))
81
        casted_ptrs = multires_fn(X.pointer, self.levels)
82
        
83
        imgs = []
84
        for casted_ptr in casted_ptrs:
85
            img = iio.ANTsImage(pixeltype=X.pixeltype, dimension=X.dimension,
86
                                components=X.components, pointer=casted_ptr)
87
            if self.keep_shape:
88
                img = img.resample_image_to_target(X)
89
            imgs.append(img)
90
91
        return imgs
92
93
94
## Intensity Transforms ##
95
96
class CastIntensity(object):
97
    """
98
    Cast the pixeltype of an ANTsImage to a given type. 
99
    This code uses the C++ ITK library directly, so it is fast.
100
    
101
    NOTE: This offers a ~2.5x speedup over using img.clone(pixeltype):
102
    
103
    Timings vs Cloning
104
    ------------------
105
    >>> import ants
106
    >>> import time
107
    >>> caster = ants.contrib.CastIntensity('float')
108
    >>> img = ants.image_read(ants.get_data('mni')).clone('unsigned int')
109
    >>> s = time.time()
110
    >>> for i in range(1000):
111
    ...     img_float = caster.transform(img)
112
    >>> e = time.time()
113
    >>> print(e - s) # 9.6s
114
    >>> s = time.time()
115
    >>> for i in range(1000):
116
    ...     img_float = img.clone('float')
117
    >>> e = time.time()
118
    >>> print(e - s) # 25.3s
119
    """
120
    def __init__(self, pixeltype):
121
        """
122
        Initialize a CastIntensity transform
123
    
124
        Arguments
125
        ---------
126
        pixeltype : string
127
            pixeltype to which images will be casted
128
129
        Example
130
        -------
131
        >>> import ants
132
        >>> caster = ants.contrib.CastIntensity('float')
133
        """
134
        self.pixeltype = pixeltype
135
136
    def transform(self, X, y=None):
137
        """
138
        Transform an image by casting its type
139
140
        Arguments
141
        ---------
142
        X : ANTsImage
143
            image to cast
144
145
        y : ANTsImage (optional)
146
            another image to cast.
147
148
        Example
149
        -------
150
        >>> import ants
151
        >>> caster = ants.contrib.CastIntensity('float')
152
        >>> img2d = ants.image_read(ants.get_data('r16')).clone('unsigned int')
153
        >>> img2d_float = caster.transform(img2d)
154
        >>> print(img2d.pixeltype, '- ', img2d_float.pixeltype)
155
        >>> img3d = ants.image_read(ants.get_data('mni')).clone('unsigned int')
156
        >>> img3d_float = caster.transform(img3d)
157
        >>> print(img3d.pixeltype, ' - ' , img3d_float.pixeltype)
158
        """
159
        insuffix = X._libsuffix
160
        outsuffix = '%s%i' % (utils.short_ptype(self.pixeltype), X.dimension)
161
        cast_fn = utils.get_lib_fn('castAntsImage%s%s' % (insuffix, outsuffix))
162
        casted_ptr = cast_fn(X.pointer)
163
        return iio.ANTsImage(pixeltype=self.pixeltype, dimension=X.dimension,
164
                             components=X.components, pointer=casted_ptr)
165
166
167
class BlurIntensity(object):
168
    """
169
    Transform for blurring the intensity of an ANTsImage
170
    using a Gaussian Filter
171
    """
172
    def __init__(self, sigma, width):
173
        """
174
        Initialize a BlurIntensity transform
175
176
        Arguments
177
        ---------
178
        sigma : float
179
            variance of gaussian kernel intensity
180
            increasing this value increasing the amount
181
            of blur
182
183
        width : int
184
            width of gaussian kernel shape
185
            increasing this value increase the number of
186
            neighboring voxels which are used for blurring
187
188
        Example
189
        -------
190
        >>> import ants
191
        >>> blur = ants.contrib.BlurIntensity(2,3)
192
        """
193
        self.sigma = sigma
194
        self.width = width
195
196
    def transform(self, X, y=None):
197
        """
198
        Blur an image by applying a gaussian filter.
199
200
        Arguments
201
        ---------
202
        X : ANTsImage
203
            image to transform
204
205
        y : ANTsImage (optional)
206
            another image to transform.
207
208
        Example
209
        -------
210
        >>> import ants
211
        >>> blur = ants.contrib.BlurIntensity(2,3)
212
        >>> img2d = ants.image_read(ants.get_data('r16'))
213
        >>> img2d_b = blur.transform(img2d)
214
        >>> ants.plot(img2d)
215
        >>> ants.plot(img2d_b)
216
        >>> img3d = ants.image_read(ants.get_data('mni'))
217
        >>> img3d_b = blur.transform(img3d)
218
        >>> ants.plot(img3d)
219
        >>> ants.plot(img3d_b)
220
        """
221
        if X.pixeltype != 'float':
222
            raise ValueError('image.pixeltype must be float ... use TypeCast transform or clone to float')
223
224
        insuffix = X._libsuffix
225
        cast_fn = utils.get_lib_fn('blurAntsImage%s' % (insuffix))
226
        casted_ptr = cast_fn(X.pointer, self.sigma, self.width)
227
        return iio.ANTsImage(pixeltype=X.pixeltype, dimension=X.dimension,
228
                             components=X.components, pointer=casted_ptr,
229
                             origin=X.origin)
230
231
232
class LocallyBlurIntensity(object):
233
    """
234
    Blur an ANTsImage locally using a gradient anisotropic
235
    diffusion filter, thereby preserving the sharpeness of edges as best 
236
    as possible.
237
    """
238
    def __init__(self, conductance=1, iters=5):
239
        self.conductance = conductance
240
        self.iters = iters
241
242
    def transform(self, X, y=None):
243
        """
244
        Locally blur an image by applying a gradient anisotropic diffusion filter.
245
246
        Arguments
247
        ---------
248
        X : ANTsImage
249
            image to transform
250
251
        y : ANTsImage (optional)
252
            another image to transform.
253
254
        Example
255
        -------
256
        >>> import ants
257
        >>> blur = ants.contrib.LocallyBlurIntensity(1,5)
258
        >>> img2d = ants.image_read(ants.get_data('r16'))
259
        >>> img2d_b = blur.transform(img2d)
260
        >>> ants.plot(img2d)
261
        >>> ants.plot(img2d_b)
262
        >>> img3d = ants.image_read(ants.get_data('mni'))
263
        >>> img3d_b = blur.transform(img3d)
264
        >>> ants.plot(img3d)
265
        >>> ants.plot(img3d_b)
266
        """
267
        #if X.pixeltype != 'float':
268
        #    raise ValueError('image.pixeltype must be float ... use TypeCast transform or clone to float')
269
        insuffix = X._libsuffix
270
        cast_fn = utils.get_lib_fn('locallyBlurAntsImage%s' % (insuffix))
271
        casted_ptr = cast_fn(X.pointer, self.iters, self.conductance)
272
        return iio.ANTsImage(pixeltype=X.pixeltype, dimension=X.dimension,
273
                             components=X.components, pointer=casted_ptr)
274
275
276
class NormalizeIntensity(object):
277
    """
278
    Normalize the intensity values of an ANTsImage to have
279
    zero mean and unit variance
280
281
    NOTE: this transform is more-or-less the same in speed
282
    as an equivalent numpy+scikit-learn solution.
283
284
    Timing vs Numpy+Scikit-Learn
285
    ----------------------------
286
    >>> import ants
287
    >>> import numpy as np
288
    >>> from sklearn.preprocessing import StandardScaler
289
    >>> import time
290
    >>> img = ants.image_read(ants.get_data('mni'))
291
    >>> arr = img.numpy().reshape(1,-1)
292
    >>> normalizer = ants.contrib.NormalizeIntensity()
293
    >>> normalizer2 = StandardScaler()
294
    >>> s = time.time()
295
    >>> for i in range(100):
296
    ...     img_scaled = normalizer.transform(img)
297
    >>> e = time.time()
298
    >>> print(e - s) # 3.3s
299
    >>> s = time.time()
300
    >>> for i in range(100):
301
    ...     arr_scaled = normalizer2.fit_transform(arr)
302
    >>> e = time.time()
303
    >>> print(e - s) # 3.5s
304
    """
305
    def __init__(self):
306
        """
307
        Initialize a NormalizeIntensity transform
308
        """
309
        pass
310
311
    def transform(self, X, y=None):
312
        """
313
        Transform an image by normalizing its intensity values to 
314
        have zero mean and unit variance.
315
316
        Arguments
317
        ---------
318
        X : ANTsImage
319
            image to transform
320
321
        y : ANTsImage (optional)
322
            another image to transform.
323
324
        Example
325
        -------
326
        >>> import ants
327
        >>> normalizer = ants.contrib.NormalizeIntensity()
328
        >>> img2d = ants.image_read(ants.get_data('r16'))
329
        >>> img2d_r = normalizer.transform(img2d)
330
        >>> print(img2d.mean(), ',', img2d.std(), ' -> ', img2d_r.mean(), ',', img2d_r.std())
331
        >>> img3d = ants.image_read(ants.get_data('mni'))
332
        >>> img3d_r = normalizer.transform(img3d)
333
        >>> print(img3d.mean(), ',' , img3d.std(), ',', ' -> ', img3d_r.mean(), ',' , img3d_r.std())  
334
        """
335
        if X.pixeltype != 'float':
336
            raise ValueError('image.pixeltype must be float ... use TypeCast transform or clone to float')
337
338
        insuffix = X._libsuffix
339
        cast_fn = utils.get_lib_fn('normalizeAntsImage%s' % (insuffix))
340
        casted_ptr = cast_fn(X.pointer)
341
        return iio.ANTsImage(pixeltype=X.pixeltype, dimension=X.dimension,
342
                             components=X.components, pointer=casted_ptr)
343
344
345
class RescaleIntensity(object):
346
    """
347
    Rescale the pixeltype of an ANTsImage linearly to be between a given
348
    minimum and maximum value. 
349
    This code uses the C++ ITK library directly, so it is fast.
350
    
351
    NOTE: this offered a ~5x speedup over using built-in arithmetic operations in ANTs.
352
    It is also more-or-less the same in speed as an equivalent numpy+scikit-learn
353
    solution.
354
355
    Timing vs Built-in Operations
356
    -----------------------------
357
    >>> import ants
358
    >>> import time
359
    >>> rescaler = ants.contrib.RescaleIntensity(0,1)
360
    >>> img = ants.image_read(ants.get_data('mni'))
361
    >>> s = time.time()
362
    >>> for i in range(100):
363
    ...     img_float = rescaler.transform(img)
364
    >>> e = time.time()
365
    >>> print(e - s) # 2.8s
366
    >>> s = time.time()
367
    >>> for i in range(100):
368
    ...     maxval = img.max()
369
    ...     img_float = (img - maxval) / (maxval - img.min())
370
    >>> e = time.time()
371
    >>> print(e - s) # 13.9s
372
373
    Timing vs Numpy+Scikit-Learn
374
    ----------------------------
375
    >>> import ants
376
    >>> import numpy as np
377
    >>> from sklearn.preprocessing import MinMaxScaler
378
    >>> import time
379
    >>> img = ants.image_read(ants.get_data('mni'))
380
    >>> arr = img.numpy().reshape(1,-1)
381
    >>> rescaler = ants.contrib.RescaleIntensity(-1,1)
382
    >>> rescaler2 = MinMaxScaler((-1,1)).fit(arr)
383
    >>> s = time.time()
384
    >>> for i in range(100):
385
    ...     img_scaled = rescaler.transform(img)
386
    >>> e = time.time()
387
    >>> print(e - s) # 2.8s
388
    >>> s = time.time()
389
    >>> for i in range(100):
390
    ...     arr_scaled = rescaler2.transform(arr)
391
    >>> e = time.time()
392
    >>> print(e - s) # 3s
393
    """
394
395
    def __init__(self, min_val, max_val):
396
        """
397
        Initialize a RescaleIntensity transform.
398
399
        Arguments
400
        ---------
401
        min_val : float
402
            minimum value to which image(s) will be rescaled
403
404
        max_val : float
405
            maximum value to which image(s) will be rescaled
406
        
407
        Example
408
        -------
409
        >>> import ants
410
        >>> rescaler = ants.contrib.RescaleIntensity(0,1)
411
        """
412
        self.min_val = min_val
413
        self.max_val = max_val
414
415
    def transform(self, X, y=None):
416
        """
417
        Transform an image by linearly rescaling its intensity to
418
        be between a minimum and maximum value
419
420
        Arguments
421
        ---------
422
        X : ANTsImage
423
            image to transform
424
425
        y : ANTsImage (optional)
426
            another image to transform.
427
428
        Example
429
        -------
430
        >>> import ants
431
        >>> rescaler = ants.contrib.RescaleIntensity(0,1)
432
        >>> img2d = ants.image_read(ants.get_data('r16'))
433
        >>> img2d_r = rescaler.transform(img2d)
434
        >>> print(img2d.min(), ',', img2d.max(), ' -> ', img2d_r.min(), ',', img2d_r.max())
435
        >>> img3d = ants.image_read(ants.get_data('mni'))
436
        >>> img3d_r = rescaler.transform(img3d)
437
        >>> print(img3d.min(), ',' , img3d.max(), ' -> ', img3d_r.min(), ',' , img3d_r.max())
438
        """
439
        if X.pixeltype != 'float':
440
            raise ValueError('image.pixeltype must be float ... use TypeCast transform or clone to float')
441
442
        insuffix = X._libsuffix
443
        cast_fn = utils.get_lib_fn('rescaleAntsImage%s' % (insuffix))
444
        casted_ptr = cast_fn(X.pointer, self.min_val, self.max_val)
445
        return iio.ANTsImage(pixeltype=X.pixeltype, dimension=X.dimension,
446
                             components=X.components, pointer=casted_ptr)
447
448
449
class ShiftScaleIntensity(object):
450
    """
451
    Shift and scale the intensity of an ANTsImage
452
    """
453
    def __init__(self, shift, scale):
454
        """
455
        Initialize a ShiftScaleIntensity transform
456
457
        Arguments
458
        ---------
459
        shift : float
460
            shift all of the intensity values by the given amount through addition.
461
            For example, if the minimum image value is 0.0 and the shift
462
            is 10.0, then the new minimum value (before scaling) will be 10.0
463
464
        scale : float
465
            scale all the intensity values by the given amount through multiplication.
466
            For example, if the min/max image values are 10/20 and the scale
467
            is 2.0, then then new min/max values will be 20/40
468
469
        Example
470
        -------
471
        >>> import ants
472
        >>> shiftscaler = ants.contrib.ShiftScaleIntensity(shift=10, scale=2)
473
        """
474
        self.shift = shift
475
        self.scale = scale
476
477
    def transform(self, X, y=None):
478
        """
479
        Transform an image by shifting and scaling its intensity values.
480
481
        Arguments
482
        ---------
483
        X : ANTsImage
484
            image to transform
485
486
        y : ANTsImage (optional)
487
            another image to transform.
488
489
        Example
490
        -------
491
        >>> import ants
492
        >>> shiftscaler = ants.contrib.ShiftScaleIntensity(10,2.)
493
        >>> img2d = ants.image_read(ants.get_data('r16'))
494
        >>> img2d_r = shiftscaler.transform(img2d)
495
        >>> print(img2d.min(), ',', img2d.max(), ' -> ', img2d_r.min(), ',', img2d_r.max())
496
        >>> img3d = ants.image_read(ants.get_data('mni'))
497
        >>> img3d_r = shiftscaler.transform(img3d)
498
        >>> print(img3d.min(), ',' , img3d.max(), ',', ' -> ', img3d_r.min(), ',' , img3d_r.max())  
499
        """
500
        if X.pixeltype != 'float':
501
            raise ValueError('image.pixeltype must be float ... use TypeCast transform or clone to float')
502
503
        insuffix = X._libsuffix
504
        cast_fn = utils.get_lib_fn('shiftScaleAntsImage%s' % (insuffix))
505
        casted_ptr = cast_fn(X.pointer, self.scale, self.shift)
506
        return iio.ANTsImage(pixeltype=X.pixeltype, dimension=X.dimension,
507
                             components=X.components, pointer=casted_ptr)
508
509
510
class SigmoidIntensity(object):
511
    """
512
    Transform an image using a sigmoid function
513
    """
514
    def __init__(self, min_val, max_val, alpha, beta):
515
        """
516
        Initialize a SigmoidIntensity transform
517
518
        Arguments
519
        ---------
520
        min_val : float
521
            minimum value
522
523
        max_val : float
524
            maximum value
525
526
        alpha : float
527
            alpha value for sigmoid
528
529
        beta : flaot
530
            beta value for sigmoid
531
532
        Example
533
        -------
534
        >>> import ants
535
        >>> sigscaler = ants.contrib.SigmoidIntensity(0,1,1,1)
536
        """
537
        self.min_val = min_val
538
        self.max_val = max_val
539
        self.alpha = alpha
540
        self.beta = beta
541
542
    def transform(self, X, y=None):
543
        """
544
        Transform an image by applying a sigmoid function.
545
546
        Arguments
547
        ---------
548
        X : ANTsImage
549
            image to transform
550
551
        y : ANTsImage (optional)
552
            another image to transform.
553
554
        Example
555
        -------
556
        >>> import ants
557
        >>> sigscaler = ants.contrib.SigmoidIntensity(0,1,1,1)
558
        >>> img2d = ants.image_read(ants.get_data('r16'))
559
        >>> img2d_r = sigscaler.transform(img2d)
560
        >>> img3d = ants.image_read(ants.get_data('mni'))
561
        >>> img3d_r = sigscaler.transform(img3d)
562
        """
563
        if X.pixeltype != 'float':
564
            raise ValueError('image.pixeltype must be float ... use TypeCast transform or clone to float')
565
566
        insuffix = X._libsuffix
567
        cast_fn = utils.get_lib_fn('sigmoidAntsImage%s' % (insuffix))
568
        casted_ptr = cast_fn(X.pointer, self.min_val, self.max_val, self.alpha, self.beta)
569
        return iio.ANTsImage(pixeltype=X.pixeltype, dimension=X.dimension,
570
                             components=X.components, pointer=casted_ptr)
571
572
573
## Physical Transforms ##
574
575
class FlipImage(object):
576
    """
577
    Transform an image by flipping two axes. 
578
    """
579
    def __init__(self, axis1, axis2):
580
        """
581
        Initialize a SigmoidIntensity transform
582
583
        Arguments
584
        ---------
585
        axis1 : int
586
            axis to flip
587
588
        axis2 : int
589
            other axis to flip
590
591
        Example
592
        -------
593
        >>> import ants
594
        >>> flipper = ants.contrib.FlipImage(0,1)
595
        """
596
        self.axis1 = axis1
597
        self.axis2 = axis2
598
599
    def transform(self, X, y=None):
600
        """
601
        Transform an image by applying a sigmoid function.
602
603
        Arguments
604
        ---------
605
        X : ANTsImage
606
            image to transform
607
608
        y : ANTsImage (optional)
609
            another image to transform.
610
611
        Example
612
        -------
613
        >>> import ants
614
        >>> flipper = ants.contrib.FlipImage(0,1)
615
        >>> img2d = ants.image_read(ants.get_data('r16'))
616
        >>> img2d_r = flipper.transform(img2d)
617
        >>> ants.plot(img2d)
618
        >>> ants.plot(img2d_r)
619
        >>> flipper2 = ants.contrib.FlipImage(1,0)
620
        >>> img2d = ants.image_read(ants.get_data('r16'))
621
        >>> img2d_r = flipper2.transform(img2d)
622
        >>> ants.plot(img2d)
623
        >>> ants.plot(img2d_r)
624
        """
625
        if X.pixeltype != 'float':
626
            raise ValueError('image.pixeltype must be float ... use TypeCast transform or clone to float')
627
628
        insuffix = X._libsuffix
629
        cast_fn = utils.get_lib_fn('flipAntsImage%s' % (insuffix))
630
        casted_ptr = cast_fn(X.pointer, self.axis1, self.axis2)
631
        return iio.ANTsImage(pixeltype=X.pixeltype, dimension=X.dimension,
632
                             components=X.components, pointer=casted_ptr,
633
                             origin=X.origin)
634
635
636
class TranslateImage(object):
637
    """
638
    Translate an image in physical space. This function calls 
639
    highly optimized ITK/C++ code.
640
    """
641
    def __init__(self, translation, reference=None, interp='linear'):
642
        """
643
        Initialize a TranslateImage transform
644
645
        Arguments
646
        ---------
647
        translation : list, tuple, or numpy.ndarray
648
            absolute pixel transformation in each axis
649
650
        reference : ANTsImage (optional)
651
            image which provides the reference physical space in which
652
            to perform the transform
653
654
        interp : string
655
            type of interpolation to use
656
            options: linear, nearest
657
658
        Example
659
        -------
660
        >>> import ants
661
        >>> translater = ants.contrib.TranslateImage((10,10), interp='linear')
662
        """
663
        if interp not in {'linear', 'nearest'}:
664
            raise ValueError('interp must be one of {linear, nearest}')
665
666
        self.translation = list(translation)
667
        self.reference = reference
668
        self.interp = interp
669
670
    def transform(self, X, y=None):
671
        """
672
        Example
673
        -------
674
        >>> import ants
675
        >>> translater = ants.contrib.TranslateImage((40,0))
676
        >>> img2d = ants.image_read(ants.get_data('r16'))
677
        >>> img2d_r = translater.transform(img2d)
678
        >>> ants.plot(img2d, img2d_r)
679
        >>> translater = ants.contrib.TranslateImage((40,0,0))
680
        >>> img3d = ants.image_read(ants.get_data('mni'))
681
        >>> img3d_r = translater.transform(img3d)
682
        >>> ants.plot(img3d, img3d_r, axis=2)
683
        """
684
        if X.pixeltype != 'float':
685
            raise ValueError('image.pixeltype must be float ... use TypeCast transform or clone to float')
686
687
        if len(self.translation) != X.dimension:
688
            raise ValueError('must give a translation value for each image dimension')
689
690
        if self.reference is None:
691
            reference = X
692
        else:
693
            reference = self.reference
694
695
        insuffix = X._libsuffix
696
        cast_fn = utils.get_lib_fn('translateAntsImage%s_%s' % (insuffix, self.interp))
697
        casted_ptr = cast_fn(X.pointer, reference.pointer, self.translation)
698
        return iio.ANTsImage(pixeltype=X.pixeltype, dimension=X.dimension,
699
                             components=X.components, pointer=casted_ptr)
700
701
702
class ScaleImage(object):
703
    """
704
    Scale an image in physical space. This function calls 
705
    highly optimized ITK/C++ code.
706
    """
707
    def __init__(self, scale, reference=None, interp='linear'):
708
        """
709
        Initialize a TranslateImage transform
710
711
        Arguments
712
        ---------
713
        scale : list, tuple, or numpy.ndarray
714
            relative scaling along each axis
715
716
        reference : ANTsImage (optional)
717
            image which provides the reference physical space in which
718
            to perform the transform
719
720
        interp : string
721
            type of interpolation to use
722
            options: linear, nearest
723
724
        Example
725
        -------
726
        >>> import ants
727
        >>> translater = ants.contrib.TranslateImage((10,10), interp='linear')
728
        """
729
        if interp not in {'linear', 'nearest'}:
730
            raise ValueError('interp must be one of {linear, nearest}')
731
732
        self.scale = list(scale)
733
        self.reference = reference
734
        self.interp = interp
735
736
    def transform(self, X, y=None):
737
        """
738
        Example
739
        -------
740
        >>> import ants
741
        >>> scaler = ants.contrib.ScaleImage((1.2,1.2))
742
        >>> img2d = ants.image_read(ants.get_data('r16'))
743
        >>> img2d_r = scaler.transform(img2d)
744
        >>> ants.plot(img2d, img2d_r)
745
        >>> scaler = ants.contrib.ScaleImage((1.2,1.2,1.2))
746
        >>> img3d = ants.image_read(ants.get_data('mni'))
747
        >>> img3d_r = scaler.transform(img3d)
748
        >>> ants.plot(img3d, img3d_r)
749
        """
750
        if X.pixeltype != 'float':
751
            raise ValueError('image.pixeltype must be float ... use TypeCast transform or clone to float')
752
753
        if len(self.scale) != X.dimension:
754
            raise ValueError('must give a scale value for each image dimension')
755
756
        if self.reference is None:
757
            reference = X
758
        else:
759
            reference = self.reference
760
761
        insuffix = X._libsuffix
762
        cast_fn = utils.get_lib_fn('scaleAntsImage%s_%s' % (insuffix, self.interp))
763
        casted_ptr = cast_fn(X.pointer, reference.pointer, self.scale)
764
        return iio.ANTsImage(pixeltype=X.pixeltype, dimension=X.dimension,
765
                             components=X.components, pointer=casted_ptr)
766