a b/ants/segmentation/joint_label_fusion.py
1
"""
2
Joint Label Fusion algorithm
3
"""
4
5
__all__ = ["joint_label_fusion", "local_joint_label_fusion"]
6
7
import os
8
import numpy as np
9
import warnings
10
from pathlib import Path
11
from tempfile import NamedTemporaryFile
12
from tempfile import mktemp
13
import glob
14
import re
15
import math
16
17
import ants
18
from ants.internal import get_lib_fn, get_pointer_string, process_arguments
19
20
21
def joint_label_fusion(
22
    target_image,
23
    target_image_mask,
24
    atlas_list,
25
    beta=4,
26
    rad=2,
27
    label_list=None,
28
    rho=0.01,
29
    usecor=False,
30
    r_search=3,
31
    nonnegative=False,
32
    no_zeroes=False,
33
    max_lab_plus_one=False,
34
    output_prefix=None,
35
    verbose=False,
36
):
37
    """
38
    A multiple atlas voting scheme to customize labels for a new subject.
39
    This function will also perform intensity fusion. It almost directly
40
    calls the C++ in the ANTs executable so is much faster than other
41
    variants in ANTsR.
42
43
    One may want to normalize image intensities for each input image before
44
    passing to this function. If no labels are passed, we do intensity fusion.
45
    Note on computation time: the underlying C++ is multithreaded.
46
    You can control the number of threads by setting the environment
47
    variable ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS e.g. to use all or some
48
    of your CPUs. This will improve performance substantially.
49
    For instance, on a macbook pro from 2015, 8 cores improves speed by about 4x.
50
51
    ANTsR function: `jointLabelFusion`
52
53
    Arguments
54
    ---------
55
    target_image : ANTsImage
56
        image to be approximated
57
58
    target_image_mask : ANTsImage
59
        mask with value 1
60
61
    atlas_list : list of ANTsImage types
62
        list containing intensity images
63
64
    beta : scalar
65
        weight sharpness, default to 2
66
67
    rad : scalar
68
        neighborhood radius, default to 2
69
70
    label_list : list of ANTsImage types (optional)
71
        list containing images with segmentation labels
72
73
    rho : scalar
74
        ridge penalty increases robustness to outliers but also makes image converge to average
75
76
    usecor : boolean
77
        employ correlation as local similarity
78
79
    r_search : scalar
80
        radius of search, default is 3
81
82
    nonnegative : boolean
83
        constrain weights to be non-negative
84
85
    no_zeroes : boolean
86
        this will constrain the solution only to voxels that are always non-zero in the label list
87
88
    max_lab_plus_one : boolean
89
        this will add max label plus one to the non-zero parts of each label where the target mask
90
        is greater than one.  NOTE: this will have a side effect of adding to the original label
91
        images that are passed to the program.  It also guarantees that every position in the
92
        labels have some label, rather than none.  Ie it guarantees to explicitly parcellate the
93
        input data.
94
95
    output_prefix: string
96
        file prefix for storing output probabilityimages to disk
97
98
    verbose : boolean
99
        whether to show status updates
100
101
    Returns
102
    -------
103
    dictionary w/ following key/value pairs:
104
        `segmentation` : ANTsImage
105
            segmentation image
106
107
        `intensity` : ANTsImage
108
            intensity image
109
110
        `probabilityimages` : list of ANTsImage types
111
            probability map image for each label
112
113
        `segmentation_numbers` : list of numbers
114
            segmentation label (number, int) for each probability map
115
116
117
    Example
118
    -------
119
    >>> import ants
120
    >>> ref = ants.image_read( ants.get_ants_data('r16'))
121
    >>> ref = ants.resample_image(ref, (50,50),1,0)
122
    >>> ref = ants.iMath(ref,'Normalize')
123
    >>> mi = ants.image_read( ants.get_ants_data('r27'))
124
    >>> mi2 = ants.image_read( ants.get_ants_data('r30'))
125
    >>> mi3 = ants.image_read( ants.get_ants_data('r62'))
126
    >>> mi4 = ants.image_read( ants.get_ants_data('r64'))
127
    >>> mi5 = ants.image_read( ants.get_ants_data('r85'))
128
    >>> refmask = ants.get_mask(ref)
129
    >>> refmask = ants.iMath(refmask,'ME',2) # just to speed things up
130
    >>> ilist = [mi,mi2,mi3,mi4,mi5]
131
    >>> seglist = [None]*len(ilist)
132
    >>> for i in range(len(ilist)):
133
    >>>     ilist[i] = ants.iMath(ilist[i],'Normalize')
134
    >>>     mytx = ants.registration(fixed=ref , moving=ilist[i] ,
135
    >>>         type_of_transform = ('Affine') )
136
    >>>     mywarpedimage = ants.apply_transforms(fixed=ref,moving=ilist[i],
137
    >>>             transformlist=mytx['fwdtransforms'])
138
    >>>     ilist[i] = mywarpedimage
139
    >>>     seg = ants.threshold_image(ilist[i],'Otsu', 3)
140
    >>>     seglist[i] = ( seg ) + ants.threshold_image( seg, 1, 3 ).morphology( operation='dilate', radius=3 )
141
    >>> r = 2
142
    >>> pp = ants.joint_label_fusion(ref, refmask, ilist, r_search=2,
143
    >>>                     label_list=seglist, rad=[r]*ref.dimension )
144
    >>> pp = ants.joint_label_fusion(ref,refmask,ilist, r_search=2, rad=[r]*ref.dimension)
145
    """
