Switch to unified view

a b/ants/plotting/plot_ortho.py
1
"""
2
Functions for plotting ants images
3
"""
4
5
6
__all__ = [
7
    "plot_ortho"
8
]
9
10
import fnmatch
11
import math
12
import os
13
import warnings
14
15
from matplotlib import gridspec
16
import matplotlib.pyplot as plt
17
import matplotlib.patheffects as path_effects
18
import matplotlib.lines as mlines
19
import matplotlib.patches as patches
20
import matplotlib.mlab as mlab
21
import matplotlib.animation as animation
22
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
23
24
25
import numpy as np
26
import ants
27
from ants.decorators import image_method
28
29
@image_method
30
def plot_ortho(
31
    image,
32
    overlay=None,
33
    reorient=True,
34
    blend=False,
35
    # xyz arguments
36
    xyz=None,
37
    xyz_lines=True,
38
    xyz_color="red",
39
    xyz_alpha=0.6,
40
    xyz_linewidth=2,
41
    xyz_pad=5,
42
    orient_labels=True,
43
    # base image arguments
44
    alpha=1,
45
    cmap="Greys_r",
46
    # overlay arguments
47
    overlay_cmap="jet",
48
    overlay_alpha=0.9,
49
    cbar=False,
50
    cbar_length=0.8,
51
    cbar_dx=0.0,
52
    cbar_vertical=True,
53
    # background arguments
54
    black_bg=True,
55
    bg_thresh_quant=0.01,
56
    bg_val_quant=0.99,
57
    # scale/crop/domain arguments
58
    crop=False,
59
    scale=False,
60
    domain_image_map=None,
61
    # title arguments
62
    title=None,
63
    titlefontsize=24,
64
    title_dx=0,
65
    title_dy=0,
66
    # 4th panel text arguemnts
67
    text=None,
68
    textfontsize=24,
69
    textfontcolor="white",
70
    text_dx=0,
71
    text_dy=0,
72
    # save & size arguments
73
    filename=None,
74
    dpi=500,
75
    figsize=1.0,
76
    flat=False,
77
    transparent=True,
78
    resample=False,
79
    allow_xyz_change=True,
80
):
81
    """
82
    Plot an orthographic view of a 3D image
83
84
    Use mask_image and/or threshold_image to preprocess images to be be
85
    overlaid and display the overlays in a given range. See the wiki examples.
86
87
    ANTsR function: N/A
88
89
    Arguments
90
    ---------
91
    image : ANTsImage
92
        image to plot
93
94
    overlay : ANTsImage
95
        image to overlay on base image
96
97
    xyz : list or tuple of 3 integers
98
        selects index location on which to center display
99
        if given, solid lines will be drawn to converge at this coordinate.
100
        This is useful for pinpointing a specific location in the image.
101
102
    flat : boolean
103
        if true, the ortho image will be plot in one row
104
        if false, the ortho image will be a 2x2 grid with the bottom
105
            left corner blank
106
107
    cmap : string
108
        colormap to use for base image. See matplotlib.
109
110
    overlay_cmap : string
111
        colormap to use for overlay images, if applicable. See matplotlib.
112
113
    overlay_alpha : float
114
        level of transparency for any overlays. Smaller value means
115
        the overlay is more transparent. See matplotlib.
116
117
    cbar: boolean
118
        if true, a colorbar will be added to the plot
119
120
    cbar_length: float
121
        length of the colorbar relative to the image
122
123
    cbar_dx: float
124
        horizontal shift of the colorbar relative to the image
125
126
    cbar_vertical: boolean
127
        if true, the colorbar will be vertical, if false, it will be
128
        horizontal underneath the image
129
130
    axis : integer
131
        which axis to plot along if image is 3D
132
133
    black_bg : boolean
134
        if True, the background of the image(s) will be black.
135
        if False, the background of the image(s) will be determined by the
136
            values `bg_thresh_quant` and `bg_val_quant`.
137
138
    bg_thresh_quant : float
139
        if white_bg=True, the background will be determined by thresholding
140
        the image at the `bg_thresh` quantile value and setting the background
141
        intensity to the `bg_val` quantile value.
142
        This value should be in [0, 1] - somewhere around 0.01 is recommended.
143
            - equal to 1 will threshold the entire image
144
            - equal to 0 will threshold none of the image
145
146
    bg_val_quant : float
147
        if white_bg=True, the background will be determined by thresholding
148
        the image at the `bg_thresh` quantile value and setting the background
149
        intensity to the `bg_val` quantile value.
150
        This value should be in [0, 1]
151
            - equal to 1 is pure white
152
            - equal to 0 is pure black
153
            - somewhere in between is gray
154
155
    domain_image_map : ANTsImage
156
        this input ANTsImage or list of ANTsImage types contains a reference image
157
        `domain_image` and optional reference mapping named `domainMap`.
158
        If supplied, the image(s) to be plotted will be mapped to the domain
159
        image space before plotting - useful for non-standard image orientations.
160
161
    crop : boolean
162
        if true, the image(s) will be cropped to their bounding boxes, resulting
163
        in a potentially smaller image size.
164
        if false, the image(s) will not be cropped
165
166
    scale : boolean or 2-tuple
167
        if true, nothing will happen to intensities of image(s) and overlay(s)
168
        if false, dynamic range will be maximized when visualizing overlays
169
        if 2-tuple, the image will be dynamically scaled between these quantiles
170
171
    title : string
172
        add a title to the plot
173
174
    filename : string
175
        if given, the resulting image will be saved to this file
176
177
    dpi : integer
178
        determines resolution of image if saved to file. Higher values
179
        result in higher resolution images, but at a cost of having a
180
        larger file size
181
182
    resample : resample image in case of unbalanced spacing
183
184
    allow_xyz_change : boolean will attempt to adjust xyz after padding
185
186
    Example
187
    -------
188
    >>> import ants
189
    >>> mni = ants.image_read(ants.get_data('mni'))
190
    >>> ants.plot_ortho(mni, xyz=(100,100,100))
191
    >>> mni2 = mni.threshold_image(7000, mni.max())
192
    >>> ants.plot_ortho(mni, overlay=mni2)
193
    >>> ants.plot_ortho(mni, overlay=mni2, flat=True)
194
    >>> ants.plot_ortho(mni, overlay=mni2, xyz=(110,110,110), xyz_lines=False,
195
                        text='Lines Turned Off', textfontsize=22)
196
    >>> ants.plot_ortho(mni, mni2, xyz=(120,100,100),
197
                        text=' Example \nOrtho Text', textfontsize=26,
198
                        title='Example Ortho Title', titlefontsize=26)
199
    """
