Diff of /dosma/utils/img_utils.py [000000] .. [030aeb]

Switch to side-by-side view

--- a
+++ b/dosma/utils/img_utils.py
@@ -0,0 +1,127 @@
+import itertools
+
+import numpy as np
+import seaborn as sns
+
+from dosma import defaults
+
+import matplotlib.pyplot as plt
+from matplotlib.lines import Line2D
+
+__all__ = ["downsample_slice", "write_regions"]
+
+
+def downsample_slice(img_array, ds_factor, is_mask=False):
+    """
+    Takes in a 3D array and then downsamples in the z-direction by a
+    user-specified downsampling factor.
+
+    Args:
+        img_array (np.ndarray): 3D numpy array for now (xres x yres x zres)
+        ds_factor (int): Downsampling factor
+        is_mask (:obj:`bool`, optional): If ``True``, ``img_array`` is a mask and will be binarized
+            after downsampling. Defaults to `False`.
+
+    Returns:
+        np.ndarray: 3D numpy array of dimensions (xres x yres x zres//ds_factor)
+
+    Examples:
+        >>> input_image  = numpy.random.rand(4,4,4)
+        >>> input_mask   = (a > 0.5) * 1.0
+        >>> output_image = downsample_slice(input_mask, ds_factor = 2, is_mask = False)
+        >>> output_mask  = downsample_slice(input_mask, ds_factor = 2, is_mask = True)
+    """
+
+    img_array = np.transpose(img_array, (2, 0, 1))
+    L = list(img_array)
+
+    def grouper(iterable, n):
+        args = [iter(iterable)] * n
+        return itertools.zip_longest(fillvalue=0, *args)
+
+    final = np.array([sum(x) for x in grouper(L, ds_factor)])
+    final = np.transpose(final, (1, 2, 0))
+
+    # Binarize if it is a mask.
+    if is_mask is True:
+        final = (final >= 1) * 1
+
+    return final
+
+
+def write_regions(file_path, arr, plt_dict=None):
+    """Write 2D array to region image where colors correspond to the region.
+
+    All finite values should be >= 1.
+    nan/inf value are ignored - written as white.
+
+    Args:
+        file_path (str): File path to save image.
+        arr (np.ndarray): The 2D numpy array to convert to region image.
+            Unique non-zero values correspond to different regions.
+            Values that are `0` or `np.nan` will be written as white pixels.
+        plt_dict (:obj:`dict`, optional): Dictionary of values to use when plotting with
+            ``matplotlib.pyplot``. Keys are strings like `xlabel`, `ylabel`, etc.
+            Use Key `labels` to specify a mapping from unique non-zero values in the array
+            to names for the legend.
+    """
+
+    if len(arr.shape) != 2:
+        raise ValueError("`arr` must be a 2D numpy array")
+
+    unique_vals = np.unique(arr.flatten())
+    if 0 in unique_vals:
+        raise ValueError("All finite values in `arr` must be >=1")
+
+    unique_vals = unique_vals[np.isfinite(unique_vals)]
+    num_unique_vals = len(unique_vals)
+
+    plt_dict_int = {"xlabel": "", "ylabel": "", "title": "", "labels": None}
+    if plt_dict:
+        plt_dict_int.update(plt_dict)
+    plt_dict = plt_dict_int
+
+    labels = plt_dict["labels"]
+    if labels is None:
+        labels = list(unique_vals)
+
+    if len(labels) != num_unique_vals:
+        raise ValueError(
+            "len(labels) != num_unique_vals - %d != %d" % (len(labels), num_unique_vals)
+        )
+
+    cpal = sns.color_palette("pastel", num_unique_vals)
+
+    arr_c = np.array(arr)
+    arr_c = np.nan_to_num(arr_c)
+    arr_c[arr_c > np.max(unique_vals)] = 0
+    arr_rgb = np.ones([arr_c.shape[0], arr_c.shape[1], 3])
+
+    plt.figure()
+    plt.clf()
+    custom_lines = []
+    for i in range(num_unique_vals):
+        unique_val = unique_vals[i]
+        i0, i1 = np.where(arr_c == unique_val)
+        arr_rgb[i0, i1, ...] = np.asarray(cpal[i])
+
+        custom_lines.append(
+            Line2D([], [], color=cpal[i], marker="o", linestyle="None", markersize=5)
+        )
+
+    plt.xlabel(plt_dict["xlabel"])
+    plt.ylabel(plt_dict["ylabel"])
+    plt.title(plt_dict["title"])
+
+    lgd = plt.legend(
+        custom_lines,
+        labels,
+        loc="upper center",
+        bbox_to_anchor=(0.5, -defaults.DEFAULT_TEXT_SPACING),
+        fancybox=True,
+        shadow=True,
+        ncol=3,
+    )
+    plt.imshow(arr_rgb)
+
+    plt.savefig(file_path, bbox_extra_artists=(lgd,), bbox_inches="tight")