146
    segpixtype = "unsigned int"
147
    if (label_list is None) or (np.any([l is None for l in label_list])):
148
        doJif = True
149
    else:
150
        doJif = False
151
152
    if not doJif:
153
        if len(label_list) != len(atlas_list):
154
            raise ValueError("len(label_list) != len(atlas_list)")
155
        if no_zeroes:
156
            for label in label_list:
157
                target_image_mask[label == 0] = 0
158
        inlabs = set()
159
        for label in label_list:
160
            values = np.unique(label[target_image_mask != 0 and label != 0])
161
            inlabs = inlabs.union(values)
162
        inlabs = sorted(inlabs)
163
        maxLab = max(inlabs)
164
        if max_lab_plus_one:
165
            for label in label_list:
166
                label[label == 0] = maxLab + 1
167
        mymask = target_image_mask.clone()
168
    else:
169
        mymask = target_image_mask
170
171
###### security issues with mktemp but could not figure out the right solution
172
###### NamedTemporaryFile creates a file with permissions:
173
###### -rw-------  1 stnava  staff
174
###### whereas mktemp gives
175
###### -rw-r--r--  1 stnava  staff
176
###### the latter is what we want - one solution is to use chmod via os but
177
###### am currently too lazy to change one line of code to two or more everywhere
178
179
#    osegfn = NamedTemporaryFile(prefix="antsr", suffix="myseg.nii.gz",delete=False).name
180
    osegfn = mktemp(prefix="antsr", suffix="myseg.nii.gz")
181
    # segdir = osegfn.replace(os.path.basename(osegfn),'')
182
183
    if os.path.exists(osegfn):
184
        os.remove(osegfn)
185
186
    if output_prefix is None:
187
#        probs = NamedTemporaryFile(prefix="antsr", suffix="prob%02d.nii.gz",delete=False).name
188
        probs = mktemp(prefix="antsr", suffix="prob%02d.nii.gz")
189
        probsbase = os.path.basename(probs)
190
        tdir = probs.replace(probsbase, "")
191
        searchpattern = probsbase.replace("%02d", "*")
192
193
    if output_prefix is not None:
194
        probs = output_prefix + "prob%02d.nii.gz"
195
        probpath = Path(probs).parent
196
        Path(probpath).mkdir(parents=True, exist_ok=True)
197
        probsbase = os.path.basename(probs)
198
        tdir = probs.replace(probsbase, "")
199
        searchpattern = probsbase.replace("%02d", "*")
200
201
    mydim = target_image_mask.dimension
202
    if not doJif:
203
        # not sure if these should be allocated or what their size should be
204
        outimg = label_list[1].clone(segpixtype)
205
        outimgi = target_image * 0
206
207
        outimg_ptr = get_pointer_string(outimg)
208
        outimgi_ptr = get_pointer_string(outimgi)
209
        outs = "[%s,%s,%s]" % (outimg_ptr, outimgi_ptr, probs)
210
    else:
211
        outimgi = target_image * 0
212
        outs = get_pointer_string(outimgi)
213
214
    mymask = mymask.clone(segpixtype)
215
    if (not isinstance(rad, (tuple, list))) or (len(rad) == 1):
216
        myrad = [rad] * mydim
217
    else:
218
        myrad = rad
219
220
    if len(myrad) != mydim:
221
        raise ValueError("path radius dimensionality must equal image dimensionality")
222
223
    myrad = "x".join([str(mr) for mr in myrad])
224
    vnum = 1 if verbose else 0
225
    nnum = 1 if nonnegative else 0
226
    mypc = "MSQ"
227
    if usecor:
228
        mypc = "PC"
229
230
    myargs = {
231
        "d": mydim,
232
        "t": target_image,
233
        "a": rho,
234
        "b": beta,
235
        "c": nnum,
236
        "p": myrad,
237
        "m": mypc,
238
        "s": r_search,
239
        "x": mymask,
240
        "o": outs,
241
        "v": vnum,
242
    }
243
244
    kct = len(myargs.keys())
245
    for k in range(len(atlas_list)):
246
        kct += 1
247
        myargs["g-MULTINAME-%i" % kct] = atlas_list[k]
248
        if not doJif:
249
            kct += 1
250
            castseg = label_list[k].clone(segpixtype)
251
            myargs["l-MULTINAME-%i" % kct] = castseg
252
253
    myprocessedargs = process_arguments(myargs)
254
255
    libfn = get_lib_fn("antsJointFusion")
256
    rval = libfn(myprocessedargs)
257
    if rval != 0:
258
        print("Warning: Non-zero return from antsJointFusion")
259
260
    if doJif:
261
        return outimgi
262
263
    probsout = glob.glob(os.path.join(tdir, "*" + searchpattern))
264
    probsout.sort()
265
    probimgs = []
266
#    print( os.system("ls -l "+probsout[0]) )
267
    for idx in range(len(probsout)):
268
        probimgs.append(ants.image_read(probsout[idx]))
269
270
    #    if len(probsout) != (len(inlabs)) and max_lab_plus_one == False:
271
    #        warnings.warn("Length of output probabilities != length of unique input labels")
272
273
    segmentation_numbers = [0] * len(probsout)
274
    for i in range(len(probsout)):
275
        temp = str.split(probsout[i], "prob")
276
        segnum = temp[len(temp) - 1].split(".nii.gz")[0]
277
        segmentation_numbers[i] = int(segnum)
278
279
    if max_lab_plus_one == False:
280
        segmat = ants.images_to_matrix(probimgs, target_image_mask)
281
        finalsegvec = segmat.argmax(axis=0)
282
        finalsegvec2 = finalsegvec.copy()
283
        # mapfinalsegvec to original labels
284
        for i in range(len(probsout)):
285
            temp = str.split(probsout[i], "prob")
286
            segnum = temp[len(temp) - 1].split(".nii.gz")[0]
287
            finalsegvec2[finalsegvec == i] = segnum
288
        outimg = ants.make_image(target_image_mask, finalsegvec2)
289
290
        return {
291
            "segmentation": outimg,
292
            "intensity": outimgi,
293
            "probabilityimages": probimgs,
294
            "segmentation_numbers": segmentation_numbers,
295
        }
296
297
    if max_lab_plus_one == True:
298
        mymaxlab = max(segmentation_numbers)
299
        matchings_indices = [
300
            i
301
            for i, segmentation_numbers in enumerate(segmentation_numbers)
302
            if segmentation_numbers == mymaxlab
303
        ]
304
        background_prob = probimgs[matchings_indices[0]]
305
        background_probfn = probsout[matchings_indices[0]]
306
        del probimgs[matchings_indices[0]]
307
        del probsout[matchings_indices[0]]
308
        del segmentation_numbers[matchings_indices[0]]
309
310
        segmat = ants.images_to_matrix(probimgs, target_image_mask)
311
312
        finalsegvec = segmat.argmax(axis=0)
313
        finalsegvec2 = finalsegvec.copy()
314
        # mapfinalsegvec to original labels
315
        for i in range(len(probsout)):
316
            temp = str.split(probsout[i], "prob")
317
            segnum = temp[len(temp) - 1].split(".nii.gz")[0]
