a b/ants/plotting/plot.py
1
"""
2
Functions for plotting ants images
3
"""
4
5
6
__all__ = [
7
    "plot"
8
]
9
10
import fnmatch
11
import math
12
import os
13
import warnings
14
15
from matplotlib import gridspec
16
import matplotlib.pyplot as plt
17
import matplotlib.patheffects as path_effects
18
import matplotlib.lines as mlines
19
import matplotlib.patches as patches
20
import matplotlib.mlab as mlab
21
import matplotlib.animation as animation
22
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
23
24
import numpy as np
25
import ants
26
from ants.decorators import image_method
27
28
@image_method
29
def plot(
30
    image,
31
    overlay=None,
32
    blend=False,
33
    alpha=1,
34
    cmap="Greys_r",
35
    overlay_cmap="turbo",
36
    overlay_alpha=0.9,
37
    vminol=None,
38
    vmaxol=None,
39
    cbar=False,
40
    cbar_length=0.8,
41
    cbar_dx=0.0,
42
    cbar_vertical=True,
43
    axis=0,
44
    nslices=12,
45
    slices=None,
46
    ncol=None,
47
    slice_buffer=None,
48
    black_bg=True,
49
    bg_thresh_quant=0.01,
50
    bg_val_quant=0.99,
51
    domain_image_map=None,
52
    crop=False,
53
    scale=False,
54
    reverse=False,
55
    title=None,
56
    title_fontsize=20,
57
    title_dx=0.0,
58
    title_dy=0.0,
59
    filename=None,
60
    dpi=500,
61
    figsize=1.5,
62
    reorient=True,
63
    resample=True,
64
):
65
    """
66
    Plot an ANTsImage.
67
68
    Use mask_image and/or threshold_image to preprocess images to be be
69
    overlaid and display the overlays in a given range. See the wiki examples.
70
71
    By default, images will be reoriented to 'LAI' orientation before plotting.
72
    So, if axis == 0, the images will be ordered from the
73
    left side of the brain to the right side of the brain. If axis == 1,
74
    the images will be ordered from the anterior (front) of the brain to
75
    the posterior (back) of the brain. And if axis == 2, the images will
76
    be ordered from the inferior (bottom) of the brain to the superior (top)
77
    of the brain.
78
79
    ANTsR function: `plot.antsImage`
80
81
    Arguments
82
    ---------
83
    image : ANTsImage
84
        image to plot
85
86
    overlay : ANTsImage
87
        image to overlay on base image
88
89
    cmap : string
90
        colormap to use for base image. See matplotlib.
91
92
    overlay_cmap : string
93
        colormap to use for overlay images, if applicable. See matplotlib.
94
95
    overlay_alpha : float
96
        level of transparency for any overlays. Smaller value means
97
        the overlay is more transparent. See matplotlib.
98
99
    axis : integer
100
        which axis to plot along if image is 3D
101
102
    nslices : integer
103
        number of slices to plot if image is 3D
104
105
    slices : list or tuple of integers
106
        specific slice indices to plot if image is 3D.
107
        If given, this will override `nslices`.
108
        This can be absolute array indices (e.g. (80,100,120)), or
109
        this can be relative array indices (e.g. (0.4,0.5,0.6))
110
111
    ncol : integer
112
        Number of columns to have on the plot if image is 3D.
113
114
    slice_buffer : integer
115
        how many slices to buffer when finding the non-zero slices of
116
        a 3D images. So, if slice_buffer = 10, then the first slice
117
        in a 3D image will be the first non-zero slice index plus 10 more
118
        slices.
119
120
    black_bg : boolean
121
        if True, the background of the image(s) will be black.
122
        if False, the background of the image(s) will be determined by the
123
            values `bg_thresh_quant` and `bg_val_quant`.
124
125
    bg_thresh_quant : float
126
        if white_bg=True, the background will be determined by thresholding
127
        the image at the `bg_thresh` quantile value and setting the background
128
        intensity to the `bg_val` quantile value.
129
        This value should be in [0, 1] - somewhere around 0.01 is recommended.
130
            - equal to 1 will threshold the entire image
131
            - equal to 0 will threshold none of the image
132
133
    bg_val_quant : float
134
        if white_bg=True, the background will be determined by thresholding
135
        the image at the `bg_thresh` quantile value and setting the background
136
        intensity to the `bg_val` quantile value.
137
        This value should be in [0, 1]
138
            - equal to 1 is pure white
139
            - equal to 0 is pure black
140
            - somewhere in between is gray
141
142
    domain_image_map : ANTsImage
143
        this input ANTsImage or list of ANTsImage types contains a reference image
144
        `domain_image` and optional reference mapping named `domainMap`.
145
        If supplied, the image(s) to be plotted will be mapped to the domain
146
        image space before plotting - useful for non-standard image orientations.
147
148
    crop : boolean
149
        if true, the image(s) will be cropped to their bounding boxes, resulting
150
        in a potentially smaller image size.
151
        if false, the image(s) will not be cropped
152
153
    scale : boolean or 2-tuple
154
        if true, nothing will happen to intensities of image(s) and overlay(s)
155
        if false, dynamic range will be maximized when visualizing overlays
156
        if 2-tuple, the image will be dynamically scaled between these quantiles
157
158
    reverse : boolean
159
        if true, the order in which the slices are plotted will be reversed.
160
        This is useful if you want to plot from the front of the brain first
161
        to the back of the brain, or vice-versa
162
163
    title : string
164
        add a title to the plot
165
166
    filename : string
167
        if given, the resulting image will be saved to this file
168
169
    dpi : integer
170
        determines resolution of image if saved to file. Higher values
171
        result in higher resolution images, but at a cost of having a
172
        larger file size
173
174
    resample : bool
175
        if true, resample image if spacing is very unbalanced.
176
177
    Example
178
    -------
179
    >>> import ants
180
    >>> import numpy as np
181
    >>> img = ants.image_read(ants.get_data('r16'))
182
    >>> segs = img.kmeans_segmentation(k=3)['segmentation']
183
    >>> ants.plot(img, segs*(segs==1), crop=True)
184
    >>> ants.plot(img, segs*(segs==1), crop=False)
185
    >>> mni = ants.image_read(ants.get_data('mni'))
186
    >>> segs = mni.kmeans_segmentation(k=3)['segmentation']
187
    >>> ants.plot(mni, segs*(segs==1), crop=False)
188
    """
