"""Dipole viz specific functions."""
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import os.path as op
import numpy as np
from scipy.spatial import ConvexHull
from .._freesurfer import _estimate_talxfm_rigid, _get_head_surface
from ..surface import read_surface
from ..transforms import _get_trans, apply_trans, invert_transform
from ..utils import _check_option, _validate_type, get_subjects_dir
from .utils import _validate_if_list_of_axes, plt_show
def _check_concat_dipoles(dipole):
from ..dipole import Dipole, _concatenate_dipoles
if not isinstance(dipole, Dipole):
dipole = _concatenate_dipoles(dipole)
return dipole
def _plot_dipole_mri_outlines(
dipoles,
*,
subject,
trans,
ax,
subjects_dir,
color,
scale,
coord_frame,
show,
block,
head_source,
title,
surf,
width,
):
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection, PatchCollection
from matplotlib.patches import Circle
extra = 'when mode is "outlines"'
trans = _get_trans(trans, fro="head", to="mri")[0]
_check_option(
"coord_frame", coord_frame, ["head", "mri", "mri_rotated"], extra=extra
)
_validate_type(surf, (str, None), "surf")
_check_option("surf", surf, ("white", "pial", None))
if ax is None:
_, ax = plt.subplots(1, 3, figsize=(7, 2.5), squeeze=True, layout="constrained")
_validate_if_list_of_axes(ax, 3, name="ax")
dipoles = _check_concat_dipoles(dipoles)
color = "r" if color is None else color
scale = 0.03 if scale is None else scale
width = 0.015 if width is None else width
fig = ax[0].figure
surfs = dict()
hemis = ("lh", "rh")
if surf is not None:
for hemi in hemis:
surfs[hemi] = read_surface(
op.join(subjects_dir, subject, "surf", f"{hemi}.{surf}"),
return_dict=True,
)[2]
surfs[hemi]["rr"] /= 1000.0
subjects_dir = get_subjects_dir(subjects_dir)
if subjects_dir is not None:
subjects_dir = str(subjects_dir)
surfs["head"] = _get_head_surface(head_source, subject, subjects_dir)
del head_source
mri_trans = head_trans = np.eye(4)
if coord_frame in ("mri", "mri_rotated"):
head_trans = trans["trans"]
if coord_frame == "mri_rotated":
rot = _estimate_talxfm_rigid(subject, subjects_dir)
rot[:3, 3] = 0.0
head_trans = rot @ head_trans
mri_trans = rot @ mri_trans
else:
assert coord_frame == "head"
mri_trans = invert_transform(trans)["trans"]
for s in surfs.values():
s["rr"] = 1000 * apply_trans(mri_trans, s["rr"])
del mri_trans
levels = dict()
if surf is not None:
use_rr = np.concatenate([surfs[key]["rr"] for key in hemis])
else:
use_rr = surfs["head"]["rr"]
views = [("Axial", "XY"), ("Coronal", "XZ"), ("Sagittal", "YZ")]
# axial: 25% up the Z axis
axial = float(np.percentile(use_rr[:, 2], 20.0))
coronal = float(np.percentile(use_rr[:, 1], 55.0))
for key in hemis + ("head",):
levels[key] = dict(Axial=axial, Coronal=coronal)
if surf is not None:
levels["rh"]["Sagittal"] = float(np.percentile(surfs["rh"]["rr"][:, 0], 50))
levels["head"]["Sagittal"] = 0.0
for ax_, (name, coords) in zip(ax, views):
idx = list(map(dict(X=0, Y=1, Z=2).get, coords))
miss = np.setdiff1d(np.arange(3), idx)[0]
pos = 1000 * apply_trans(head_trans, dipoles.pos)
ori = 1000 * apply_trans(head_trans, dipoles.ori, move=False)
lims = dict()
for ii, char in enumerate(coords):
lim = surfs["head"]["rr"][:, idx[ii]]
lim = np.array([lim.min(), lim.max()])
lims[char] = lim
ax_.quiver(
pos[:, idx[0]],
pos[:, idx[1]],
scale * ori[:, idx[0]],
scale * ori[:, idx[1]],
color=color,
pivot="middle",
zorder=5,
scale_units="xy",
angles="xy",
scale=1.0,
width=width,
minshaft=0.5,
headwidth=2.5,
headlength=2.5,
headaxislength=2,
)
coll = PatchCollection(
[
Circle((x, y), radius=scale * 1000 * width * 6)
for x, y in zip(pos[:, idx[0]], pos[:, idx[1]])
],
linewidths=0.0,
facecolors=color,
zorder=6,
)
for key, surf in surfs.items():
try:
level = levels[key][name]
except KeyError:
continue
if key != "head":
rrs = surf["rr"][:, idx]
tris = ConvexHull(rrs).simplices
segments = LineCollection(
rrs[:, [0, 1]][tris],
linewidths=1,
linestyles="-",
colors="k",
zorder=3,
alpha=0.25,
)
ax_.add_collection(segments)
ax_.tricontour(
surf["rr"][:, idx[0]],
surf["rr"][:, idx[1]],
surf["tris"],
surf["rr"][:, miss],
levels=[level],
colors="k",
linewidths=1.0,
linestyles=["-"],
zorder=4,
alpha=0.5,
)
# TODO: this breaks the PatchCollection in MPL
# for coll in h.collections:
# coll.set_clip_on(False)
ax_.add_collection(coll)
ax_.set(
title=name,
xlim=lims[coords[0]],
ylim=lims[coords[1]],
xlabel=coords[0] + " (mm)",
ylabel=coords[1] + " (mm)",
)
for spine in ax_.spines.values():
spine.set_visible(False)
ax_.grid(True, ls=":", zorder=2)
ax_.set_aspect("equal")
if title is not None:
fig.suptitle(title)
plt_show(show, block=block)
return fig
def _plot_dipole_3d(dipoles, *, coord_frame, color, fig, trans, scale, mode):
from .backends.renderer import _get_renderer
_check_option("coord_frame", coord_frame, ("head", "mri"))
color = "r" if color is None else color
scale = 0.005 if scale is None else scale
renderer = _get_renderer(fig=fig, size=(600, 600))
pos = dipoles.pos
ori = dipoles.ori
if coord_frame != "head":
trans = _get_trans(trans, fro="head", to=coord_frame)[0]
pos = apply_trans(trans, pos)
ori = apply_trans(trans, ori)
renderer.sphere(center=pos, color=color, scale=scale)
if mode == "arrow":
x, y, z = pos.T
u, v, w = ori.T
renderer.quiver3d(x, y, z, u, v, w, scale=3 * scale, color=color, mode="arrow")
renderer.show()
fig = renderer.scene()
return fig