318
            finalsegvec2[finalsegvec == i] = segnum
319
320
        outimg = ants.make_image(target_image_mask, finalsegvec2)
321
322
        # next decide what is "background" based on the sum of the first k labels vs the prob of the last one
323
        firstK = probimgs[0] * 0
324
        for i in range(len(probsout)):
325
            firstK = firstK + probimgs[i]
326
327
        segmat = ants.images_to_matrix([background_prob, firstK], target_image_mask)
328
        bkgsegvec = segmat.argmax(axis=0)
329
        outimg = outimg * ants.make_image(target_image_mask, bkgsegvec)
330
331
        return {
332
            "segmentation": outimg * ants.make_image(target_image_mask, bkgsegvec),
333
            "segmentation_raw": outimg,
334
            "intensity": outimgi,
335
            "probabilityimages": probimgs,
336
            "segmentation_numbers": segmentation_numbers,
337
            "background_prob": background_prob,
338
        }
339
340
341
def local_joint_label_fusion(
342
    target_image,
343
    which_labels,
344
    target_mask,
345
    initial_label,
346
    atlas_list,
347
    label_list,
348
    submask_dilation=10,
349
    type_of_transform="SyN",
350
    aff_metric="meansquares",
351
    syn_metric="mattes",
352
    syn_sampling=32,
353
    reg_iterations=(40, 20, 0),
354
    aff_iterations=(500, 50, 0),
355
    grad_step=0.2,
356
    flow_sigma=3,
357
    total_sigma=0,
358
    beta=4,
359
    rad=2,
360
    rho=0.1,
361
    usecor=False,
362
    r_search=3,
363
    nonnegative=False,
364
    no_zeroes=False,
365
    max_lab_plus_one=False,
366
    local_mask_transform="Similarity",
367
    output_prefix=None,
368
    verbose=False,
369
):
370
    """
371
    A local version of joint label fusion that focuses on a subset of labels.
372
    This is primarily different from standard JLF because it performs
373
    registration on the label subset and focuses JLF on those labels alone.
374
375
    ANTsR function: `localJointLabelFusion`
376
377
    Arguments
378
    ---------
379
    target_image : ANTsImage
380
        image to be labeled
381
382
    which_labels : numeric vector
383
        label number(s) that exist(s) in both the template and library
384
385
    target_image_mask : ANTsImage
386
        a mask for the target image (optional), passed to joint fusion
387
388
    initial_label : ANTsImage
389
        initial label set, may be same labels as library or binary.
390
        typically labels would be produced by a single deformable registration
391
        or by manual labeling.
392
393
    atlas_list : list of ANTsImage types
394
        list containing intensity images
395
396
    label_list : list of ANTsImage types (optional)
397
        list containing images with segmentation labels
398
399
    submask_dilation : integer
400
        amount to dilate initial mask to define region on which
401
        we perform focused registration
402
403
    type_of_transform : string
404
        A linear or non-linear registration type. Mutual information metric by default.
405
        See Notes below for more.
406
407
    aff_metric : string
408
        the metric for the affine part (GC, mattes, meansquares)
409
410
    syn_metric : string
411
        the metric for the syn part (CC, mattes, meansquares, demons)
412
413
    syn_sampling : scalar
414
        the nbins or radius parameter for the syn metric
415
416
    reg_iterations : list/tuple of integers
417
        vector of iterations for syn. we will set the smoothing and multi-resolution parameters based on the length of this vector.
418
419
420
    aff_iterations : list/tuple of integers
421
        vector of iterations for low-dimensional registration.
422
423
    grad_step : scalar
424
        gradient step size (not for all tx)
425
426
    flow_sigma : scalar
427
        smoothing for update field
428
429
    total_sigma : scalar
430
        smoothing for total field
431
432
    beta : scalar
433
        weight sharpness, default to 2
434
435
    rad : scalar
436
        neighborhood radius, default to 2
437
438
    rho : scalar
439
        ridge penalty increases robustness to outliers but also makes image converge to average
440
441
    usecor : boolean
442
        employ correlation as local similarity
443
444
    r_search : scalar
445
        radius of search, default is 3
446
447
    nonnegative : boolean
448
        constrain weights to be non-negative
449
450
    no_zeroes : boolean
451
        this will constrain the solution only to voxels that are always non-zero in the label list
452
453
    max_lab_plus_one : boolean
454
        this will add max label plus one to the non-zero parts of each label where the target mask
455
        is greater than one.  NOTE: this will have a side effect of adding to the original label
456
        images that are passed to the program.  It also guarantees that every position in the
457
        labels have some label, rather than none.  Ie it guarantees to explicitly parcellate the
458
        input data.
459
460
    local_mask_transform: string
461
        the type of transform for the local mask alignment - usually translation,
462
        rigid, similarity or affine.
463
464
    output_prefix: string
465
        file prefix for storing output probabilityimages to disk
466
467
    verbose : boolean
468
        whether to show status updates
469
470
    Returns
471
    -------
472
    dictionary w/ following key/value pairs:
473
        `segmentation` : ANTsImage
474
            segmentation image
475
476
        `intensity` : ANTsImage
477
            intensity image
478
479
        `probabilityimages` : list of ANTsImage types
480
            probability map image for each label
481
482
    """