189
    if (axis == "x") or (axis == "saggittal"):
190
        axis = 0
191
    if (axis == "y") or (axis == "coronal"):
192
        axis = 1
193
    if (axis == "z") or (axis == "axial"):
194
        axis = 2
195
196
    def mirror_matrix(x):
197
        return x[::-1, :]
198
199
    def rotate270_matrix(x):
200
        return mirror_matrix(x.T)
201
202
    def rotate180_matrix(x):
203
        return x[::-1, ::-1]
204
205
    def rotate90_matrix(x):
206
        return x.T
207
208
    def reorient_slice(x, axis):
209
        if axis != 2:
210
            x = rotate90_matrix(x)
211
        if axis == 2:
212
            x = rotate270_matrix(x)
213
        x = mirror_matrix(x)
214
        return x
215
216
217
    # handle `image` argument
218
    if isinstance(image, str):
219
        image = ants.image_read(image)
220
    if not ants.is_image(image):
221
        raise ValueError("image argument must be an ANTsImage")
222
223
    if np.all(np.equal(image.numpy(), 0.0)):
224
        warnings.warn("Image must be non-zero. will not plot.")
225
        return
226
227
    # need this hack because of a weird NaN warning from matplotlib with overlays
228
    warnings.simplefilter("ignore")
229
230
    if (image.pixeltype not in {"float", "double"}) or (image.is_rgb):
231
        scale = False  # turn off scaling if image is discrete
232
233
    # handle `overlay` argument
234
    if overlay is not None:
235
        if isinstance(overlay, str):
236
            overlay = ants.image_read(overlay)
237
        if vminol is None:
238
            vminol = overlay.min()
239
        if vmaxol is None:
240
            vmaxol = overlay.max()
241
        if not ants.is_image(overlay):
242
            raise ValueError("overlay argument must be an ANTsImage")
243
        if overlay.components > 1:
244
            raise ValueError("overlay cannot have more than one voxel component")
245
246
        if not ants.image_physical_space_consistency(image, overlay):
247
            overlay = ants.resample_image_to_target(overlay, image, interp_type="nearestNeighbor")
248
249
        if blend:
250
            if alpha == 1:
251
                alpha = 0.5
252
            image = image * alpha + overlay * (1 - alpha)
253
            overlay = None
254
            alpha = 1.0
255
256
    # handle `domain_image_map` argument
257
    if domain_image_map is not None:
258
        tx = ants.new_ants_transform(
259
            precision="float",
260
            transform_type="AffineTransform",
261
            dimension=image.dimension,
262
        )
