"""Functions to plot on circle as for connectivity."""
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
from functools import partial
from itertools import cycle
from types import SimpleNamespace
import numpy as np
from ..utils import _validate_type
from .utils import _get_cmap, plt_show
def circular_layout(
node_names,
node_order,
start_pos=90,
start_between=True,
group_boundaries=None,
group_sep=10,
):
"""Create layout arranging nodes on a circle.
Parameters
----------
node_names : list of str
Node names.
node_order : list of str
List with node names defining the order in which the nodes are
arranged. Must have the elements as node_names but the order can be
different. The nodes are arranged clockwise starting at "start_pos"
degrees.
start_pos : float
Angle in degrees that defines where the first node is plotted.
start_between : bool
If True, the layout starts with the position between the nodes. This is
the same as adding "180. / len(node_names)" to start_pos.
group_boundaries : None | array-like
List of of boundaries between groups at which point a "group_sep" will
be inserted. E.g. "[0, len(node_names) / 2]" will create two groups.
group_sep : float
Group separation angle in degrees. See "group_boundaries".
Returns
-------
node_angles : array, shape=(n_node_names,)
Node angles in degrees.
"""
n_nodes = len(node_names)
if len(node_order) != n_nodes:
raise ValueError("node_order has to be the same length as node_names")
if group_boundaries is not None:
boundaries = np.array(group_boundaries, dtype=np.int64)
if np.any(boundaries >= n_nodes) or np.any(boundaries < 0):
raise ValueError('"group_boundaries" has to be between 0 and n_nodes - 1.')
if len(boundaries) > 1 and np.any(np.diff(boundaries) <= 0):
raise ValueError('"group_boundaries" must have non-decreasing values.')
n_group_sep = len(group_boundaries)
else:
n_group_sep = 0
boundaries = None
# convert it to a list with indices
node_order = [node_order.index(name) for name in node_names]
node_order = np.array(node_order)
if len(np.unique(node_order)) != n_nodes:
raise ValueError("node_order has repeated entries")
node_sep = (360.0 - n_group_sep * group_sep) / n_nodes
if start_between:
start_pos += node_sep / 2
if boundaries is not None and boundaries[0] == 0:
# special case when a group separator is at the start
start_pos += group_sep / 2
boundaries = boundaries[1:] if n_group_sep > 1 else None
node_angles = np.ones(n_nodes, dtype=np.float64) * node_sep
node_angles[0] = start_pos
if boundaries is not None:
node_angles[boundaries] += group_sep
node_angles = np.cumsum(node_angles)[node_order]
return node_angles
def _plot_connectivity_circle_onpick(
event,
fig=None,
ax=None,
indices=None,
n_nodes=0,
node_angles=None,
ylim=(9, 10),
):
"""Isolate connections around a single node when user left clicks a node.
On right click, resets all connections.
"""
if event.inaxes != ax:
return
if event.button == 1: # left click
# click must be near node radius
if not ylim[0] <= event.ydata <= ylim[1]:
return
# all angles in range [0, 2*pi]
node_angles = node_angles % (np.pi * 2)
node = np.argmin(np.abs(event.xdata - node_angles))
patches = event.inaxes.patches
for ii, (x, y) in enumerate(zip(indices[0], indices[1])):
patches[ii].set_visible(node in [x, y])
fig.canvas.draw()
elif event.button == 3: # right click
patches = event.inaxes.patches
for ii in range(np.size(indices, axis=1)):
patches[ii].set_visible(True)
fig.canvas.draw()
def _plot_connectivity_circle(
con,
node_names,
indices=None,
n_lines=None,
node_angles=None,
node_width=None,
node_height=None,
node_colors=None,
facecolor="black",
textcolor="white",
node_edgecolor="black",
linewidth=1.5,
colormap="hot",
vmin=None,
vmax=None,
colorbar=True,
title=None,
colorbar_size=None,
colorbar_pos=None,
fontsize_title=12,
fontsize_names=8,
fontsize_colorbar=8,
padding=6.0,
ax=None,
interactive=True,
node_linewidth=2.0,
show=True,
):
import matplotlib.patches as m_patches
import matplotlib.path as m_path
import matplotlib.pyplot as plt
from matplotlib.projections.polar import PolarAxes
_validate_type(ax, (None, PolarAxes))
n_nodes = len(node_names)
if node_angles is not None:
if len(node_angles) != n_nodes:
raise ValueError("node_angles has to be the same length as node_names")
# convert it to radians
node_angles = node_angles * np.pi / 180
else:
# uniform layout on unit circle
node_angles = np.linspace(0, 2 * np.pi, n_nodes, endpoint=False)
if node_width is None:
# widths correspond to the minimum angle between two nodes
dist_mat = node_angles[None, :] - node_angles[:, None]
dist_mat[np.diag_indices(n_nodes)] = 1e9
node_width = np.min(np.abs(dist_mat))
else:
node_width = node_width * np.pi / 180
if node_height is None:
node_height = 1.0
if node_colors is not None:
if len(node_colors) < n_nodes:
node_colors = cycle(node_colors)
else:
# assign colors using colormap
try:
spectral = plt.cm.spectral
except AttributeError:
spectral = plt.cm.Spectral
node_colors = [spectral(i / float(n_nodes)) for i in range(n_nodes)]
# handle 1D and 2D connectivity information
if con.ndim == 1:
if indices is None:
raise ValueError("indices has to be provided if con.ndim == 1")
elif con.ndim == 2:
if con.shape[0] != n_nodes or con.shape[1] != n_nodes:
raise ValueError("con has to be 1D or a square matrix")
# we use the lower-triangular part
indices = np.tril_indices(n_nodes, -1)
con = con[indices]
else:
raise ValueError("con has to be 1D or a square matrix")
# get the colormap
colormap = _get_cmap(colormap)
# Use a polar axes
if ax is None:
fig = plt.figure(figsize=(8, 8), facecolor=facecolor, layout="constrained")
ax = fig.add_subplot(polar=True)
else:
fig = ax.figure
ax.set_facecolor(facecolor)
# No ticks, we'll put our own
ax.set_xticks([])
ax.set_yticks([])
# Set y axes limit, add additional space if requested
ax.set_ylim(0, 10 + padding)
# Remove the black axes border which may obscure the labels
ax.spines["polar"].set_visible(False)
# Draw lines between connected nodes, only draw the strongest connections
if n_lines is not None and len(con) > n_lines:
con_thresh = np.sort(np.abs(con).ravel())[-n_lines]
else:
con_thresh = 0.0
# get the connections which we are drawing and sort by connection strength
# this will allow us to draw the strongest connections first
con_abs = np.abs(con)
con_draw_idx = np.where(con_abs >= con_thresh)[0]
con = con[con_draw_idx]
con_abs = con_abs[con_draw_idx]
indices = [ind[con_draw_idx] for ind in indices]
# now sort them
sort_idx = np.argsort(con_abs)
del con_abs
con = con[sort_idx]
indices = [ind[sort_idx] for ind in indices]
# Get vmin vmax for color scaling
if vmin is None:
vmin = np.min(con[np.abs(con) >= con_thresh])
if vmax is None:
vmax = np.max(con)
vrange = vmax - vmin
# We want to add some "noise" to the start and end position of the
# edges: We modulate the noise with the number of connections of the
# node and the connection strength, such that the strongest connections
# are closer to the node center
nodes_n_con = np.zeros((n_nodes), dtype=np.int64)
for i, j in zip(indices[0], indices[1]):
nodes_n_con[i] += 1
nodes_n_con[j] += 1
# initialize random number generator so plot is reproducible
rng = np.random.mtrand.RandomState(0)
n_con = len(indices[0])
noise_max = 0.25 * node_width
start_noise = rng.uniform(-noise_max, noise_max, n_con)
end_noise = rng.uniform(-noise_max, noise_max, n_con)
nodes_n_con_seen = np.zeros_like(nodes_n_con)
for i, (start, end) in enumerate(zip(indices[0], indices[1])):
nodes_n_con_seen[start] += 1
nodes_n_con_seen[end] += 1
start_noise[i] *= (nodes_n_con[start] - nodes_n_con_seen[start]) / float(
nodes_n_con[start]
)
end_noise[i] *= (nodes_n_con[end] - nodes_n_con_seen[end]) / float(
nodes_n_con[end]
)
# scale connectivity for colormap (vmin<=>0, vmax<=>1)
con_val_scaled = (con - vmin) / vrange
# Finally, we draw the connections
for pos, (i, j) in enumerate(zip(indices[0], indices[1])):
# Start point
t0, r0 = node_angles[i], 10
# End point
t1, r1 = node_angles[j], 10
# Some noise in start and end point
t0 += start_noise[pos]
t1 += end_noise[pos]
verts = [(t0, r0), (t0, 5), (t1, 5), (t1, r1)]
codes = [
m_path.Path.MOVETO,
m_path.Path.CURVE4,
m_path.Path.CURVE4,
m_path.Path.LINETO,
]
path = m_path.Path(verts, codes)
color = colormap(con_val_scaled[pos])
# Actual line
patch = m_patches.PathPatch(
path, fill=False, edgecolor=color, linewidth=linewidth, alpha=1.0
)
ax.add_patch(patch)
# Draw ring with colored nodes
height = np.ones(n_nodes) * node_height
bars = ax.bar(
node_angles,
height,
width=node_width,
bottom=9,
edgecolor=node_edgecolor,
lw=node_linewidth,
facecolor=".9",
align="center",
)
for bar, color in zip(bars, node_colors):
bar.set_facecolor(color)
# Draw node labels
angles_deg = 180 * node_angles / np.pi
for name, angle_rad, angle_deg in zip(node_names, node_angles, angles_deg):
if angle_deg >= 270:
ha = "left"
else:
# Flip the label, so text is always upright
angle_deg += 180
ha = "right"
ax.text(
angle_rad,
9.4 + node_height,
name,
size=fontsize_names,
rotation=angle_deg,
rotation_mode="anchor",
horizontalalignment=ha,
verticalalignment="center",
color=textcolor,
)
if title is not None:
ax.set_title(title, color=textcolor, fontsize=fontsize_title)
if colorbar:
sm = plt.cm.ScalarMappable(cmap=colormap, norm=plt.Normalize(vmin, vmax))
sm.set_array(np.linspace(vmin, vmax))
colorbar_kwargs = dict()
if colorbar_size is not None:
colorbar_kwargs.update(shrink=colorbar_size)
if colorbar_pos is not None:
colorbar_kwargs.update(anchor=colorbar_pos)
cb = fig.colorbar(sm, ax=ax, **colorbar_kwargs)
cb_yticks = plt.getp(cb.ax.axes, "yticklabels")
cb.ax.tick_params(labelsize=fontsize_colorbar)
plt.setp(cb_yticks, color=textcolor)
fig.mne = SimpleNamespace(colorbar=cb)
# Add callback for interaction
if interactive:
callback = partial(
_plot_connectivity_circle_onpick,
fig=fig,
ax=ax,
indices=indices,
n_nodes=n_nodes,
node_angles=node_angles,
)
fig.canvas.mpl_connect("button_press_event", callback)
plt_show(show)
return fig, ax
def plot_channel_labels_circle(labels, colors=None, picks=None, **kwargs):
"""Plot labels for each channel in a circle plot.
.. note:: This primarily makes sense for sEEG channels where each
channel can be assigned an anatomical label as the electrode
passes through various brain areas.
Parameters
----------
labels : dict
Lists of labels (values) associated with each channel (keys).
colors : dict
The color (value) for each label (key).
picks : list | tuple
The channels to consider.
**kwargs : kwargs
Keyword arguments for
:func:`mne_connectivity.viz.plot_connectivity_circle`.
Returns
-------
fig : instance of matplotlib.figure.Figure
The figure handle.
axes : instance of matplotlib.projections.polar.PolarAxes
The subplot handle.
"""
from matplotlib.colors import LinearSegmentedColormap
_validate_type(labels, dict, "labels")
_validate_type(colors, (dict, None), "colors")
_validate_type(picks, (list, tuple, None), "picks")
if picks is not None:
labels = {k: v for k, v in labels.items() if k in picks}
ch_names = list(labels.keys())
all_labels = list(set([label for val in labels.values() for label in val]))
n_labels = len(all_labels)
if colors is not None:
for label in all_labels:
if label not in colors:
raise ValueError(f"No color provided for {label} in `colors`")
# update all_labels, there may be unconnected labels in colors
all_labels = list(colors.keys())
n_labels = len(all_labels)
# make colormap
label_colors = [colors[label] for label in all_labels]
node_colors = ["black"] * len(ch_names) + label_colors
label_cmap = LinearSegmentedColormap.from_list(
"label_cmap", label_colors, N=len(label_colors)
)
else:
node_colors = None
node_names = ch_names + all_labels
con = np.zeros((len(node_names), len(node_names))) * np.nan
for idx, ch_name in enumerate(ch_names):
for label in labels[ch_name]:
node_idx = node_names.index(label)
label_color = all_labels.index(label) / n_labels
con[idx, node_idx] = con[node_idx, idx] = label_color # symmetric
# plot
node_order = ch_names + all_labels[::-1]
node_angles = circular_layout(
node_names, node_order, start_pos=90, group_boundaries=[0, len(ch_names)]
)
# provide defaults but don't overwrite
if "node_angles" not in kwargs:
kwargs.update(node_angles=node_angles)
if "colorbar" not in kwargs:
kwargs.update(colorbar=False)
if "node_colors" not in kwargs:
kwargs.update(node_colors=node_colors)
if "vmin" not in kwargs:
kwargs.update(vmin=0)
if "vmax" not in kwargs:
kwargs.update(vmax=1)
if "colormap" not in kwargs:
kwargs.update(colormap=label_cmap)
return _plot_connectivity_circle(con, node_names, **kwargs)