483
    myregion = ants.mask_image(initial_label, initial_label, which_labels)
484
    if myregion.max() == 0:
485
        myregion = ants.threshold_image(initial_label, 1, math.inf)
486
487
    myregionb = ants.threshold_image(myregion, 1, math.inf)
488
    myregionAroundRegion = ants.iMath(myregionb, "MD", submask_dilation)
489
    if target_mask is not None:
490
        myregionAroundRegion = myregionAroundRegion * target_mask
491
    croppedImage = ants.crop_image(target_image, myregionAroundRegion)
492
    croppedMask = ants.crop_image(myregionAroundRegion, myregionAroundRegion)
493
    mycroppedregion = ants.crop_image(myregion, myregionAroundRegion)
494
    croppedmappedImages = []
495
    croppedmappedSegs = []
496
    if verbose is True:
497
        print("Begin registrations:")
498
    for k in range(len(atlas_list)):
499
500
        if verbose is True:
501
            print(str(k) + "...")
502
503
        if verbose is True:
504
            print( "local-seg-tx: " + local_mask_transform )
505
        libregion = ants.mask_image(label_list[k], label_list[k], which_labels)
506
        initMap = ants.registration(
507
            mycroppedregion, libregion, type_of_transform=local_mask_transform, aff_metric=aff_metric, aff_iterations=aff_iterations, verbose=False
508
        )["fwdtransforms"]
509
        if verbose is True:
510
            print( "local-img-tx: " + type_of_transform )
511
        localReg = ants.registration(
512
            croppedImage,
513
            atlas_list[k],
514
            reg_iterations=reg_iterations,
515
            flow_sigma=flow_sigma,
516
            total_sigma=total_sigma,
517
            grad_step=grad_step,
518
            type_of_transform=type_of_transform,
519
            syn_metric=syn_metric,
520
            syn_sampling=syn_sampling,
521
            initial_transform=initMap[0],
522
            verbose=False,
523
        )
524
        transformedImage = ants.apply_transforms(
525
            croppedImage, atlas_list[k], localReg["fwdtransforms"]
526
        )
527
        transformedLabels = ants.apply_transforms(
528
            croppedImage,
529
            label_list[k],
530
            localReg["fwdtransforms"],
531
            interpolator="nearestNeighbor",
532
        )
533
        croppedmappedImages.append(transformedImage)
534
        croppedmappedSegs.append(transformedLabels)
535
536
    ljlf = joint_label_fusion(
537
        croppedImage,
538
        croppedMask,
539
        atlas_list=croppedmappedImages,
540
        label_list=croppedmappedSegs,
541
        beta=beta,
542
        rad=rad,
543
        rho=rho,
544
        usecor=usecor,
545
        r_search=r_search,
546
        nonnegative=nonnegative,
547
        no_zeroes=no_zeroes,
548
        max_lab_plus_one=max_lab_plus_one,
549
        output_prefix=output_prefix,
550
        verbose=verbose,
551
    )
552
553
    return {
554
        "ljlf": ljlf,
555
        "croppedImage": croppedImage,
556
        "croppedmappedImages": croppedmappedImages,
557
        "croppedmappedSegs": croppedmappedSegs,
558
    }