200
201
    def mirror_matrix(x):
202
        return x[::-1, :]
203
204
    def rotate270_matrix(x):
205
        return mirror_matrix(x.T)
206
207
    def reorient_slice(x, axis):
208
        return rotate270_matrix(x)
209
210
    # need this hack because of a weird NaN warning from matplotlib with overlays
211
    warnings.simplefilter("ignore")
212
213
    # handle `image` argument
214
    if isinstance(image, str):
215
        image = ants.image_read(image)
216
    if not ants.is_image(image):
217
        raise ValueError("image argument must be an ANTsImage")
218
    if image.dimension != 3:
219
        raise ValueError("Input image must have 3 dimensions!")
220
221
    # handle `overlay` argument
222
    if overlay is not None:
223
        if isinstance(overlay, str):
224
            overlay = ants.image_read(overlay)
225
        vminol = overlay.min()
226
        vmaxol = overlay.max()
227
        if not ants.is_image(overlay):
228
            raise ValueError("overlay argument must be an ANTsImage")
229
        if overlay.components > 1:
230
            raise ValueError("overlay cannot have more than one voxel component")
231
        if overlay.dimension != 3:
232
            raise ValueError("Overlay image must have 3 dimensions!")
233
234
        if not ants.image_physical_space_consistency(image, overlay):
235
            overlay = ants.resample_image_to_target(overlay, image, interp_type="linear")
236
237
    if blend:
238
        if alpha == 1:
239
            alpha = 0.5
240
        image = image * alpha + overlay * (1 - alpha)
241
        overlay = None
242
        alpha = 1.0
243
244
    if image.pixeltype not in {"float", "double"}:
245
        scale = False  # turn off scaling if image is discrete
246
247
    # reorient images
248
    if reorient != False:
249
        if reorient == True:
250
            reorient = "RPI"
251
        image = image.reorient_image2("RPI")
252
        if overlay is not None:
253
            overlay = overlay.reorient_image2("RPI")
254
255
    # handle `slices` argument
256
    if xyz is None:
257
        xyz = [int(s / 2) for s in image.shape]
258
    for i in range(3):
259
        if xyz[i] is None:
260
            xyz[i] = int(image.shape[i] / 2)
261
262
    # resample image if spacing is very unbalanced
263
    spacing = [s for i, s in enumerate(image.spacing)]
264
    if (max(spacing) / min(spacing)) > 3.0 and resample:
265
        new_spacing = (1, 1, 1)
266
        image = image.resample_image(tuple(new_spacing))
267
        if overlay is not None:
268
            overlay = overlay.resample_image(tuple(new_spacing))
269
        xyz = [
270
            int(sl * (sold / snew)) for sl, sold, snew in zip(xyz, spacing, new_spacing)
271
        ]
272
273
274
    # potentially crop image
275
    if crop:
276
        plotmask = image.get_mask(cleanup=0)
277
        if plotmask.max() == 0:
278
            plotmask += 1
279
        image = image.crop_image(plotmask)
280
        if overlay is not None:
281
            overlay = overlay.crop_image(plotmask)
282
283
    # pad images
284
    if True:
285
        image, lowpad, uppad = image.pad_image(return_padvals=True)
286
        if allow_xyz_change:
287
            xyz = [v + l for v, l in zip(xyz, lowpad)]
288
        if overlay is not None:
289
            overlay = overlay.pad_image()
290
291
292
    # handle `domain_image_map` argument
293
    if domain_image_map is not None:
294
        if ants.is_image(domain_image_map):
295
            tx = ants.new_ants_transform(
296
                precision="float",
297
                transform_type="AffineTransform",
298
                dimension=image.dimension,
299
            )
300
            image = ants.apply_ants_transform_to_image(tx, image, domain_image_map)
301
            if overlay is not None:
302
                overlay = ants.apply_ants_transform_to_image(
303
                    tx, overlay, domain_image_map, interpolation="linear"
304
                )
305
        else:
306
            raise Exception('The domain_image_map must be an image.')
307
308
    ## single-channel images ##
309
    if image.components == 1:
310
311
        # potentially find dynamic range
312
        if scale == True:
313
            vmin, vmax = image.quantile((0.05, 0.95))
314
        elif isinstance(scale, (list, tuple)):
315
            if len(scale) != 2:
316
                raise ValueError(
317
                    "scale argument must be boolean or list/tuple with two values"
318
                )
319
            vmin, vmax = image.quantile(scale)
320
        else:
321
            vmin = None
322
            vmax = None
323
324
        if not flat:
325
            nrow = 2
326
            ncol = 2
327
        else:
328
            nrow = 1
329
            ncol = 3
330
331
        fig = plt.figure(figsize=(9 * figsize, 9 * figsize))
332
        if title is not None:
333
            basey = 0.88 if not flat else 0.66
334
            basex = 0.5
335
            fig.suptitle(
336
                title, fontsize=titlefontsize, color=textfontcolor, x=basex + title_dx, y=basey + title_dy
337
            )
338
339
        gs = gridspec.GridSpec(
340
            nrow,
341
            ncol,
342
            wspace=0.0,
343
            hspace=0.0,
344
            top=1.0 - 0.5 / (nrow + 1),
345
            bottom=0.5 / (nrow + 1),
346
            left=0.5 / (ncol + 1),
347
            right=1 - 0.5 / (ncol + 1),
348
        )
349
350
        # pad image to have isotropic array dimensions
351
        imageReturn = image.clone()
352
        image = image.numpy()
353
        overlayReturn = None
354
        if overlay is not None:
355
            overlayReturn = overlay.clone()
356
            overlay = overlay.numpy()
357
            if overlay.dtype not in ["uint8", "uint32"]:
358
                overlay = np.ma.masked_where( np.abs(overlay) <= 1e-16, overlay)
359
#                overlay[np.abs(overlay) == 0] = np.nan
360
361
        yz_slice = reorient_slice(image[xyz[0], :, :], 0)
362
        ax = plt.subplot(gs[0, 0])
363
        ax.imshow(yz_slice, cmap=cmap, vmin=vmin, vmax=vmax)
364
        if overlay is not None:
365
            yz_overlay = reorient_slice(overlay[xyz[0], :, :], 0)
366
            ax.imshow(yz_overlay, alpha=overlay_alpha, cmap=overlay_cmap, vmin=vminol, vmax=vmaxol )