263
        image = ants.apply_ants_transform_to_image(tx, image, domain_image_map)
264
        if overlay is not None:
265
            overlay = ants.apply_ants_transform_to_image(
266
                tx, overlay, domain_image_map, interpolation="nearestNeighbor"
267
            )
268
269
    ## single-channel images ##
270
    if image.components == 1:
271
272
        # potentially crop image
273
        if crop:
274
            plotmask = image.get_mask(cleanup=0)
275
            if plotmask.max() == 0:
276
                plotmask += 1
277
            image = image.crop_image(plotmask)
278
            if overlay is not None:
279
                overlay = overlay.crop_image(plotmask)
280
281
        # potentially find dynamic range
282
        if scale == True:
283
            vmin, vmax = image.quantile((0.05, 0.95))
284
        elif isinstance(scale, (list, tuple)):
285
            if len(scale) != 2:
286
                raise ValueError(
287
                    "scale argument must be boolean or list/tuple with two values"
288
                )
289
            vmin, vmax = image.quantile(scale)
290
        else:
291
            vmin = None
292
            vmax = None
293
294
        # Plot 2D image
295
        if image.dimension == 2:
296
297
            img_arr = image.numpy()
298
            img_arr = rotate90_matrix(img_arr)
299
300
            if not black_bg:
301
                img_arr[img_arr < image.quantile(bg_thresh_quant)] = image.quantile(
302
                    bg_val_quant
303
                )
304
305
            if overlay is not None:
306
                ov_arr = overlay.numpy()
307
                mask = ov_arr == 0
308
                mask = np.ma.masked_where(mask == 0, mask)
309
                ov_arr = np.ma.masked_array(ov_arr, mask)
310
                ov_arr = rotate90_matrix(ov_arr)
311
312
            fig = plt.figure()
313
            if title is not None:
314
                fig.suptitle(
315
                    title, fontsize=title_fontsize, x=0.5 + title_dx, y=0.95 + title_dy
316
                )
317
318
            ax = plt.subplot(111)
319
320
            # plot main image
321
            im = ax.imshow(img_arr, cmap=cmap, alpha=alpha, vmin=vmin, vmax=vmax)
322
323
            if overlay is not None:
324
                im = ax.imshow(ov_arr, alpha=overlay_alpha, cmap=overlay_cmap,
325
                    vmin=vminol, vmax=vmaxol )
326
327
            if cbar:
328
                cbar_orient = "vertical" if cbar_vertical else "horizontal"
329
                fig.colorbar(im, orientation=cbar_orient)
330
331
            plt.axis("off")
332
333
        # Plot 3D image
334
        elif image.dimension == 3:
335
            # resample image if spacing is very unbalanced
336
            spacing = [s for i, s in enumerate(image.spacing) if i != axis]
337
            was_resampled = False
338
            if (max(spacing) / min(spacing)) > 3.0 and resample:
339
                was_resampled = True
340
                new_spacing = (1, 1, 1)
341
                image = image.resample_image(tuple(new_spacing))
342
                if overlay is not None:
343
                    overlay = overlay.resample_image(tuple(new_spacing))
344
345
            if reorient:
346
                image = image.reorient_image2("LAI")
347
            img_arr = image.numpy()
348
            # reorder dims so that chosen axis is first
349
            img_arr = np.rollaxis(img_arr, axis)
350
351
            if overlay is not None:
352
                if reorient:
353
                    overlay = overlay.reorient_image2("LAI")
354
                ov_arr = overlay.numpy()
355
                mask = ov_arr == 0
356
                mask = np.ma.masked_where(mask == 0, mask)
357
                ov_arr = np.ma.masked_array(ov_arr, mask)
358
                ov_arr = np.rollaxis(ov_arr, axis)
359
360
            if slices is None:
361
                if not isinstance(slice_buffer, (list, tuple)):
362
                    if slice_buffer is None:
363
                        slice_buffer = (
364
                            int(img_arr.shape[1] * 0.1),
365
                            int(img_arr.shape[2] * 0.1),
366
                        )
367
                    else:
368
                        slice_buffer = (slice_buffer, slice_buffer)
369
                nonzero = np.where(img_arr.sum(axis=(1, 2)) > 0.01)[0]
370
                min_idx = nonzero[0] + slice_buffer[0]
371
                max_idx = nonzero[-1] - slice_buffer[1]
372
                if min_idx > max_idx:
373
                    temp = min_idx
374
                    min_idx = max_idx
