|
a |
|
b/ants/plotting/plot_grid.py |
|
|
1 |
""" |
|
|
2 |
Functions for plotting ants images |
|
|
3 |
""" |
|
|
4 |
|
|
|
5 |
|
|
|
6 |
__all__ = [ |
|
|
7 |
"plot_grid" |
|
|
8 |
] |
|
|
9 |
|
|
|
10 |
import fnmatch |
|
|
11 |
import math |
|
|
12 |
import os |
|
|
13 |
import warnings |
|
|
14 |
|
|
|
15 |
from matplotlib import gridspec |
|
|
16 |
import matplotlib.pyplot as plt |
|
|
17 |
import matplotlib.patheffects as path_effects |
|
|
18 |
import matplotlib.lines as mlines |
|
|
19 |
import matplotlib.patches as patches |
|
|
20 |
import matplotlib.mlab as mlab |
|
|
21 |
import matplotlib.animation as animation |
|
|
22 |
from mpl_toolkits.axes_grid1.inset_locator import inset_axes |
|
|
23 |
|
|
|
24 |
|
|
|
25 |
import numpy as np |
|
|
26 |
|
|
|
27 |
|
|
|
28 |
def plot_grid( |
|
|
29 |
images, |
|
|
30 |
slices=None, |
|
|
31 |
axes=2, |
|
|
32 |
# general figure arguments |
|
|
33 |
figsize=1.0, |
|
|
34 |
rpad=0, |
|
|
35 |
cpad=0, |
|
|
36 |
vmin=None, |
|
|
37 |
vmax=None, |
|
|
38 |
colorbar=True, |
|
|
39 |
cmap="Greys_r", |
|
|
40 |
# title arguments |
|
|
41 |
title=None, |
|
|
42 |
tfontsize=20, |
|
|
43 |
title_dx=0, |
|
|
44 |
title_dy=0, |
|
|
45 |
# row arguments |
|
|
46 |
rlabels=None, |
|
|
47 |
rfontsize=14, |
|
|
48 |
rfontcolor="white", |
|
|
49 |
rfacecolor="black", |
|
|
50 |
# column arguments |
|
|
51 |
clabels=None, |
|
|
52 |
cfontsize=14, |
|
|
53 |
cfontcolor="white", |
|
|
54 |
cfacecolor="black", |
|
|
55 |
# save arguments |
|
|
56 |
filename=None, |
|
|
57 |
dpi=400, |
|
|
58 |
transparent=True, |
|
|
59 |
# other args |
|
|
60 |
**kwargs |
|
|
61 |
): |
|
|
62 |
""" |
|
|
63 |
Plot a collection of images in an arbitrarily-defined grid |
|
|
64 |
|
|
|
65 |
Matplotlib named colors: https://matplotlib.org/examples/color/named_colors.html |
|
|
66 |
|
|
|
67 |
Arguments |
|
|
68 |
--------- |
|
|
69 |
images : list of ANTsImage types |
|
|
70 |
image(s) to plot. |
|
|
71 |
if one image, this image will be used for all grid locations. |
|
|
72 |
if multiple images, they should be arrange in a list the same |
|
|
73 |
shape as the `gridsize` argument. |
|
|
74 |
|
|
|
75 |
slices : integer or list of integers |
|
|
76 |
slice indices to plot |
|
|
77 |
if one integer, this slice index will be used for all images |
|
|
78 |
if multiple integers, they should be arranged in a list the same |
|
|
79 |
shape as the `gridsize` argument |
|
|
80 |
|
|
|
81 |
axes : integer or list of integers |
|
|
82 |
axis or axes along which to plot image slices |
|
|
83 |
if one integer, this axis will be used for all images |
|
|
84 |
if multiple integers, they should be arranged in a list the same |
|
|
85 |
shape as the `gridsize` argument |
|
|
86 |
|
|
|
87 |
Example |
|
|
88 |
------- |
|
|
89 |
>>> import ants |
|
|
90 |
>>> import numpy as np |
|
|
91 |
>>> mni1 = ants.image_read(ants.get_data('mni')) |
|
|
92 |
>>> mni2 = mni1.smooth_image(1.) |
|
|
93 |
>>> mni3 = mni1.smooth_image(2.) |
|
|
94 |
>>> mni4 = mni1.smooth_image(3.) |
|
|
95 |
>>> images = np.asarray([[mni1, mni2], |
|
|
96 |
... [mni3, mni4]]) |
|
|
97 |
>>> slices = np.asarray([[100, 100], |
|
|
98 |
... [100, 100]]) |
|
|
99 |
>>> ants.plot_grid(images=images, slices=slices, title='2x2 Grid') |
|
|
100 |
>>> images2d = np.asarray([[mni1.slice_image(2,100), mni2.slice_image(2,100)], |
|
|
101 |
... [mni3.slice_image(2,100), mni4.slice_image(2,100)]]) |
|
|
102 |
>>> ants.plot_grid(images=images2d, title='2x2 Grid Pre-Sliced') |
|
|
103 |
>>> ants.plot_grid(images.reshape(1,4), slices.reshape(1,4), title='1x4 Grid') |
|
|
104 |
>>> ants.plot_grid(images.reshape(4,1), slices.reshape(4,1), title='4x1 Grid') |
|
|
105 |
|
|
|
106 |
>>> # Padding between rows and/or columns |
|
|
107 |
>>> ants.plot_grid(images, slices, cpad=0.02, title='Col Padding') |
|
|
108 |
>>> ants.plot_grid(images, slices, rpad=0.02, title='Row Padding') |
|
|
109 |
>>> ants.plot_grid(images, slices, rpad=0.02, cpad=0.02, title='Row and Col Padding') |
|
|
110 |
|
|
|
111 |
>>> # Adding plain row and/or column labels |
|
|
112 |
>>> ants.plot_grid(images, slices, title='Adding Row Labels', rlabels=['Row #1', 'Row #2']) |
|
|
113 |
>>> ants.plot_grid(images, slices, title='Adding Col Labels', clabels=['Col #1', 'Col #2']) |
|
|
114 |
>>> ants.plot_grid(images, slices, title='Row and Col Labels', |
|
|
115 |
rlabels=['Row 1', 'Row 2'], clabels=['Col 1', 'Col 2']) |
|
|
116 |
|
|
|
117 |
>>> # Making a publication-quality image |
|
|
118 |
>>> images = np.asarray([[mni1, mni2, mni2], |
|
|
119 |
... [mni3, mni4, mni4]]) |
|
|
120 |
>>> slices = np.asarray([[100, 100, 100], |
|
|
121 |
... [100, 100, 100]]) |
|
|
122 |
>>> axes = np.asarray([[0, 1, 2], |
|
|
123 |
[0, 1, 2]]) |
|
|
124 |
>>> ants.plot_grid(images, slices, axes, title='Publication Figures with ANTsPy', |
|
|
125 |
tfontsize=20, title_dy=0.03, title_dx=-0.04, |
|
|
126 |
rlabels=['Row 1', 'Row 2'], |
|
|
127 |
clabels=['Col 1', 'Col 2', 'Col 3'], |
|
|
128 |
rfontsize=16, cfontsize=16) |
|
|
129 |
""" |
|
|
130 |
|
|
|
131 |
def mirror_matrix(x): |
|
|
132 |
return x[::-1, :] |
|
|
133 |
|
|
|
134 |
def rotate270_matrix(x): |
|
|
135 |
return mirror_matrix(x.T) |
|
|
136 |
|
|
|
137 |
def rotate180_matrix(x): |
|
|
138 |
return x[::-1, ::-1] |
|
|
139 |
|
|
|
140 |
def rotate90_matrix(x): |
|
|
141 |
return mirror_matrix(x).T |
|
|
142 |
|
|
|
143 |
def flip_matrix(x): |
|
|
144 |
return mirror_matrix(rotate180_matrix(x)) |
|
|
145 |
|
|
|
146 |
def reorient_slice(x, axis): |
|
|
147 |
if axis != 1: |
|
|
148 |
x = rotate90_matrix(x) |
|
|
149 |
if axis == 1: |
|
|
150 |
x = rotate90_matrix(x) |
|
|
151 |
x = mirror_matrix(x) |
|
|
152 |
return x |
|
|
153 |
|
|
|
154 |
def slice_image(img, axis, idx): |
|
|
155 |
if axis == 0: |
|
|
156 |
return img[idx, :, :].numpy() |
|
|
157 |
elif axis == 1: |
|
|
158 |
return img[:, idx, :].numpy() |
|
|
159 |
elif axis == 2: |
|
|
160 |
return img[:, :, idx].numpy() |
|
|
161 |
elif axis == -1: |
|
|
162 |
return img[:, :, idx].numpy() |
|
|
163 |
elif axis == -2: |
|
|
164 |
return img[:, idx, :].numpy() |
|
|
165 |
elif axis == -3: |
|
|
166 |
return img[idx, :, :].numpy() |
|
|
167 |
else: |
|
|
168 |
raise ValueError("axis %i not valid" % axis) |
|
|
169 |
|
|
|
170 |
if isinstance(images, np.ndarray): |
|
|
171 |
images = images.tolist() |
|
|
172 |
if not isinstance(images, list): |
|
|
173 |
raise ValueError("images argument must be of type list") |
|
|
174 |
if not isinstance(images[0], list): |
|
|
175 |
images = [images] |
|
|
176 |
|
|
|
177 |
if slices is None: |
|
|
178 |
one_slice = True |
|
|
179 |
if isinstance(slices, int): |
|
|
180 |
one_slice = True |
|
|
181 |
if isinstance(slices, np.ndarray): |
|
|
182 |
slices = slices.tolist() |
|
|
183 |
if isinstance(slices, list): |
|
|
184 |
one_slice = False |
|
|
185 |
if not isinstance(slices[0], list): |
|
|
186 |
slices = [slices] |
|
|
187 |
nslicerow = len(slices) |
|
|
188 |
nslicecol = len(slices[0]) |
|
|
189 |
|
|
|
190 |
nrow = len(images) |
|
|
191 |
ncol = len(images[0]) |
|
|
192 |
|
|
|
193 |
if rlabels is None: |
|
|
194 |
rlabels = [None] * nrow |
|
|
195 |
if clabels is None: |
|
|
196 |
clabels = [None] * ncol |
|
|
197 |
|
|
|
198 |
if not one_slice: |
|
|
199 |
if (nrow != nslicerow) or (ncol != nslicecol): |
|
|
200 |
raise ValueError( |
|
|
201 |
"`images` arg shape (%i,%i) must equal `slices` arg shape (%i,%i)!" |
|
|
202 |
% (nrow, ncol, nslicerow, nslicecol) |
|
|
203 |
) |
|
|
204 |
|
|
|
205 |
fig = plt.figure(figsize=((ncol + 1) * 2.5 * figsize, (nrow + 1) * 2.5 * figsize)) |
|
|
206 |
|
|
|
207 |
if title is not None: |
|
|
208 |
basex = 0.5 |
|
|
209 |
basey = 0.9 if clabels[0] is None else 0.95 |
|
|
210 |
fig.suptitle(title, fontsize=tfontsize, x=basex + title_dx, y=basey + title_dy) |
|
|
211 |
|
|
|
212 |
if (cpad > 0) and (rpad > 0): |
|
|
213 |
bothgridpad = max(cpad, rpad) |
|
|
214 |
cpad = 0 |
|
|
215 |
rpad = 0 |
|
|
216 |
else: |
|
|
217 |
bothgridpad = 0.0 |
|
|
218 |
|
|
|
219 |
gs = gridspec.GridSpec( |
|
|
220 |
nrow, |
|
|
221 |
ncol, |
|
|
222 |
wspace=bothgridpad, |
|
|
223 |
hspace=0.0, |
|
|
224 |
top=1.0 - 0.5 / (nrow + 1), |
|
|
225 |
bottom=0.5 / (nrow + 1) + cpad, |
|
|
226 |
left=0.5 / (ncol + 1) + rpad, |
|
|
227 |
right=1 - 0.5 / (ncol + 1), |
|
|
228 |
) |
|
|
229 |
|
|
|
230 |
if isinstance(vmin, (int, float)): |
|
|
231 |
vmins = [vmin] * nrow |
|
|
232 |
elif vmin is None: |
|
|
233 |
vmins = [None] * nrow |
|
|
234 |
else: |
|
|
235 |
vmins = vmin |
|
|
236 |
|
|
|
237 |
if isinstance(vmax, (int, float)): |
|
|
238 |
vmaxs = [vmax] * nrow |
|
|
239 |
elif vmax is None: |
|
|
240 |
vmaxs = [None] * nrow |
|
|
241 |
else: |
|
|
242 |
vmaxs = vmax |
|
|
243 |
|
|
|
244 |
if isinstance(cmap, str): |
|
|
245 |
cmaps = [cmap] * nrow |
|
|
246 |
elif cmap is None: |
|
|
247 |
cmaps = [None] * nrow |
|
|
248 |
else: |
|
|
249 |
cmaps = cmap |
|
|
250 |
|
|
|
251 |
for rowidx, rvmin, rvmax, rcmap in zip(range(nrow), vmins, vmaxs, cmaps): |
|
|
252 |
for colidx in range(ncol): |
|
|
253 |
ax = plt.subplot(gs[rowidx, colidx]) |
|
|
254 |
|
|
|
255 |
if colidx == 0: |
|
|
256 |
if rlabels[rowidx] is not None: |
|
|
257 |
bottom, height = 0.25, 0.5 |
|
|
258 |
top = bottom + height |
|
|
259 |
# add label text |
|
|
260 |
ax.text( |
|
|
261 |
-0.07, |
|
|
262 |
0.5 * (bottom + top), |
|
|
263 |
rlabels[rowidx], |
|
|
264 |
horizontalalignment="right", |
|
|
265 |
verticalalignment="center", |
|
|
266 |
rotation="vertical", |
|
|
267 |
transform=ax.transAxes, |
|
|
268 |
color=rfontcolor, |
|
|
269 |
fontsize=rfontsize, |
|
|
270 |
) |
|
|
271 |
|
|
|
272 |
# add label background |
|
|
273 |
extra = 0.3 if rowidx == 0 else 0.0 |
|
|
274 |
|
|
|
275 |
rect = patches.Rectangle( |
|
|
276 |
(-0.3, 0), |
|
|
277 |
0.3, |
|
|
278 |
1.0 + extra, |
|
|
279 |
facecolor=rfacecolor, |
|
|
280 |
alpha=1.0, |
|
|
281 |
transform=ax.transAxes, |
|
|
282 |
clip_on=False, |
|
|
283 |
) |
|
|
284 |
ax.add_patch(rect) |
|
|
285 |
|
|
|
286 |
if rowidx == 0: |
|
|
287 |
if clabels[colidx] is not None: |
|
|
288 |
bottom, height = 0.25, 0.5 |
|
|
289 |
left, width = 0.25, 0.5 |
|
|
290 |
right = left + width |
|
|
291 |
top = bottom + height |
|
|
292 |
ax.text( |
|
|
293 |
0.5 * (left + right), |
|
|
294 |
0.09 + top + bottom, |
|
|
295 |
clabels[colidx], |
|
|
296 |
horizontalalignment="center", |
|
|
297 |
verticalalignment="center", |
|
|
298 |
rotation="horizontal", |
|
|
299 |
transform=ax.transAxes, |
|
|
300 |
color=cfontcolor, |
|
|
301 |
fontsize=cfontsize, |
|
|
302 |
) |
|
|
303 |
|
|
|
304 |
# add label background |
|
|
305 |
rect = patches.Rectangle( |
|
|
306 |
(0, 1.0), |
|
|
307 |
1.0, |
|
|
308 |
0.3, |
|
|
309 |
facecolor=cfacecolor, |
|
|
310 |
alpha=1.0, |
|
|
311 |
transform=ax.transAxes, |
|
|
312 |
clip_on=False, |
|
|
313 |
) |
|
|
314 |
ax.add_patch(rect) |
|
|
315 |
|
|
|
316 |
tmpimg = images[rowidx][colidx] |
|
|
317 |
if isinstance(axes, int): |
|
|
318 |
tmpaxis = axes |
|
|
319 |
else: |
|
|
320 |
tmpaxis = axes[rowidx][colidx] |
|
|
321 |
|
|
|
322 |
if tmpimg.dimension == 2: |
|
|
323 |
tmpslice = tmpimg.numpy() |
|
|
324 |
tmpslice = reorient_slice(tmpslice, tmpaxis) |
|
|
325 |
else: |
|
|
326 |
sliceidx = slices[rowidx][colidx] if not one_slice else slices |
|
|
327 |
if sliceidx is None: |
|
|
328 |
sliceidx = math.ceil(tmpimg.shape[tmpaxis] / 2) |
|
|
329 |
tmpslice = slice_image(tmpimg, tmpaxis, sliceidx) |
|
|
330 |
tmpslice = reorient_slice(tmpslice, tmpaxis) |
|
|
331 |
|
|
|
332 |
im = ax.imshow(tmpslice, cmap=rcmap, aspect="auto", vmin=rvmin, vmax=rvmax) |
|
|
333 |
ax.axis("off") |
|
|
334 |
|
|
|
335 |
# A colorbar solution with make_axes_locatable will not allow y-scaling of the colorbar. |
|
|
336 |
# from mpl_toolkits.axes_grid1 import make_axes_locatable |
|
|
337 |
# divider = make_axes_locatable(ax) |
|
|
338 |
# cax = divider.append_axes('right', size='5%', pad=0.05) |
|
|
339 |
if colorbar: |
|
|
340 |
axins = inset_axes(ax, |
|
|
341 |
width="5%", # width = 5% of parent_bbox width |
|
|
342 |
height="90%", # height : 50% |
|
|
343 |
loc='center left', |
|
|
344 |
bbox_to_anchor=(1.03, 0., 1, 1), |
|
|
345 |
bbox_transform=ax.transAxes, |
|
|
346 |
borderpad=0, |
|
|
347 |
) |
|
|
348 |
fig.colorbar(im, cax=axins, orientation='vertical') |
|
|
349 |
|
|
|
350 |
if filename is not None: |
|
|
351 |
filename = os.path.expanduser(filename) |
|
|
352 |
plt.savefig(filename, dpi=dpi, transparent=transparent, bbox_inches="tight") |
|
|
353 |
plt.close(fig) |
|
|
354 |
else: |
|
|
355 |
plt.show() |