|
a |
|
b/dosma/utils/img_utils.py |
|
|
1 |
import itertools |
|
|
2 |
|
|
|
3 |
import numpy as np |
|
|
4 |
import seaborn as sns |
|
|
5 |
|
|
|
6 |
from dosma import defaults |
|
|
7 |
|
|
|
8 |
import matplotlib.pyplot as plt |
|
|
9 |
from matplotlib.lines import Line2D |
|
|
10 |
|
|
|
11 |
__all__ = ["downsample_slice", "write_regions"] |
|
|
12 |
|
|
|
13 |
|
|
|
14 |
def downsample_slice(img_array, ds_factor, is_mask=False): |
|
|
15 |
""" |
|
|
16 |
Takes in a 3D array and then downsamples in the z-direction by a |
|
|
17 |
user-specified downsampling factor. |
|
|
18 |
|
|
|
19 |
Args: |
|
|
20 |
img_array (np.ndarray): 3D numpy array for now (xres x yres x zres) |
|
|
21 |
ds_factor (int): Downsampling factor |
|
|
22 |
is_mask (:obj:`bool`, optional): If ``True``, ``img_array`` is a mask and will be binarized |
|
|
23 |
after downsampling. Defaults to `False`. |
|
|
24 |
|
|
|
25 |
Returns: |
|
|
26 |
np.ndarray: 3D numpy array of dimensions (xres x yres x zres//ds_factor) |
|
|
27 |
|
|
|
28 |
Examples: |
|
|
29 |
>>> input_image = numpy.random.rand(4,4,4) |
|
|
30 |
>>> input_mask = (a > 0.5) * 1.0 |
|
|
31 |
>>> output_image = downsample_slice(input_mask, ds_factor = 2, is_mask = False) |
|
|
32 |
>>> output_mask = downsample_slice(input_mask, ds_factor = 2, is_mask = True) |
|
|
33 |
""" |
|
|
34 |
|
|
|
35 |
img_array = np.transpose(img_array, (2, 0, 1)) |
|
|
36 |
L = list(img_array) |
|
|
37 |
|
|
|
38 |
def grouper(iterable, n): |
|
|
39 |
args = [iter(iterable)] * n |
|
|
40 |
return itertools.zip_longest(fillvalue=0, *args) |
|
|
41 |
|
|
|
42 |
final = np.array([sum(x) for x in grouper(L, ds_factor)]) |
|
|
43 |
final = np.transpose(final, (1, 2, 0)) |
|
|
44 |
|
|
|
45 |
# Binarize if it is a mask. |
|
|
46 |
if is_mask is True: |
|
|
47 |
final = (final >= 1) * 1 |
|
|
48 |
|
|
|
49 |
return final |
|
|
50 |
|
|
|
51 |
|
|
|
52 |
def write_regions(file_path, arr, plt_dict=None): |
|
|
53 |
"""Write 2D array to region image where colors correspond to the region. |
|
|
54 |
|
|
|
55 |
All finite values should be >= 1. |
|
|
56 |
nan/inf value are ignored - written as white. |
|
|
57 |
|
|
|
58 |
Args: |
|
|
59 |
file_path (str): File path to save image. |
|
|
60 |
arr (np.ndarray): The 2D numpy array to convert to region image. |
|
|
61 |
Unique non-zero values correspond to different regions. |
|
|
62 |
Values that are `0` or `np.nan` will be written as white pixels. |
|
|
63 |
plt_dict (:obj:`dict`, optional): Dictionary of values to use when plotting with |
|
|
64 |
``matplotlib.pyplot``. Keys are strings like `xlabel`, `ylabel`, etc. |
|
|
65 |
Use Key `labels` to specify a mapping from unique non-zero values in the array |
|
|
66 |
to names for the legend. |
|
|
67 |
""" |
|
|
68 |
|
|
|
69 |
if len(arr.shape) != 2: |
|
|
70 |
raise ValueError("`arr` must be a 2D numpy array") |
|
|
71 |
|
|
|
72 |
unique_vals = np.unique(arr.flatten()) |
|
|
73 |
if 0 in unique_vals: |
|
|
74 |
raise ValueError("All finite values in `arr` must be >=1") |
|
|
75 |
|
|
|
76 |
unique_vals = unique_vals[np.isfinite(unique_vals)] |
|
|
77 |
num_unique_vals = len(unique_vals) |
|
|
78 |
|
|
|
79 |
plt_dict_int = {"xlabel": "", "ylabel": "", "title": "", "labels": None} |
|
|
80 |
if plt_dict: |
|
|
81 |
plt_dict_int.update(plt_dict) |
|
|
82 |
plt_dict = plt_dict_int |
|
|
83 |
|
|
|
84 |
labels = plt_dict["labels"] |
|
|
85 |
if labels is None: |
|
|
86 |
labels = list(unique_vals) |
|
|
87 |
|
|
|
88 |
if len(labels) != num_unique_vals: |
|
|
89 |
raise ValueError( |
|
|
90 |
"len(labels) != num_unique_vals - %d != %d" % (len(labels), num_unique_vals) |
|
|
91 |
) |
|
|
92 |
|
|
|
93 |
cpal = sns.color_palette("pastel", num_unique_vals) |
|
|
94 |
|
|
|
95 |
arr_c = np.array(arr) |
|
|
96 |
arr_c = np.nan_to_num(arr_c) |
|
|
97 |
arr_c[arr_c > np.max(unique_vals)] = 0 |
|
|
98 |
arr_rgb = np.ones([arr_c.shape[0], arr_c.shape[1], 3]) |
|
|
99 |
|
|
|
100 |
plt.figure() |
|
|
101 |
plt.clf() |
|
|
102 |
custom_lines = [] |
|
|
103 |
for i in range(num_unique_vals): |
|
|
104 |
unique_val = unique_vals[i] |
|
|
105 |
i0, i1 = np.where(arr_c == unique_val) |
|
|
106 |
arr_rgb[i0, i1, ...] = np.asarray(cpal[i]) |
|
|
107 |
|
|
|
108 |
custom_lines.append( |
|
|
109 |
Line2D([], [], color=cpal[i], marker="o", linestyle="None", markersize=5) |
|
|
110 |
) |
|
|
111 |
|
|
|
112 |
plt.xlabel(plt_dict["xlabel"]) |
|
|
113 |
plt.ylabel(plt_dict["ylabel"]) |
|
|
114 |
plt.title(plt_dict["title"]) |
|
|
115 |
|
|
|
116 |
lgd = plt.legend( |
|
|
117 |
custom_lines, |
|
|
118 |
labels, |
|
|
119 |
loc="upper center", |
|
|
120 |
bbox_to_anchor=(0.5, -defaults.DEFAULT_TEXT_SPACING), |
|
|
121 |
fancybox=True, |
|
|
122 |
shadow=True, |
|
|
123 |
ncol=3, |
|
|
124 |
) |
|
|
125 |
plt.imshow(arr_rgb) |
|
|
126 |
|
|
|
127 |
plt.savefig(file_path, bbox_extra_artists=(lgd,), bbox_inches="tight") |