"""
Functions for plotting ants images
"""
__all__ = [
"plot_ortho"
]
import fnmatch
import math
import os
import warnings
from matplotlib import gridspec
import matplotlib.pyplot as plt
import matplotlib.patheffects as path_effects
import matplotlib.lines as mlines
import matplotlib.patches as patches
import matplotlib.mlab as mlab
import matplotlib.animation as animation
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import numpy as np
import ants
from ants.decorators import image_method
@image_method
def plot_ortho(
image,
overlay=None,
reorient=True,
blend=False,
# xyz arguments
xyz=None,
xyz_lines=True,
xyz_color="red",
xyz_alpha=0.6,
xyz_linewidth=2,
xyz_pad=5,
orient_labels=True,
# base image arguments
alpha=1,
cmap="Greys_r",
# overlay arguments
overlay_cmap="jet",
overlay_alpha=0.9,
cbar=False,
cbar_length=0.8,
cbar_dx=0.0,
cbar_vertical=True,
# background arguments
black_bg=True,
bg_thresh_quant=0.01,
bg_val_quant=0.99,
# scale/crop/domain arguments
crop=False,
scale=False,
domain_image_map=None,
# title arguments
title=None,
titlefontsize=24,
title_dx=0,
title_dy=0,
# 4th panel text arguemnts
text=None,
textfontsize=24,
textfontcolor="white",
text_dx=0,
text_dy=0,
# save & size arguments
filename=None,
dpi=500,
figsize=1.0,
flat=False,
transparent=True,
resample=False,
allow_xyz_change=True,
):
"""
Plot an orthographic view of a 3D image
Use mask_image and/or threshold_image to preprocess images to be be
overlaid and display the overlays in a given range. See the wiki examples.
ANTsR function: N/A
Arguments
---------
image : ANTsImage
image to plot
overlay : ANTsImage
image to overlay on base image
xyz : list or tuple of 3 integers
selects index location on which to center display
if given, solid lines will be drawn to converge at this coordinate.
This is useful for pinpointing a specific location in the image.
flat : boolean
if true, the ortho image will be plot in one row
if false, the ortho image will be a 2x2 grid with the bottom
left corner blank
cmap : string
colormap to use for base image. See matplotlib.
overlay_cmap : string
colormap to use for overlay images, if applicable. See matplotlib.
overlay_alpha : float
level of transparency for any overlays. Smaller value means
the overlay is more transparent. See matplotlib.
cbar: boolean
if true, a colorbar will be added to the plot
cbar_length: float
length of the colorbar relative to the image
cbar_dx: float
horizontal shift of the colorbar relative to the image
cbar_vertical: boolean
if true, the colorbar will be vertical, if false, it will be
horizontal underneath the image
axis : integer
which axis to plot along if image is 3D
black_bg : boolean
if True, the background of the image(s) will be black.
if False, the background of the image(s) will be determined by the
values `bg_thresh_quant` and `bg_val_quant`.
bg_thresh_quant : float
if white_bg=True, the background will be determined by thresholding
the image at the `bg_thresh` quantile value and setting the background
intensity to the `bg_val` quantile value.
This value should be in [0, 1] - somewhere around 0.01 is recommended.
- equal to 1 will threshold the entire image
- equal to 0 will threshold none of the image
bg_val_quant : float
if white_bg=True, the background will be determined by thresholding
the image at the `bg_thresh` quantile value and setting the background
intensity to the `bg_val` quantile value.
This value should be in [0, 1]
- equal to 1 is pure white
- equal to 0 is pure black
- somewhere in between is gray
domain_image_map : ANTsImage
this input ANTsImage or list of ANTsImage types contains a reference image
`domain_image` and optional reference mapping named `domainMap`.
If supplied, the image(s) to be plotted will be mapped to the domain
image space before plotting - useful for non-standard image orientations.
crop : boolean
if true, the image(s) will be cropped to their bounding boxes, resulting
in a potentially smaller image size.
if false, the image(s) will not be cropped
scale : boolean or 2-tuple
if true, nothing will happen to intensities of image(s) and overlay(s)
if false, dynamic range will be maximized when visualizing overlays
if 2-tuple, the image will be dynamically scaled between these quantiles
title : string
add a title to the plot
filename : string
if given, the resulting image will be saved to this file
dpi : integer
determines resolution of image if saved to file. Higher values
result in higher resolution images, but at a cost of having a
larger file size
resample : resample image in case of unbalanced spacing
allow_xyz_change : boolean will attempt to adjust xyz after padding
Example
-------
>>> import ants
>>> mni = ants.image_read(ants.get_data('mni'))
>>> ants.plot_ortho(mni, xyz=(100,100,100))
>>> mni2 = mni.threshold_image(7000, mni.max())
>>> ants.plot_ortho(mni, overlay=mni2)
>>> ants.plot_ortho(mni, overlay=mni2, flat=True)
>>> ants.plot_ortho(mni, overlay=mni2, xyz=(110,110,110), xyz_lines=False,
text='Lines Turned Off', textfontsize=22)
>>> ants.plot_ortho(mni, mni2, xyz=(120,100,100),
text=' Example \nOrtho Text', textfontsize=26,
title='Example Ortho Title', titlefontsize=26)
"""
def mirror_matrix(x):
return x[::-1, :]
def rotate270_matrix(x):
return mirror_matrix(x.T)
def reorient_slice(x, axis):
return rotate270_matrix(x)
# need this hack because of a weird NaN warning from matplotlib with overlays
warnings.simplefilter("ignore")
# handle `image` argument
if isinstance(image, str):
image = ants.image_read(image)
if not ants.is_image(image):
raise ValueError("image argument must be an ANTsImage")
if image.dimension != 3:
raise ValueError("Input image must have 3 dimensions!")
# handle `overlay` argument
if overlay is not None:
if isinstance(overlay, str):
overlay = ants.image_read(overlay)
vminol = overlay.min()
vmaxol = overlay.max()
if not ants.is_image(overlay):
raise ValueError("overlay argument must be an ANTsImage")
if overlay.components > 1:
raise ValueError("overlay cannot have more than one voxel component")
if overlay.dimension != 3:
raise ValueError("Overlay image must have 3 dimensions!")
if not ants.image_physical_space_consistency(image, overlay):
overlay = ants.resample_image_to_target(overlay, image, interp_type="linear")
if blend:
if alpha == 1:
alpha = 0.5
image = image * alpha + overlay * (1 - alpha)
overlay = None
alpha = 1.0
if image.pixeltype not in {"float", "double"}:
scale = False # turn off scaling if image is discrete
# reorient images
if reorient != False:
if reorient == True:
reorient = "RPI"
image = image.reorient_image2("RPI")
if overlay is not None:
overlay = overlay.reorient_image2("RPI")
# handle `slices` argument
if xyz is None:
xyz = [int(s / 2) for s in image.shape]
for i in range(3):
if xyz[i] is None:
xyz[i] = int(image.shape[i] / 2)
# resample image if spacing is very unbalanced
spacing = [s for i, s in enumerate(image.spacing)]
if (max(spacing) / min(spacing)) > 3.0 and resample:
new_spacing = (1, 1, 1)
image = image.resample_image(tuple(new_spacing))
if overlay is not None:
overlay = overlay.resample_image(tuple(new_spacing))
xyz = [
int(sl * (sold / snew)) for sl, sold, snew in zip(xyz, spacing, new_spacing)
]
# potentially crop image
if crop:
plotmask = image.get_mask(cleanup=0)
if plotmask.max() == 0:
plotmask += 1
image = image.crop_image(plotmask)
if overlay is not None:
overlay = overlay.crop_image(plotmask)
# pad images
if True:
image, lowpad, uppad = image.pad_image(return_padvals=True)
if allow_xyz_change:
xyz = [v + l for v, l in zip(xyz, lowpad)]
if overlay is not None:
overlay = overlay.pad_image()
# handle `domain_image_map` argument
if domain_image_map is not None:
if ants.is_image(domain_image_map):
tx = ants.new_ants_transform(
precision="float",
transform_type="AffineTransform",
dimension=image.dimension,
)
image = ants.apply_ants_transform_to_image(tx, image, domain_image_map)
if overlay is not None:
overlay = ants.apply_ants_transform_to_image(
tx, overlay, domain_image_map, interpolation="linear"
)
else:
raise Exception('The domain_image_map must be an image.')
## single-channel images ##
if image.components == 1:
# potentially find dynamic range
if scale == True:
vmin, vmax = image.quantile((0.05, 0.95))
elif isinstance(scale, (list, tuple)):
if len(scale) != 2:
raise ValueError(
"scale argument must be boolean or list/tuple with two values"
)
vmin, vmax = image.quantile(scale)
else:
vmin = None
vmax = None
if not flat:
nrow = 2
ncol = 2
else:
nrow = 1
ncol = 3
fig = plt.figure(figsize=(9 * figsize, 9 * figsize))
if title is not None:
basey = 0.88 if not flat else 0.66
basex = 0.5
fig.suptitle(
title, fontsize=titlefontsize, color=textfontcolor, x=basex + title_dx, y=basey + title_dy
)
gs = gridspec.GridSpec(
nrow,
ncol,
wspace=0.0,
hspace=0.0,
top=1.0 - 0.5 / (nrow + 1),
bottom=0.5 / (nrow + 1),
left=0.5 / (ncol + 1),
right=1 - 0.5 / (ncol + 1),
)
# pad image to have isotropic array dimensions
imageReturn = image.clone()
image = image.numpy()
overlayReturn = None
if overlay is not None:
overlayReturn = overlay.clone()
overlay = overlay.numpy()
if overlay.dtype not in ["uint8", "uint32"]:
overlay = np.ma.masked_where( np.abs(overlay) <= 1e-16, overlay)
# overlay[np.abs(overlay) == 0] = np.nan
yz_slice = reorient_slice(image[xyz[0], :, :], 0)
ax = plt.subplot(gs[0, 0])
ax.imshow(yz_slice, cmap=cmap, vmin=vmin, vmax=vmax)
if overlay is not None:
yz_overlay = reorient_slice(overlay[xyz[0], :, :], 0)
ax.imshow(yz_overlay, alpha=overlay_alpha, cmap=overlay_cmap, vmin=vminol, vmax=vmaxol )
if xyz_lines:
# add lines
l = mlines.Line2D(
[xyz[1], xyz[1]],
[xyz_pad, yz_slice.shape[0] - xyz_pad],
color=xyz_color,
alpha=xyz_alpha,
linewidth=xyz_linewidth,
)
ax.add_line(l)
l = mlines.Line2D(
[xyz_pad, yz_slice.shape[1] - xyz_pad],
[yz_slice.shape[1] - xyz[2], yz_slice.shape[1] - xyz[2]],
color=xyz_color,
alpha=xyz_alpha,
linewidth=xyz_linewidth,
)
ax.add_line(l)
if orient_labels:
ax.text(
0.5,
0.98,
"S",
horizontalalignment="center",
verticalalignment="top",
fontsize=20 * figsize,
color=textfontcolor,
transform=ax.transAxes,
)
ax.text(
0.5,
0.02,
"I",
horizontalalignment="center",
verticalalignment="bottom",
fontsize=20 * figsize,
color=textfontcolor,
transform=ax.transAxes,
)
ax.text(
0.98,
0.5,
"A",
horizontalalignment="right",
verticalalignment="center",
fontsize=20 * figsize,
color=textfontcolor,
transform=ax.transAxes,
)
ax.text(
0.02,
0.5,
"P",
horizontalalignment="left",
verticalalignment="center",
fontsize=20 * figsize,
color=textfontcolor,
transform=ax.transAxes,
)
ax.axis("off")
xz_slice = reorient_slice(image[:, xyz[1], :], 1)
ax = plt.subplot(gs[0, 1])
ax.imshow(xz_slice, cmap=cmap, vmin=vmin, vmax=vmax)
if overlay is not None:
xz_overlay = reorient_slice(overlay[:, xyz[1], :], 1)
ax.imshow(xz_overlay, alpha=overlay_alpha, cmap=overlay_cmap, vmin=vminol, vmax=vmaxol )
if xyz_lines:
# add lines
l = mlines.Line2D(
[xz_slice.shape[0] - xyz[0], xz_slice.shape[0] - xyz[0]],
[xyz_pad, xz_slice.shape[0] - xyz_pad],
color=xyz_color,
alpha=xyz_alpha,
linewidth=xyz_linewidth,
)
ax.add_line(l)
l = mlines.Line2D(
[xyz_pad, xz_slice.shape[1] - xyz_pad],
[xz_slice.shape[1] - xyz[2], xz_slice.shape[1] - xyz[2]],
color=xyz_color,
alpha=xyz_alpha,
linewidth=xyz_linewidth,
)
ax.add_line(l)
if orient_labels:
ax.text(
0.5,
0.98,
"S",
horizontalalignment="center",
verticalalignment="top",
fontsize=20 * figsize,
color=textfontcolor,
transform=ax.transAxes,
)
ax.text(
0.5,
0.02,
"I",
horizontalalignment="center",
verticalalignment="bottom",
fontsize=20 * figsize,
color=textfontcolor,
transform=ax.transAxes,
)
ax.text(
0.98,
0.5,
"L",
horizontalalignment="right",
verticalalignment="center",
fontsize=20 * figsize,
color=textfontcolor,
transform=ax.transAxes,
)
ax.text(
0.02,
0.5,
"R",
horizontalalignment="left",
verticalalignment="center",
fontsize=20 * figsize,
color=textfontcolor,
transform=ax.transAxes,
)
ax.axis("off")
xy_slice = reorient_slice(image[:, :, xyz[2]], 2)
if not flat:
ax = plt.subplot(gs[1, 1])
else:
ax = plt.subplot(gs[0, 2])
im = ax.imshow(xy_slice, cmap=cmap, vmin=vmin, vmax=vmax)
if overlay is not None:
xy_overlay = reorient_slice(overlay[:, :, xyz[2]], 2)
im = ax.imshow(xy_overlay, alpha=overlay_alpha, cmap=overlay_cmap, vmin=vminol, vmax=vmaxol)
if xyz_lines:
# add lines
l = mlines.Line2D(
[xy_slice.shape[0] - xyz[0], xy_slice.shape[0] - xyz[0]],
[xyz_pad, xy_slice.shape[0] - xyz_pad],
color=xyz_color,
alpha=xyz_alpha,
linewidth=xyz_linewidth,
)
ax.add_line(l)
l = mlines.Line2D(
[xyz_pad, xy_slice.shape[1] - xyz_pad],
[xy_slice.shape[1] - xyz[1], xy_slice.shape[1] - xyz[1]],
color=xyz_color,
alpha=xyz_alpha,
linewidth=xyz_linewidth,
)
ax.add_line(l)
if orient_labels:
ax.text(
0.5,
0.98,
"A",
horizontalalignment="center",
verticalalignment="top",
fontsize=20 * figsize,
color=textfontcolor,
transform=ax.transAxes,
)
ax.text(
0.5,
0.02,
"P",
horizontalalignment="center",
verticalalignment="bottom",
fontsize=20 * figsize,
color=textfontcolor,
transform=ax.transAxes,
)
ax.text(
0.98,
0.5,
"L",
horizontalalignment="right",
verticalalignment="center",
fontsize=20 * figsize,
color=textfontcolor,
transform=ax.transAxes,
)
ax.text(
0.02,
0.5,
"R",
horizontalalignment="left",
verticalalignment="center",
fontsize=20 * figsize,
color=textfontcolor,
transform=ax.transAxes,
)
ax.axis("off")
if not flat:
# empty corner
ax = plt.subplot(gs[1, 0])
if text is not None:
# add text
left, width = 0.25, 0.5
bottom, height = 0.25, 0.5
right = left + width
top = bottom + height
ax.text(
0.5 * (left + right) + text_dx,
0.5 * (bottom + top) + text_dy,
text,
horizontalalignment="center",
verticalalignment="center",
fontsize=textfontsize,
color=textfontcolor,
transform=ax.transAxes,
)
# ax.text(0.5, 0.5)
ax.imshow(np.zeros(image.shape[:-1]), cmap="Greys_r")
ax.axis("off")
if cbar:
cbar_start = (1 - cbar_length) / 2
if cbar_vertical:
cax = fig.add_axes([0.9 + cbar_dx, cbar_start, 0.03, cbar_length])
cbar_orient = "vertical"
else:
cax = fig.add_axes([cbar_start, 0.08 + cbar_dx, cbar_length, 0.03])
cbar_orient = "horizontal"
fig.colorbar(im, cax=cax, orientation=cbar_orient)
## multi-channel images ##
elif image.components > 1:
raise ValueError("Multi-channel images not currently supported!")
if filename is not None:
plt.savefig(filename, dpi=dpi, transparent=transparent)
plt.close(fig)
else:
plt.show()
# turn warnings back to default
warnings.simplefilter("default")