Switch to unified view

a b/ants/plotting/plot_grid.py
1
"""
2
Functions for plotting ants images
3
"""
4
5
6
__all__ = [
7
    "plot_grid"
8
]
9
10
import fnmatch
11
import math
12
import os
13
import warnings
14
15
from matplotlib import gridspec
16
import matplotlib.pyplot as plt
17
import matplotlib.patheffects as path_effects
18
import matplotlib.lines as mlines
19
import matplotlib.patches as patches
20
import matplotlib.mlab as mlab
21
import matplotlib.animation as animation
22
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
23
24
25
import numpy as np
26
27
28
def plot_grid(
29
    images,
30
    slices=None,
31
    axes=2,
32
    # general figure arguments
33
    figsize=1.0,
34
    rpad=0,
35
    cpad=0,
36
    vmin=None,
37
    vmax=None,
38
    colorbar=True,
39
    cmap="Greys_r",
40
    # title arguments
41
    title=None,
42
    tfontsize=20,
43
    title_dx=0,
44
    title_dy=0,
45
    # row arguments
46
    rlabels=None,
47
    rfontsize=14,
48
    rfontcolor="white",
49
    rfacecolor="black",
50
    # column arguments
51
    clabels=None,
52
    cfontsize=14,
53
    cfontcolor="white",
54
    cfacecolor="black",
55
    # save arguments
56
    filename=None,
57
    dpi=400,
58
    transparent=True,
59
    # other args
60
    **kwargs
61
):
62
    """
63
    Plot a collection of images in an arbitrarily-defined grid
64
65
    Matplotlib named colors: https://matplotlib.org/examples/color/named_colors.html
66
67
    Arguments
68
    ---------
69
    images : list of ANTsImage types
70
        image(s) to plot.
71
        if one image, this image will be used for all grid locations.
72
        if multiple images, they should be arrange in a list the same
73
        shape as the `gridsize` argument.
74
75
    slices : integer or list of integers
76
        slice indices to plot
77
        if one integer, this slice index will be used for all images
78
        if multiple integers, they should be arranged in a list the same
79
        shape as the `gridsize` argument
80
81
    axes : integer or list of integers
82
        axis or axes along which to plot image slices
83
        if one integer, this axis will be used for all images
84
        if multiple integers, they should be arranged in a list the same
85
        shape as the `gridsize` argument
86
87
    Example
88
    -------
89
    >>> import ants
90
    >>> import numpy as np
91
    >>> mni1 = ants.image_read(ants.get_data('mni'))
92
    >>> mni2 = mni1.smooth_image(1.)
93
    >>> mni3 = mni1.smooth_image(2.)
94
    >>> mni4 = mni1.smooth_image(3.)
95
    >>> images = np.asarray([[mni1, mni2],
96
    ...                      [mni3, mni4]])
97
    >>> slices = np.asarray([[100, 100],
98
    ...                      [100, 100]])
99
    >>> ants.plot_grid(images=images, slices=slices, title='2x2 Grid')
100
    >>> images2d = np.asarray([[mni1.slice_image(2,100), mni2.slice_image(2,100)],
101
    ...                        [mni3.slice_image(2,100), mni4.slice_image(2,100)]])
102
    >>> ants.plot_grid(images=images2d, title='2x2 Grid Pre-Sliced')
103
    >>> ants.plot_grid(images.reshape(1,4), slices.reshape(1,4), title='1x4 Grid')
104
    >>> ants.plot_grid(images.reshape(4,1), slices.reshape(4,1), title='4x1 Grid')
105
106
    >>> # Padding between rows and/or columns
107
    >>> ants.plot_grid(images, slices, cpad=0.02, title='Col Padding')
108
    >>> ants.plot_grid(images, slices, rpad=0.02, title='Row Padding')
109
    >>> ants.plot_grid(images, slices, rpad=0.02, cpad=0.02, title='Row and Col Padding')
110
111
    >>> # Adding plain row and/or column labels
112
    >>> ants.plot_grid(images, slices, title='Adding Row Labels', rlabels=['Row #1', 'Row #2'])
113
    >>> ants.plot_grid(images, slices, title='Adding Col Labels', clabels=['Col #1', 'Col #2'])
114
    >>> ants.plot_grid(images, slices, title='Row and Col Labels',
115
                       rlabels=['Row 1', 'Row 2'], clabels=['Col 1', 'Col 2'])
116
117
    >>> # Making a publication-quality image
118
    >>> images = np.asarray([[mni1, mni2, mni2],
119
    ...                      [mni3, mni4, mni4]])
120
    >>> slices = np.asarray([[100, 100, 100],
121
    ...                      [100, 100, 100]])
122
    >>> axes = np.asarray([[0, 1, 2],
123
                           [0, 1, 2]])
124
    >>> ants.plot_grid(images, slices, axes, title='Publication Figures with ANTsPy',
125
                       tfontsize=20, title_dy=0.03, title_dx=-0.04,
126
                       rlabels=['Row 1', 'Row 2'],
127
                       clabels=['Col 1', 'Col 2', 'Col 3'],
128
                       rfontsize=16, cfontsize=16)
129
    """
