--- a +++ b/dosma/utils/img_utils.py @@ -0,0 +1,127 @@ +import itertools + +import numpy as np +import seaborn as sns + +from dosma import defaults + +import matplotlib.pyplot as plt +from matplotlib.lines import Line2D + +__all__ = ["downsample_slice", "write_regions"] + + +def downsample_slice(img_array, ds_factor, is_mask=False): + """ + Takes in a 3D array and then downsamples in the z-direction by a + user-specified downsampling factor. + + Args: + img_array (np.ndarray): 3D numpy array for now (xres x yres x zres) + ds_factor (int): Downsampling factor + is_mask (:obj:`bool`, optional): If ``True``, ``img_array`` is a mask and will be binarized + after downsampling. Defaults to `False`. + + Returns: + np.ndarray: 3D numpy array of dimensions (xres x yres x zres//ds_factor) + + Examples: + >>> input_image = numpy.random.rand(4,4,4) + >>> input_mask = (a > 0.5) * 1.0 + >>> output_image = downsample_slice(input_mask, ds_factor = 2, is_mask = False) + >>> output_mask = downsample_slice(input_mask, ds_factor = 2, is_mask = True) + """ + + img_array = np.transpose(img_array, (2, 0, 1)) + L = list(img_array) + + def grouper(iterable, n): + args = [iter(iterable)] * n + return itertools.zip_longest(fillvalue=0, *args) + + final = np.array([sum(x) for x in grouper(L, ds_factor)]) + final = np.transpose(final, (1, 2, 0)) + + # Binarize if it is a mask. + if is_mask is True: + final = (final >= 1) * 1 + + return final + + +def write_regions(file_path, arr, plt_dict=None): + """Write 2D array to region image where colors correspond to the region. + + All finite values should be >= 1. + nan/inf value are ignored - written as white. + + Args: + file_path (str): File path to save image. + arr (np.ndarray): The 2D numpy array to convert to region image. + Unique non-zero values correspond to different regions. + Values that are `0` or `np.nan` will be written as white pixels. + plt_dict (:obj:`dict`, optional): Dictionary of values to use when plotting with + ``matplotlib.pyplot``. Keys are strings like `xlabel`, `ylabel`, etc. + Use Key `labels` to specify a mapping from unique non-zero values in the array + to names for the legend. + """ + + if len(arr.shape) != 2: + raise ValueError("`arr` must be a 2D numpy array") + + unique_vals = np.unique(arr.flatten()) + if 0 in unique_vals: + raise ValueError("All finite values in `arr` must be >=1") + + unique_vals = unique_vals[np.isfinite(unique_vals)] + num_unique_vals = len(unique_vals) + + plt_dict_int = {"xlabel": "", "ylabel": "", "title": "", "labels": None} + if plt_dict: + plt_dict_int.update(plt_dict) + plt_dict = plt_dict_int + + labels = plt_dict["labels"] + if labels is None: + labels = list(unique_vals) + + if len(labels) != num_unique_vals: + raise ValueError( + "len(labels) != num_unique_vals - %d != %d" % (len(labels), num_unique_vals) + ) + + cpal = sns.color_palette("pastel", num_unique_vals) + + arr_c = np.array(arr) + arr_c = np.nan_to_num(arr_c) + arr_c[arr_c > np.max(unique_vals)] = 0 + arr_rgb = np.ones([arr_c.shape[0], arr_c.shape[1], 3]) + + plt.figure() + plt.clf() + custom_lines = [] + for i in range(num_unique_vals): + unique_val = unique_vals[i] + i0, i1 = np.where(arr_c == unique_val) + arr_rgb[i0, i1, ...] = np.asarray(cpal[i]) + + custom_lines.append( + Line2D([], [], color=cpal[i], marker="o", linestyle="None", markersize=5) + ) + + plt.xlabel(plt_dict["xlabel"]) + plt.ylabel(plt_dict["ylabel"]) + plt.title(plt_dict["title"]) + + lgd = plt.legend( + custom_lines, + labels, + loc="upper center", + bbox_to_anchor=(0.5, -defaults.DEFAULT_TEXT_SPACING), + fancybox=True, + shadow=True, + ncol=3, + ) + plt.imshow(arr_rgb) + + plt.savefig(file_path, bbox_extra_artists=(lgd,), bbox_inches="tight")