367
        if xyz_lines:
368
            # add lines
369
            l = mlines.Line2D(
370
                [xyz[1], xyz[1]],
371
                [xyz_pad, yz_slice.shape[0] - xyz_pad],
372
                color=xyz_color,
373
                alpha=xyz_alpha,
374
                linewidth=xyz_linewidth,
375
            )
376
            ax.add_line(l)
377
            l = mlines.Line2D(
378
                [xyz_pad, yz_slice.shape[1] - xyz_pad],
379
                [yz_slice.shape[1] - xyz[2], yz_slice.shape[1] - xyz[2]],
380
                color=xyz_color,
381
                alpha=xyz_alpha,
382
                linewidth=xyz_linewidth,
383
            )
384
            ax.add_line(l)
385
        if orient_labels:
386
            ax.text(
387
                0.5,
388
                0.98,
389
                "S",
390
                horizontalalignment="center",
391
                verticalalignment="top",
392
                fontsize=20 * figsize,
393
                color=textfontcolor,
394
                transform=ax.transAxes,
395
            )
396
            ax.text(
397
                0.5,
398
                0.02,
399
                "I",
400
                horizontalalignment="center",
401
                verticalalignment="bottom",
402
                fontsize=20 * figsize,
403
                color=textfontcolor,
404
                transform=ax.transAxes,
405
            )
406
            ax.text(
407
                0.98,
408
                0.5,
409
                "A",
410
                horizontalalignment="right",
411
                verticalalignment="center",
412
                fontsize=20 * figsize,
413
                color=textfontcolor,
414
                transform=ax.transAxes,
415
            )
416
            ax.text(
417
                0.02,
418
                0.5,
419
                "P",
420
                horizontalalignment="left",
421
                verticalalignment="center",
422
                fontsize=20 * figsize,
423
                color=textfontcolor,
424
                transform=ax.transAxes,
425
            )
426
        ax.axis("off")
427
428
        xz_slice = reorient_slice(image[:, xyz[1], :], 1)
429
        ax = plt.subplot(gs[0, 1])
430
        ax.imshow(xz_slice, cmap=cmap, vmin=vmin, vmax=vmax)
431
        if overlay is not None:
432
            xz_overlay = reorient_slice(overlay[:, xyz[1], :], 1)
433
            ax.imshow(xz_overlay, alpha=overlay_alpha, cmap=overlay_cmap, vmin=vminol, vmax=vmaxol )
434
435
        if xyz_lines:
436
            # add lines
437
            l = mlines.Line2D(
438
                [xz_slice.shape[0] - xyz[0], xz_slice.shape[0] - xyz[0]],
439
                [xyz_pad, xz_slice.shape[0] - xyz_pad],
440
                color=xyz_color,
441
                alpha=xyz_alpha,
442
                linewidth=xyz_linewidth,
443
            )
444
            ax.add_line(l)
445
            l = mlines.Line2D(
446
                [xyz_pad, xz_slice.shape[1] - xyz_pad],
447
                [xz_slice.shape[1] - xyz[2], xz_slice.shape[1] - xyz[2]],
448
                color=xyz_color,
449
                alpha=xyz_alpha,
450
                linewidth=xyz_linewidth,
451
            )
452
            ax.add_line(l)
453
        if orient_labels:
454
            ax.text(
455
                0.5,
456
                0.98,
457
                "S",
458
                horizontalalignment="center",
459
                verticalalignment="top",
460
                fontsize=20 * figsize,
461
                color=textfontcolor,
462
                transform=ax.transAxes,
463
            )
464
            ax.text(
465
                0.5,
466
                0.02,
467
                "I",
468
                horizontalalignment="center",
469
                verticalalignment="bottom",
470
                fontsize=20 * figsize,
471
                color=textfontcolor,
472
                transform=ax.transAxes,
473
            )
474
            ax.text(
475
                0.98,
476
                0.5,
477
                "L",
478
                horizontalalignment="right",
479
                verticalalignment="center",
480
                fontsize=20 * figsize,
481
                color=textfontcolor,
482
                transform=ax.transAxes,
483
            )
484
            ax.text(
485
                0.02,
486
                0.5,
487
                "R",
488
                horizontalalignment="left",
489
                verticalalignment="center",
490
                fontsize=20 * figsize,
491
                color=textfontcolor,
492
                transform=ax.transAxes,
493
            )
