a b/dosma/utils/img_utils.py
1
import itertools
2
3
import numpy as np
4
import seaborn as sns
5
6
from dosma import defaults
7
8
import matplotlib.pyplot as plt
9
from matplotlib.lines import Line2D
10
11
__all__ = ["downsample_slice", "write_regions"]
12
13
14
def downsample_slice(img_array, ds_factor, is_mask=False):
15
    """
16
    Takes in a 3D array and then downsamples in the z-direction by a
17
    user-specified downsampling factor.
18
19
    Args:
20
        img_array (np.ndarray): 3D numpy array for now (xres x yres x zres)
21
        ds_factor (int): Downsampling factor
22
        is_mask (:obj:`bool`, optional): If ``True``, ``img_array`` is a mask and will be binarized
23
            after downsampling. Defaults to `False`.
24
25
    Returns:
26
        np.ndarray: 3D numpy array of dimensions (xres x yres x zres//ds_factor)
27
28
    Examples:
29
        >>> input_image  = numpy.random.rand(4,4,4)
30
        >>> input_mask   = (a > 0.5) * 1.0
31
        >>> output_image = downsample_slice(input_mask, ds_factor = 2, is_mask = False)
32
        >>> output_mask  = downsample_slice(input_mask, ds_factor = 2, is_mask = True)
33
    """
34
35
    img_array = np.transpose(img_array, (2, 0, 1))
36
    L = list(img_array)
37
38
    def grouper(iterable, n):
39
        args = [iter(iterable)] * n
40
        return itertools.zip_longest(fillvalue=0, *args)
41
42
    final = np.array([sum(x) for x in grouper(L, ds_factor)])
43
    final = np.transpose(final, (1, 2, 0))
44
45
    # Binarize if it is a mask.
46
    if is_mask is True:
47
        final = (final >= 1) * 1
48
49
    return final
50
51
52
def write_regions(file_path, arr, plt_dict=None):
53
    """Write 2D array to region image where colors correspond to the region.
54
55
    All finite values should be >= 1.
56
    nan/inf value are ignored - written as white.
57
58
    Args:
59
        file_path (str): File path to save image.
60
        arr (np.ndarray): The 2D numpy array to convert to region image.
61
            Unique non-zero values correspond to different regions.
62
            Values that are `0` or `np.nan` will be written as white pixels.
63
        plt_dict (:obj:`dict`, optional): Dictionary of values to use when plotting with
64
            ``matplotlib.pyplot``. Keys are strings like `xlabel`, `ylabel`, etc.
65
            Use Key `labels` to specify a mapping from unique non-zero values in the array
66
            to names for the legend.
67
    """
68
69
    if len(arr.shape) != 2:
70
        raise ValueError("`arr` must be a 2D numpy array")
71
72
    unique_vals = np.unique(arr.flatten())
73
    if 0 in unique_vals:
74
        raise ValueError("All finite values in `arr` must be >=1")
75
76
    unique_vals = unique_vals[np.isfinite(unique_vals)]
77
    num_unique_vals = len(unique_vals)
78
79
    plt_dict_int = {"xlabel": "", "ylabel": "", "title": "", "labels": None}
80
    if plt_dict:
81
        plt_dict_int.update(plt_dict)
82
    plt_dict = plt_dict_int
83
84
    labels = plt_dict["labels"]
85
    if labels is None:
86
        labels = list(unique_vals)
87
88
    if len(labels) != num_unique_vals:
89
        raise ValueError(
90
            "len(labels) != num_unique_vals - %d != %d" % (len(labels), num_unique_vals)
91
        )
92
93
    cpal = sns.color_palette("pastel", num_unique_vals)
94
95
    arr_c = np.array(arr)
96
    arr_c = np.nan_to_num(arr_c)
97
    arr_c[arr_c > np.max(unique_vals)] = 0
98
    arr_rgb = np.ones([arr_c.shape[0], arr_c.shape[1], 3])
99
100
    plt.figure()
101
    plt.clf()
102
    custom_lines = []
103
    for i in range(num_unique_vals):
104
        unique_val = unique_vals[i]
105
        i0, i1 = np.where(arr_c == unique_val)
106
        arr_rgb[i0, i1, ...] = np.asarray(cpal[i])
107
108
        custom_lines.append(
109
            Line2D([], [], color=cpal[i], marker="o", linestyle="None", markersize=5)
110
        )
111
112
    plt.xlabel(plt_dict["xlabel"])
113
    plt.ylabel(plt_dict["ylabel"])
114
    plt.title(plt_dict["title"])
115
116
    lgd = plt.legend(
117
        custom_lines,
118
        labels,
119
        loc="upper center",
120
        bbox_to_anchor=(0.5, -defaults.DEFAULT_TEXT_SPACING),
121
        fancybox=True,
122
        shadow=True,
123
        ncol=3,
124
    )
125
    plt.imshow(arr_rgb)
126
127
    plt.savefig(file_path, bbox_extra_artists=(lgd,), bbox_inches="tight")