--- a +++ b/ants/plotting/plot.py @@ -0,0 +1,486 @@ +""" +Functions for plotting ants images +""" + + +__all__ = [ + "plot" +] + +import fnmatch +import math +import os +import warnings + +from matplotlib import gridspec +import matplotlib.pyplot as plt +import matplotlib.patheffects as path_effects +import matplotlib.lines as mlines +import matplotlib.patches as patches +import matplotlib.mlab as mlab +import matplotlib.animation as animation +from mpl_toolkits.axes_grid1.inset_locator import inset_axes + +import numpy as np +import ants +from ants.decorators import image_method + +@image_method +def plot( + image, + overlay=None, + blend=False, + alpha=1, + cmap="Greys_r", + overlay_cmap="turbo", + overlay_alpha=0.9, + vminol=None, + vmaxol=None, + cbar=False, + cbar_length=0.8, + cbar_dx=0.0, + cbar_vertical=True, + axis=0, + nslices=12, + slices=None, + ncol=None, + slice_buffer=None, + black_bg=True, + bg_thresh_quant=0.01, + bg_val_quant=0.99, + domain_image_map=None, + crop=False, + scale=False, + reverse=False, + title=None, + title_fontsize=20, + title_dx=0.0, + title_dy=0.0, + filename=None, + dpi=500, + figsize=1.5, + reorient=True, + resample=True, +): + """ + Plot an ANTsImage. + + Use mask_image and/or threshold_image to preprocess images to be be + overlaid and display the overlays in a given range. See the wiki examples. + + By default, images will be reoriented to 'LAI' orientation before plotting. + So, if axis == 0, the images will be ordered from the + left side of the brain to the right side of the brain. If axis == 1, + the images will be ordered from the anterior (front) of the brain to + the posterior (back) of the brain. And if axis == 2, the images will + be ordered from the inferior (bottom) of the brain to the superior (top) + of the brain. + + ANTsR function: `plot.antsImage` + + Arguments + --------- + image : ANTsImage + image to plot + + overlay : ANTsImage + image to overlay on base image + + cmap : string + colormap to use for base image. See matplotlib. + + overlay_cmap : string + colormap to use for overlay images, if applicable. See matplotlib. + + overlay_alpha : float + level of transparency for any overlays. Smaller value means + the overlay is more transparent. See matplotlib. + + axis : integer + which axis to plot along if image is 3D + + nslices : integer + number of slices to plot if image is 3D + + slices : list or tuple of integers + specific slice indices to plot if image is 3D. + If given, this will override `nslices`. + This can be absolute array indices (e.g. (80,100,120)), or + this can be relative array indices (e.g. (0.4,0.5,0.6)) + + ncol : integer + Number of columns to have on the plot if image is 3D. + + slice_buffer : integer + how many slices to buffer when finding the non-zero slices of + a 3D images. So, if slice_buffer = 10, then the first slice + in a 3D image will be the first non-zero slice index plus 10 more + slices. + + black_bg : boolean + if True, the background of the image(s) will be black. + if False, the background of the image(s) will be determined by the + values `bg_thresh_quant` and `bg_val_quant`. + + bg_thresh_quant : float + if white_bg=True, the background will be determined by thresholding + the image at the `bg_thresh` quantile value and setting the background + intensity to the `bg_val` quantile value. + This value should be in [0, 1] - somewhere around 0.01 is recommended. + - equal to 1 will threshold the entire image + - equal to 0 will threshold none of the image + + bg_val_quant : float + if white_bg=True, the background will be determined by thresholding + the image at the `bg_thresh` quantile value and setting the background + intensity to the `bg_val` quantile value. + This value should be in [0, 1] + - equal to 1 is pure white + - equal to 0 is pure black + - somewhere in between is gray + + domain_image_map : ANTsImage + this input ANTsImage or list of ANTsImage types contains a reference image + `domain_image` and optional reference mapping named `domainMap`. + If supplied, the image(s) to be plotted will be mapped to the domain + image space before plotting - useful for non-standard image orientations. + + crop : boolean + if true, the image(s) will be cropped to their bounding boxes, resulting + in a potentially smaller image size. + if false, the image(s) will not be cropped + + scale : boolean or 2-tuple + if true, nothing will happen to intensities of image(s) and overlay(s) + if false, dynamic range will be maximized when visualizing overlays + if 2-tuple, the image will be dynamically scaled between these quantiles + + reverse : boolean + if true, the order in which the slices are plotted will be reversed. + This is useful if you want to plot from the front of the brain first + to the back of the brain, or vice-versa + + title : string + add a title to the plot + + filename : string + if given, the resulting image will be saved to this file + + dpi : integer + determines resolution of image if saved to file. Higher values + result in higher resolution images, but at a cost of having a + larger file size + + resample : bool + if true, resample image if spacing is very unbalanced. + + Example + ------- + >>> import ants + >>> import numpy as np + >>> img = ants.image_read(ants.get_data('r16')) + >>> segs = img.kmeans_segmentation(k=3)['segmentation'] + >>> ants.plot(img, segs*(segs==1), crop=True) + >>> ants.plot(img, segs*(segs==1), crop=False) + >>> mni = ants.image_read(ants.get_data('mni')) + >>> segs = mni.kmeans_segmentation(k=3)['segmentation'] + >>> ants.plot(mni, segs*(segs==1), crop=False) + """ + if (axis == "x") or (axis == "saggittal"): + axis = 0 + if (axis == "y") or (axis == "coronal"): + axis = 1 + if (axis == "z") or (axis == "axial"): + axis = 2 + + def mirror_matrix(x): + return x[::-1, :] + + def rotate270_matrix(x): + return mirror_matrix(x.T) + + def rotate180_matrix(x): + return x[::-1, ::-1] + + def rotate90_matrix(x): + return x.T + + def reorient_slice(x, axis): + if axis != 2: + x = rotate90_matrix(x) + if axis == 2: + x = rotate270_matrix(x) + x = mirror_matrix(x) + return x + + + # handle `image` argument + if isinstance(image, str): + image = ants.image_read(image) + if not ants.is_image(image): + raise ValueError("image argument must be an ANTsImage") + + if np.all(np.equal(image.numpy(), 0.0)): + warnings.warn("Image must be non-zero. will not plot.") + return + + # need this hack because of a weird NaN warning from matplotlib with overlays + warnings.simplefilter("ignore") + + if (image.pixeltype not in {"float", "double"}) or (image.is_rgb): + scale = False # turn off scaling if image is discrete + + # handle `overlay` argument + if overlay is not None: + if isinstance(overlay, str): + overlay = ants.image_read(overlay) + if vminol is None: + vminol = overlay.min() + if vmaxol is None: + vmaxol = overlay.max() + if not ants.is_image(overlay): + raise ValueError("overlay argument must be an ANTsImage") + if overlay.components > 1: + raise ValueError("overlay cannot have more than one voxel component") + + if not ants.image_physical_space_consistency(image, overlay): + overlay = ants.resample_image_to_target(overlay, image, interp_type="nearestNeighbor") + + if blend: + if alpha == 1: + alpha = 0.5 + image = image * alpha + overlay * (1 - alpha) + overlay = None + alpha = 1.0 + + # handle `domain_image_map` argument + if domain_image_map is not None: + tx = ants.new_ants_transform( + precision="float", + transform_type="AffineTransform", + dimension=image.dimension, + ) + image = ants.apply_ants_transform_to_image(tx, image, domain_image_map) + if overlay is not None: + overlay = ants.apply_ants_transform_to_image( + tx, overlay, domain_image_map, interpolation="nearestNeighbor" + ) + + ## single-channel images ## + if image.components == 1: + + # potentially crop image + if crop: + plotmask = image.get_mask(cleanup=0) + if plotmask.max() == 0: + plotmask += 1 + image = image.crop_image(plotmask) + if overlay is not None: + overlay = overlay.crop_image(plotmask) + + # potentially find dynamic range + if scale == True: + vmin, vmax = image.quantile((0.05, 0.95)) + elif isinstance(scale, (list, tuple)): + if len(scale) != 2: + raise ValueError( + "scale argument must be boolean or list/tuple with two values" + ) + vmin, vmax = image.quantile(scale) + else: + vmin = None + vmax = None + + # Plot 2D image + if image.dimension == 2: + + img_arr = image.numpy() + img_arr = rotate90_matrix(img_arr) + + if not black_bg: + img_arr[img_arr < image.quantile(bg_thresh_quant)] = image.quantile( + bg_val_quant + ) + + if overlay is not None: + ov_arr = overlay.numpy() + mask = ov_arr == 0 + mask = np.ma.masked_where(mask == 0, mask) + ov_arr = np.ma.masked_array(ov_arr, mask) + ov_arr = rotate90_matrix(ov_arr) + + fig = plt.figure() + if title is not None: + fig.suptitle( + title, fontsize=title_fontsize, x=0.5 + title_dx, y=0.95 + title_dy + ) + + ax = plt.subplot(111) + + # plot main image + im = ax.imshow(img_arr, cmap=cmap, alpha=alpha, vmin=vmin, vmax=vmax) + + if overlay is not None: + im = ax.imshow(ov_arr, alpha=overlay_alpha, cmap=overlay_cmap, + vmin=vminol, vmax=vmaxol ) + + if cbar: + cbar_orient = "vertical" if cbar_vertical else "horizontal" + fig.colorbar(im, orientation=cbar_orient) + + plt.axis("off") + + # Plot 3D image + elif image.dimension == 3: + # resample image if spacing is very unbalanced + spacing = [s for i, s in enumerate(image.spacing) if i != axis] + was_resampled = False + if (max(spacing) / min(spacing)) > 3.0 and resample: + was_resampled = True + new_spacing = (1, 1, 1) + image = image.resample_image(tuple(new_spacing)) + if overlay is not None: + overlay = overlay.resample_image(tuple(new_spacing)) + + if reorient: + image = image.reorient_image2("LAI") + img_arr = image.numpy() + # reorder dims so that chosen axis is first + img_arr = np.rollaxis(img_arr, axis) + + if overlay is not None: + if reorient: + overlay = overlay.reorient_image2("LAI") + ov_arr = overlay.numpy() + mask = ov_arr == 0 + mask = np.ma.masked_where(mask == 0, mask) + ov_arr = np.ma.masked_array(ov_arr, mask) + ov_arr = np.rollaxis(ov_arr, axis) + + if slices is None: + if not isinstance(slice_buffer, (list, tuple)): + if slice_buffer is None: + slice_buffer = ( + int(img_arr.shape[1] * 0.1), + int(img_arr.shape[2] * 0.1), + ) + else: + slice_buffer = (slice_buffer, slice_buffer) + nonzero = np.where(img_arr.sum(axis=(1, 2)) > 0.01)[0] + min_idx = nonzero[0] + slice_buffer[0] + max_idx = nonzero[-1] - slice_buffer[1] + if min_idx > max_idx: + temp = min_idx + min_idx = max_idx + max_idx = temp + if max_idx > nonzero.max(): + max_idx = nonzero.max() + if min_idx < 0: + min_idx = 0 + slice_idxs = np.linspace(min_idx, max_idx, nslices).astype("int") + if reverse: + slice_idxs = np.array(list(reversed(slice_idxs))) + else: + if isinstance(slices, (int, float)): + slices = [slices] + # if all slices are less than 1, infer that they are relative slices + if sum([s > 1 for s in slices]) == 0: + slices = [int(s * img_arr.shape[0]) for s in slices] + slice_idxs = slices + nslices = len(slices) + + if was_resampled: + # re-calculate slices to account for new image shape + slice_idxs = np.unique( + np.array( + [ + int(s * (image.shape[axis] / img_arr.shape[0])) + for s in slice_idxs + ] + ) + ) + + # only have one row if nslices <= 6 and user didnt specify ncol + if ncol is None: + if nslices <= 6: + ncol = nslices + else: + ncol = int(round(math.sqrt(nslices))) + + # calculate grid size + nrow = math.ceil(nslices / ncol) + xdim = img_arr.shape[2] + ydim = img_arr.shape[1] + + dim_ratio = ydim / xdim + fig = plt.figure( + figsize=((ncol + 1) * figsize * dim_ratio, (nrow + 1) * figsize) + ) + if title is not None: + fig.suptitle( + title, fontsize=title_fontsize, x=0.5 + title_dx, y=0.95 + title_dy + ) + + gs = gridspec.GridSpec( + nrow, + ncol, + wspace=0.0, + hspace=0.0, + top=1.0 - 0.5 / (nrow + 1), + bottom=0.5 / (nrow + 1), + left=0.5 / (ncol + 1), + right=1 - 0.5 / (ncol + 1), + ) + + slice_idx_idx = 0 + for i in range(nrow): + for j in range(ncol): + if slice_idx_idx < len(slice_idxs): + imslice = img_arr[slice_idxs[slice_idx_idx]] + imslice = reorient_slice(imslice, axis) + if not black_bg: + imslice[ + imslice < image.quantile(bg_thresh_quant) + ] = image.quantile(bg_val_quant) + else: + imslice = np.zeros_like(img_arr[0]) + imslice = reorient_slice(imslice, axis) + + ax = plt.subplot(gs[i, j]) + im = ax.imshow(imslice, cmap=cmap, vmin=vmin, vmax=vmax) + + if overlay is not None: + if slice_idx_idx < len(slice_idxs): + ovslice = ov_arr[slice_idxs[slice_idx_idx]] + ovslice = reorient_slice(ovslice, axis) + im = ax.imshow( + ovslice, alpha=overlay_alpha, cmap=overlay_cmap, + vmin=vminol, vmax=vmaxol ) + ax.axis("off") + slice_idx_idx += 1 + + if cbar: + cbar_start = (1 - cbar_length) / 2 + if cbar_vertical: + cax = fig.add_axes([0.9 + cbar_dx, cbar_start, 0.03, cbar_length]) + cbar_orient = "vertical" + else: + cax = fig.add_axes([cbar_start, 0.08 + cbar_dx, cbar_length, 0.03]) + cbar_orient = "horizontal" + fig.colorbar(im, cax=cax, orientation=cbar_orient) + + ## multi-channel images ## + elif image.has_components: + raise Exception('Plotting images with components is not currently supported.') + + if filename is not None: + filename = os.path.expanduser(filename) + plt.savefig(filename, dpi=dpi, transparent=True, bbox_inches="tight") + plt.close(fig) + else: + plt.show() + + # turn warnings back to default + warnings.simplefilter("default") + +