130
131
    def mirror_matrix(x):
132
        return x[::-1, :]
133
134
    def rotate270_matrix(x):
135
        return mirror_matrix(x.T)
136
137
    def rotate180_matrix(x):
138
        return x[::-1, ::-1]
139
140
    def rotate90_matrix(x):
141
        return mirror_matrix(x).T
142
143
    def flip_matrix(x):
144
        return mirror_matrix(rotate180_matrix(x))
145
146
    def reorient_slice(x, axis):
147
        if axis != 1:
148
            x = rotate90_matrix(x)
149
        if axis == 1:
150
            x = rotate90_matrix(x)
151
        x = mirror_matrix(x)
152
        return x
153
154
    def slice_image(img, axis, idx):
155
        if axis == 0:
156
            return img[idx, :, :].numpy()
157
        elif axis == 1:
158
            return img[:, idx, :].numpy()
159
        elif axis == 2:
160
            return img[:, :, idx].numpy()
161
        elif axis == -1:
162
            return img[:, :, idx].numpy()
163
        elif axis == -2:
164
            return img[:, idx, :].numpy()
165
        elif axis == -3:
166
            return img[idx, :, :].numpy()
167
        else:
168
            raise ValueError("axis %i not valid" % axis)
169
170
    if isinstance(images, np.ndarray):
171
        images = images.tolist()
172
    if not isinstance(images, list):
173
        raise ValueError("images argument must be of type list")
174
    if not isinstance(images[0], list):
175
        images = [images]
176
177
    if slices is None:
178
        one_slice = True
179
    if isinstance(slices, int):
180
        one_slice = True
181
    if isinstance(slices, np.ndarray):
182
        slices = slices.tolist()
183
    if isinstance(slices, list):
184
        one_slice = False
185
        if not isinstance(slices[0], list):
186
            slices = [slices]
187
        nslicerow = len(slices)
188
        nslicecol = len(slices[0])
189
190
    nrow = len(images)
191
    ncol = len(images[0])
192
193
    if rlabels is None:
194
        rlabels = [None] * nrow
195
    if clabels is None:
196
        clabels = [None] * ncol
197
198
    if not one_slice:
199
        if (nrow != nslicerow) or (ncol != nslicecol):
200
            raise ValueError(
201
                "`images` arg shape (%i,%i) must equal `slices` arg shape (%i,%i)!"
202
                % (nrow, ncol, nslicerow, nslicecol)
203
            )
204
205
    fig = plt.figure(figsize=((ncol + 1) * 2.5 * figsize, (nrow + 1) * 2.5 * figsize))
206
207
    if title is not None:
208
        basex = 0.5
209
        basey = 0.9 if clabels[0] is None else 0.95
210
        fig.suptitle(title, fontsize=tfontsize, x=basex + title_dx, y=basey + title_dy)
211
212
    if (cpad > 0) and (rpad > 0):
213
        bothgridpad = max(cpad, rpad)
214
        cpad = 0
215
        rpad = 0
216
    else:
217
        bothgridpad = 0.0
218
219
    gs = gridspec.GridSpec(
220
        nrow,
221
        ncol,
222
        wspace=bothgridpad,
223
        hspace=0.0,
224
        top=1.0 - 0.5 / (nrow + 1),
225
        bottom=0.5 / (nrow + 1) + cpad,
226
        left=0.5 / (ncol + 1) + rpad,
227
        right=1 - 0.5 / (ncol + 1),
228
    )
229
230
    if isinstance(vmin, (int, float)):
231
        vmins = [vmin] * nrow
232
    elif vmin is None:
233
        vmins = [None] * nrow
234
    else:
235
        vmins = vmin
236
237
    if isinstance(vmax, (int, float)):
238
        vmaxs = [vmax] * nrow