375
                    max_idx = temp
376
                if max_idx > nonzero.max():
377
                    max_idx = nonzero.max()
378
                if min_idx < 0:
379
                    min_idx = 0
380
                slice_idxs = np.linspace(min_idx, max_idx, nslices).astype("int")
381
                if reverse:
382
                    slice_idxs = np.array(list(reversed(slice_idxs)))
383
            else:
384
                if isinstance(slices, (int, float)):
385
                    slices = [slices]
386
                # if all slices are less than 1, infer that they are relative slices
387
                if sum([s > 1 for s in slices]) == 0:
388
                    slices = [int(s * img_arr.shape[0]) for s in slices]
389
                slice_idxs = slices
390
                nslices = len(slices)
391
392
            if was_resampled:
393
                # re-calculate slices to account for new image shape
394
                slice_idxs = np.unique(
395
                    np.array(
396
                        [
397
                            int(s * (image.shape[axis] / img_arr.shape[0]))
398
                            for s in slice_idxs
399
                        ]
400
                    )
401
                )
402
403
            # only have one row if nslices <= 6 and user didnt specify ncol
404
            if ncol is None:
405
                if nslices <= 6:
406
                    ncol = nslices
407
                else:
408
                    ncol = int(round(math.sqrt(nslices)))
409
410
            # calculate grid size
411
            nrow = math.ceil(nslices / ncol)
412
            xdim = img_arr.shape[2]
413
            ydim = img_arr.shape[1]
414
415
            dim_ratio = ydim / xdim
416
            fig = plt.figure(
417
                figsize=((ncol + 1) * figsize * dim_ratio, (nrow + 1) * figsize)
418
            )
419
            if title is not None:
420
                fig.suptitle(
421
                    title, fontsize=title_fontsize, x=0.5 + title_dx, y=0.95 + title_dy
422
                )
423
424
            gs = gridspec.GridSpec(
425
                nrow,
426
                ncol,
427
                wspace=0.0,
428
                hspace=0.0,
429
                top=1.0 - 0.5 / (nrow + 1),
430
                bottom=0.5 / (nrow + 1),
431
                left=0.5 / (ncol + 1),
432
                right=1 - 0.5 / (ncol + 1),
433
            )
434
435
            slice_idx_idx = 0
436
            for i in range(nrow):
437
                for j in range(ncol):
438
                    if slice_idx_idx < len(slice_idxs):
439
                        imslice = img_arr[slice_idxs[slice_idx_idx]]
440
                        imslice = reorient_slice(imslice, axis)
441
                        if not black_bg:
442
                            imslice[
443
                                imslice < image.quantile(bg_thresh_quant)
444
                            ] = image.quantile(bg_val_quant)
445
                    else:
446
                        imslice = np.zeros_like(img_arr[0])
447
                        imslice = reorient_slice(imslice, axis)
448
449
                    ax = plt.subplot(gs[i, j])
450
                    im = ax.imshow(imslice, cmap=cmap, vmin=vmin, vmax=vmax)
451
452
                    if overlay is not None:
453
                        if slice_idx_idx < len(slice_idxs):
454
                            ovslice = ov_arr[slice_idxs[slice_idx_idx]]
455
                            ovslice = reorient_slice(ovslice, axis)
456
                            im = ax.imshow(
457
                                ovslice, alpha=overlay_alpha, cmap=overlay_cmap,
458
                                    vmin=vminol, vmax=vmaxol )
459
                    ax.axis("off")
460
                    slice_idx_idx += 1
461
462
            if cbar:
463
                cbar_start = (1 - cbar_length) / 2
464
                if cbar_vertical:
465
                    cax = fig.add_axes([0.9 + cbar_dx, cbar_start, 0.03, cbar_length])
466
                    cbar_orient = "vertical"
467
                else:
468
                    cax = fig.add_axes([cbar_start, 0.08 + cbar_dx, cbar_length, 0.03])
469
                    cbar_orient = "horizontal"
470
                fig.colorbar(im, cax=cax, orientation=cbar_orient)
471
472
    ## multi-channel images ##
473
    elif image.has_components:
474
        raise Exception('Plotting images with components is not currently supported.')
475
476
    if filename is not None:
477
        filename = os.path.expanduser(filename)
478
        plt.savefig(filename, dpi=dpi, transparent=True, bbox_inches="tight")
479
        plt.close(fig)
480
    else:
481
        plt.show()
482
483
    # turn warnings back to default
484
    warnings.simplefilter("default")
485
486