Switch to unified view

a b/ants/registration/build_template.py
1
__all__ = ["build_template"]
2
3
import numpy as np
4
import os
5
import shutil
6
from tempfile import mktemp
7
8
import ants
9
10
def build_template(
11
    initial_template=None,
12
    image_list=None,
13
    iterations=3,
14
    gradient_step=0.2,
15
    blending_weight=0.75,
16
    weights=None,
17
    useNoRigid=True,
18
    output_dir=None,
19
    **kwargs
20
):
21
    """
22
    Estimate an optimal template from an input image_list
23
24
    ANTsR function: N/A
25
26
    Arguments
27
    ---------
28
    initial_template : ANTsImage
29
        initialization for the template building
30
31
    image_list : ANTsImages
32
        images from which to estimate template
33
34
    iterations : integer
35
        number of template building iterations
36
37
    gradient_step : scalar
38
        for shape update gradient
39
40
    blending_weight : scalar
41
        weight for image blending
42
43
    weights : vector
44
        weight for each input image
45
46
    useNoRigid : boolean
47
        equivalent of -y in the script. Template update
48
        step will not use the rigid component if this is True.
49
50
    output_dir : path
51
        directory name where intermediate transforms are written
52
53
    kwargs : keyword args
54
        extra arguments passed to ants registration
55
56
    Returns
57
    -------
58
    ANTsImage
59
60
    Example
61
    -------
62
    >>> import ants
63
    >>> image = ants.image_read( ants.get_ants_data('r16') )
64
    >>> image2 = ants.image_read( ants.get_ants_data('r27') )
65
    >>> image3 = ants.image_read( ants.get_ants_data('r85') )
66
    >>> timage = ants.build_template( image_list = ( image, image2, image3 ) ).resample_image( (45,45))
67
    >>> timagew = ants.build_template( image_list = ( image, image2, image3 ), weights = (5,1,1) )
68
    """
69
    work_dir = mktemp() if output_dir is None else output_dir
70
71
    def make_outprefix(k: int):
72
        os.makedirs(os.path.join(work_dir, f"img{k:04d}"), exist_ok=True)
73
        return os.path.join(work_dir, f"img{k:04d}", "out")
74
75
    if "type_of_transform" not in kwargs:
76
        type_of_transform = "SyN"
77
    else:
78
        type_of_transform = kwargs.pop("type_of_transform")
79
80
    if weights is None:
81
        weights = np.repeat(1.0 / len(image_list), len(image_list))
82
    weights = [x / sum(weights) for x in weights]
83
    if initial_template is None:
84
        initial_template = image_list[0] * 0
85
        for i in range(len(image_list)):
86
            temp = image_list[i] * weights[i]
87
            temp = ants.resample_image_to_target(temp, initial_template)
88
            initial_template = initial_template + temp
89
90
    xavg = initial_template.clone()
91
    for i in range(iterations):
92
        affinelist = []
93
        for k in range(len(image_list)):
94
            w1 = ants.registration(
95
                xavg, image_list[k], type_of_transform=type_of_transform, outprefix=make_outprefix(k), **kwargs
96
            )
97
            L = len(w1["fwdtransforms"])
98
            # affine is the last one
99
            affinelist.append(w1["fwdtransforms"][L-1])
100
101
            if k == 0:
102
                if L == 2:
103
                    wavg = ants.image_read(w1["fwdtransforms"][0]) * weights[k]
104
                xavgNew = w1["warpedmovout"] * weights[k]
105
            else:
106
                if L == 2:
107
                    wavg = wavg + ants.image_read(w1["fwdtransforms"][0]) * weights[k]
108
                xavgNew = xavgNew + w1["warpedmovout"] * weights[k]
109
110
        if useNoRigid:
111
            avgaffine = ants.average_affine_transform_no_rigid(affinelist)
112
        else:
113
            avgaffine = ants.average_affine_transform(affinelist)
114
        afffn = os.path.join(work_dir, "avgAffine.mat")
115
        ants.write_transform(avgaffine, afffn)
116
117
        if L == 2:
118
            print(wavg.abs().mean())
119
            wscl = (-1.0) * gradient_step
120
            wavg = wavg * wscl
121
            # apply affine to the nonlinear?
122
            # need to save the average
123
            wavgA = ants.apply_transforms(fixed=xavgNew, moving=wavg, imagetype=1, transformlist=afffn, whichtoinvert=[1])
124
            wavgfn = os.path.join(work_dir, "avgWarp.nii.gz")
125
            ants.image_write(wavgA, wavgfn)
126
            xavg = ants.apply_transforms(fixed=xavgNew, moving=xavgNew, transformlist=[wavgfn, afffn], whichtoinvert=[0, 1])
127
        else:
128
            xavg = ants.apply_transforms(fixed=xavgNew, moving=xavgNew, transformlist=[afffn], whichtoinvert=[1])
129
            
130
        if blending_weight is not None:
131
            xavg = xavg * blending_weight + ants.iMath(xavg, "Sharpen") * (
132
                1.0 - blending_weight
133
            )
134
135
    if output_dir is None:
136
        shutil.rmtree(work_dir)
137
    return xavg