239
    elif vmax is None:
240
        vmaxs = [None] * nrow
241
    else:
242
        vmaxs = vmax
243
244
    if isinstance(cmap, str):
245
        cmaps = [cmap] * nrow
246
    elif cmap is None:
247
        cmaps = [None] * nrow
248
    else:
249
        cmaps = cmap
250
251
    for rowidx, rvmin, rvmax, rcmap in zip(range(nrow), vmins, vmaxs, cmaps):
252
        for colidx in range(ncol):
253
            ax = plt.subplot(gs[rowidx, colidx])
254
255
            if colidx == 0:
256
                if rlabels[rowidx] is not None:
257
                    bottom, height = 0.25, 0.5
258
                    top = bottom + height
259
                    # add label text
260
                    ax.text(
261
                        -0.07,
262
                        0.5 * (bottom + top),
263
                        rlabels[rowidx],
264
                        horizontalalignment="right",
265
                        verticalalignment="center",
266
                        rotation="vertical",
267
                        transform=ax.transAxes,
268
                        color=rfontcolor,
269
                        fontsize=rfontsize,
270
                    )
271
272
                    # add label background
273
                    extra = 0.3 if rowidx == 0 else 0.0
274
275
                    rect = patches.Rectangle(
276
                        (-0.3, 0),
277
                        0.3,
278
                        1.0 + extra,
279
                        facecolor=rfacecolor,
280
                        alpha=1.0,
281
                        transform=ax.transAxes,
282
                        clip_on=False,
283
                    )
284
                    ax.add_patch(rect)
285
286
            if rowidx == 0:
287
                if clabels[colidx] is not None:
288
                    bottom, height = 0.25, 0.5
289
                    left, width = 0.25, 0.5
290
                    right = left + width
291
                    top = bottom + height
292
                    ax.text(
293
                        0.5 * (left + right),
294
                        0.09 + top + bottom,
295
                        clabels[colidx],
296
                        horizontalalignment="center",
297
                        verticalalignment="center",
298
                        rotation="horizontal",
299
                        transform=ax.transAxes,
300
                        color=cfontcolor,
301
                        fontsize=cfontsize,
302
                    )
303
304
                    # add label background
305
                    rect = patches.Rectangle(
306
                        (0, 1.0),
307
                        1.0,
308
                        0.3,
309
                        facecolor=cfacecolor,
310
                        alpha=1.0,
311
                        transform=ax.transAxes,
312
                        clip_on=False,
313
                    )
314
                    ax.add_patch(rect)
315
316
            tmpimg = images[rowidx][colidx]
317
            if isinstance(axes, int):
318
                tmpaxis = axes
319
            else:
320
                tmpaxis = axes[rowidx][colidx]
321
            
322
            if tmpimg.dimension == 2:
323
                tmpslice = tmpimg.numpy()
324
                tmpslice = reorient_slice(tmpslice, tmpaxis)
325
            else:
326
                sliceidx = slices[rowidx][colidx] if not one_slice else slices
327
                if sliceidx is None:
328
                    sliceidx = math.ceil(tmpimg.shape[tmpaxis] / 2)
329
                tmpslice = slice_image(tmpimg, tmpaxis, sliceidx)
330
                tmpslice = reorient_slice(tmpslice, tmpaxis)
331
            
332
            im = ax.imshow(tmpslice, cmap=rcmap, aspect="auto", vmin=rvmin, vmax=rvmax)
333
            ax.axis("off")
334
335
        # A colorbar solution with make_axes_locatable will not allow y-scaling of the colorbar.
336
        # from mpl_toolkits.axes_grid1 import make_axes_locatable
337
        # divider = make_axes_locatable(ax)
338
        # cax = divider.append_axes('right', size='5%', pad=0.05)
339
        if colorbar:
340
            axins = inset_axes(ax,
341
                               width="5%",  # width = 5% of parent_bbox width
342
                               height="90%",  # height : 50%
343
                               loc='center left',
344
                               bbox_to_anchor=(1.03, 0., 1, 1),
345
                               bbox_transform=ax.transAxes,
346
                               borderpad=0,
347
            )
348
            fig.colorbar(im, cax=axins, orientation='vertical')
349
350
    if filename is not None:
351
        filename = os.path.expanduser(filename)
352
        plt.savefig(filename, dpi=dpi, transparent=transparent, bbox_inches="tight")
353
        plt.close(fig)
354
    else:
355
        plt.show()