494
        ax.axis("off")
495
496
        xy_slice = reorient_slice(image[:, :, xyz[2]], 2)
497
        if not flat:
498
            ax = plt.subplot(gs[1, 1])
499
        else:
500
            ax = plt.subplot(gs[0, 2])
501
        im = ax.imshow(xy_slice, cmap=cmap, vmin=vmin, vmax=vmax)
502
        if overlay is not None:
503
            xy_overlay = reorient_slice(overlay[:, :, xyz[2]], 2)
504
            im = ax.imshow(xy_overlay, alpha=overlay_alpha, cmap=overlay_cmap, vmin=vminol, vmax=vmaxol)
505
506
        if xyz_lines:
507
            # add lines
508
            l = mlines.Line2D(
509
                [xy_slice.shape[0] - xyz[0], xy_slice.shape[0] - xyz[0]],
510
                [xyz_pad, xy_slice.shape[0] - xyz_pad],
511
                color=xyz_color,
512
                alpha=xyz_alpha,
513
                linewidth=xyz_linewidth,
514
            )
515
            ax.add_line(l)
516
            l = mlines.Line2D(
517
                [xyz_pad, xy_slice.shape[1] - xyz_pad],
518
                [xy_slice.shape[1] - xyz[1], xy_slice.shape[1] - xyz[1]],
519
                color=xyz_color,
520
                alpha=xyz_alpha,
521
                linewidth=xyz_linewidth,
522
            )
523
            ax.add_line(l)
524
        if orient_labels:
525
            ax.text(
526
                0.5,
527
                0.98,
528
                "A",
529
                horizontalalignment="center",
530
                verticalalignment="top",
531
                fontsize=20 * figsize,
532
                color=textfontcolor,
533
                transform=ax.transAxes,
534
            )
535
            ax.text(
536
                0.5,
537
                0.02,
538
                "P",
539
                horizontalalignment="center",
540
                verticalalignment="bottom",
541
                fontsize=20 * figsize,
542
                color=textfontcolor,
543
                transform=ax.transAxes,
544
            )
545
            ax.text(
546
                0.98,
547
                0.5,
548
                "L",
549
                horizontalalignment="right",
550
                verticalalignment="center",
551
                fontsize=20 * figsize,
552
                color=textfontcolor,
553
                transform=ax.transAxes,
554
            )
555
            ax.text(
556
                0.02,
557
                0.5,
558
                "R",
559
                horizontalalignment="left",
560
                verticalalignment="center",
561
                fontsize=20 * figsize,
562
                color=textfontcolor,
563
                transform=ax.transAxes,
564
            )
565
        ax.axis("off")
566
567
        if not flat:
568
            # empty corner
569
            ax = plt.subplot(gs[1, 0])
570
            if text is not None:
571
                # add text
572
                left, width = 0.25, 0.5
573
                bottom, height = 0.25, 0.5
574
                right = left + width
575
                top = bottom + height
576
                ax.text(
577
                    0.5 * (left + right) + text_dx,
578
                    0.5 * (bottom + top) + text_dy,
579
                    text,
580
                    horizontalalignment="center",
581
                    verticalalignment="center",
582
                    fontsize=textfontsize,
583
                    color=textfontcolor,
584
                    transform=ax.transAxes,
585
                )
586
            # ax.text(0.5, 0.5)
587
            ax.imshow(np.zeros(image.shape[:-1]), cmap="Greys_r")
588
            ax.axis("off")
589
590
        if cbar:
591
            cbar_start = (1 - cbar_length) / 2
592
            if cbar_vertical:
593
                cax = fig.add_axes([0.9 + cbar_dx, cbar_start, 0.03, cbar_length])
594
                cbar_orient = "vertical"
595
            else:
596
                cax = fig.add_axes([cbar_start, 0.08 + cbar_dx, cbar_length, 0.03])
597
                cbar_orient = "horizontal"
598
            fig.colorbar(im, cax=cax, orientation=cbar_orient)
599
600
    ## multi-channel images ##
601
    elif image.components > 1:
602
        raise ValueError("Multi-channel images not currently supported!")
603
604
    if filename is not None:
605
        plt.savefig(filename, dpi=dpi, transparent=transparent)
606
        plt.close(fig)
607
    else:
608
        plt.show()
609
610
    # turn warnings back to default
611
    warnings.simplefilter("default")
612