# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
from functools import partial
import numpy as np
from scipy.spatial.distance import cdist
from ...utils import _check_option, _validate_type, fill_doc
def _check_stc(stc1, stc2):
"""Check that stcs are compatible."""
if stc1.data.shape != stc2.data.shape:
raise ValueError("Data in stcs must have the same size")
if np.all(stc1.times != stc2.times):
raise ValueError("Times of two stcs must match.")
def source_estimate_quantification(stc1, stc2, metric="rms"):
"""Calculate STC similarities across all sources and times.
Parameters
----------
stc1 : SourceEstimate
First source estimate for comparison.
stc2 : SourceEstimate
Second source estimate for comparison.
metric : str
Metric to calculate, ``'rms'`` or ``'cosine'``.
Returns
-------
score : float | array
Calculated metric.
Notes
-----
Metric calculation has multiple options:
* rms: Root mean square of difference between stc data matrices.
* cosine: Normalized correlation of all elements in stc data matrices.
.. versionadded:: 0.10.0
"""
_check_option("metric", metric, ["rms", "cosine"])
# This is checking that the data are having the same size meaning
# no comparison between distributed and sparse can be done so far.
_check_stc(stc1, stc2)
data1, data2 = stc1.data, stc2.data
# Calculate root mean square difference between two matrices
if metric == "rms":
score = np.sqrt(np.mean((data1 - data2) ** 2))
# Calculate correlation coefficient between matrix elements
elif metric == "cosine":
score = 1.0 - _cosine(data1, data2)
return score
def _uniform_stc(stc1, stc2):
"""Uniform vertices of two stcs.
This function returns the stcs with the same vertices by
inserting zeros in data for missing vertices.
"""
if len(stc1.vertices) != len(stc2.vertices):
raise ValueError(
"Data in stcs must have the same number of vertices "
f"components. Got {len(stc1.vertices)} != {len(stc2.vertices)}."
)
idx_start1 = 0
idx_start2 = 0
stc1 = stc1.copy()
stc2 = stc2.copy()
all_data1 = []
all_data2 = []
for i, (vert1, vert2) in enumerate(zip(stc1.vertices, stc2.vertices)):
vert = np.union1d(vert1, vert2)
data1 = np.zeros([len(vert), stc1.data.shape[1]])
data2 = np.zeros([len(vert), stc2.data.shape[1]])
data1[np.searchsorted(vert, vert1)] = stc1.data[
idx_start1 : idx_start1 + len(vert1)
]
data2[np.searchsorted(vert, vert2)] = stc2.data[
idx_start2 : idx_start2 + len(vert2)
]
idx_start1 += len(vert1)
idx_start2 += len(vert2)
stc1.vertices[i] = vert
stc2.vertices[i] = vert
all_data1.append(data1)
all_data2.append(data2)
stc1._data = np.concatenate(all_data1, axis=0)
stc2._data = np.concatenate(all_data2, axis=0)
return stc1, stc2
def _apply(func, stc_true, stc_est, per_sample):
"""Apply metric to stcs.
Applies a metric to each pair of columns of stc_true and stc_est
if per_sample is True. Otherwise it applies it to stc_true and stc_est
directly.
"""
if per_sample:
metric = np.empty(stc_true.data.shape[1]) # one value per time point
for i in range(stc_true.data.shape[1]):
metric[i] = func(stc_true.data[:, i : i + 1], stc_est.data[:, i : i + 1])
else:
metric = func(stc_true.data, stc_est.data)
return metric
def _thresholding(stc_true, stc_est, threshold):
relative = isinstance(threshold, str)
threshold = _check_threshold(threshold)
if relative:
if stc_true is not None:
stc_true._data[
np.abs(stc_true._data) <= threshold * np.max(np.abs(stc_true._data))
] = 0.0
stc_est._data[
np.abs(stc_est._data) <= threshold * np.max(np.abs(stc_est._data))
] = 0.0
else:
if stc_true is not None:
stc_true._data[np.abs(stc_true._data) <= threshold] = 0.0
stc_est._data[np.abs(stc_est._data) <= threshold] = 0.0
return stc_true, stc_est
def _cosine(x, y):
p = x.ravel()
q = y.ravel()
p_norm = np.linalg.norm(p)
q_norm = np.linalg.norm(q)
if p_norm * q_norm:
return (p.T @ q) / (p_norm * q_norm)
elif p_norm == q_norm:
return 1
else:
return 0
@fill_doc
def cosine_score(stc_true, stc_est, per_sample=True):
"""Compute cosine similarity between 2 source estimates.
Parameters
----------
%(stc_true_metric)s
%(stc_est_metric)s
%(per_sample_metric)s
Returns
-------
%(stc_metric)s
Notes
-----
.. versionadded:: 1.2
"""
stc_true, stc_est = _uniform_stc(stc_true, stc_est)
metric = _apply(_cosine, stc_true, stc_est, per_sample=per_sample)
return metric
def _check_threshold(threshold):
"""Accept a float or a string that ends with %."""
_validate_type(threshold, ("numeric", str), "threshold")
if isinstance(threshold, str):
if not threshold.endswith("%"):
raise ValueError(
f'Threshold if a string must end with "%". Got {threshold}.'
)
threshold = float(threshold[:-1]) / 100.0
threshold = float(threshold)
if not 0 <= threshold <= 1:
raise ValueError(
"Threshold proportion must be between 0 and 1 (inclusive), but "
f"got {threshold}"
)
return threshold
def _abs_col_sum(x):
return np.abs(x).sum(axis=1)
def _dle(p, q, src, stc):
"""Aux function to compute dipole localization error."""
p = _abs_col_sum(p)
q = _abs_col_sum(q)
idx1 = np.nonzero(p)[0]
idx2 = np.nonzero(q)[0]
points = []
for i in range(len(src)):
points.append(src[i]["rr"][stc.vertices[i]])
points = np.concatenate(points, axis=0)
if len(idx1) and len(idx2):
D = cdist(points[idx1], points[idx2])
D_min_1 = np.min(D, axis=0)
D_min_2 = np.min(D, axis=1)
return (np.mean(D_min_1) + np.mean(D_min_2)) / 2.0
else:
return np.inf
@fill_doc
def region_localization_error(stc_true, stc_est, src, threshold="90%", per_sample=True):
r"""Compute region localization error (RLE) between 2 source estimates.
.. math::
RLE = \frac{1}{2Q}\sum_{k \in I} \min_{l \in \hat{I}}{||r_k - r_l||} + \frac{1}{2\hat{Q}}\sum_{l \in \hat{I}} \min_{k \in I}{||r_k - r_l||}
where :math:`I` and :math:`\hat{I}` denote respectively the original and
estimated indexes of active sources, :math:`Q` and :math:`\hat{Q}` are
the numbers of original and estimated active sources.
:math:`r_k` denotes the position of the k-th source dipole in space
and :math:`||\cdot||` is an Euclidean norm in :math:`\mathbb{R}^3`.
Parameters
----------
%(stc_true_metric)s
%(stc_est_metric)s
src : instance of SourceSpaces
The source space on which the source estimates are defined.
threshold : float | str
The threshold to apply to source estimates before computing
the dipole localization error. If a string the threshold is
a percentage and it should end with the percent character.
%(per_sample_metric)s
Returns
-------
%(stc_metric)s
Notes
-----
Papers :footcite:`MaksymenkoEtAl2017` and :footcite:`BeckerEtAl2017`
use term Dipole Localization Error (DLE) for the same formula. Paper
:footcite:`YaoEtAl2005` uses term Error Distance (ED) for the same formula.
To unify the terminology and to avoid confusion with other cases
of using term DLE but for different metric :footcite:`MolinsEtAl2008`, we
use term Region Localization Error (RLE).
.. versionadded:: 1.2
References
----------
.. footbibliography::
""" # noqa: E501
stc_true, stc_est = _uniform_stc(stc_true, stc_est)
stc_true, stc_est = _thresholding(stc_true, stc_est, threshold)
func = partial(_dle, src=src, stc=stc_true)
metric = _apply(func, stc_true, stc_est, per_sample=per_sample)
return metric
def _roc_auc_score(p, q):
from sklearn.metrics import roc_auc_score
return roc_auc_score(np.abs(p) > 0, np.abs(q))
@fill_doc
def roc_auc_score(stc_true, stc_est, per_sample=True):
"""Compute ROC AUC between 2 source estimates.
ROC stands for receiver operating curve and AUC is Area under the curve.
When computing this metric the stc_true must be thresholded
as any non-zero value will be considered as a positive.
The ROC-AUC metric is computed between amplitudes of the source
estimates, i.e. after taking the absolute values.
Parameters
----------
%(stc_true_metric)s
%(stc_est_metric)s
%(per_sample_metric)s
Returns
-------
%(stc_metric)s
Notes
-----
.. versionadded:: 1.2
"""
stc_true, stc_est = _uniform_stc(stc_true, stc_est)
metric = _apply(_roc_auc_score, stc_true, stc_est, per_sample=per_sample)
return metric
def _f1_score(p, q):
from sklearn.metrics import f1_score
return f1_score(_abs_col_sum(p) > 0, _abs_col_sum(q) > 0)
@fill_doc
def f1_score(stc_true, stc_est, threshold="90%", per_sample=True):
"""Compute the F1 score, also known as balanced F-score or F-measure.
The F1 score can be interpreted as a weighted average of the precision
and recall, where an F1 score reaches its best value at 1 and worst score
at 0. The relative contribution of precision and recall to the F1
score are equal.
The formula for the F1 score is::
F1 = 2 * (precision * recall) / (precision + recall)
Threshold is used first for data binarization.
Parameters
----------
%(stc_true_metric)s
%(stc_est_metric)s
threshold : float | str
The threshold to apply to source estimates before computing
the f1 score. If a string the threshold is
a percentage and it should end with the percent character.
%(per_sample_metric)s
Returns
-------
%(stc_metric)s
Notes
-----
.. versionadded:: 1.2
"""
stc_true, stc_est = _uniform_stc(stc_true, stc_est)
stc_true, stc_est = _thresholding(stc_true, stc_est, threshold)
metric = _apply(_f1_score, stc_true, stc_est, per_sample=per_sample)
return metric
def _precision_score(p, q):
from sklearn.metrics import precision_score
return precision_score(_abs_col_sum(p) > 0, _abs_col_sum(q) > 0)
@fill_doc
def precision_score(stc_true, stc_est, threshold="90%", per_sample=True):
"""Compute the precision.
The precision is the ratio ``tp / (tp + fp)`` where ``tp`` is the number of
true positives and ``fp`` the number of false positives. The precision is
intuitively the ability of the classifier not to label as positive a sample
that is negative.
The best value is 1 and the worst value is 0.
Threshold is used first for data binarization.
Parameters
----------
%(stc_true_metric)s
%(stc_est_metric)s
threshold : float | str
The threshold to apply to source estimates before computing
the precision. If a string the threshold is
a percentage and it should end with the percent character.
%(per_sample_metric)s
Returns
-------
%(stc_metric)s
Notes
-----
.. versionadded:: 1.2
"""
stc_true, stc_est = _uniform_stc(stc_true, stc_est)
stc_true, stc_est = _thresholding(stc_true, stc_est, threshold)
metric = _apply(_precision_score, stc_true, stc_est, per_sample=per_sample)
return metric
def _recall_score(p, q):
from sklearn.metrics import recall_score
return recall_score(_abs_col_sum(p) > 0, _abs_col_sum(q) > 0)
@fill_doc
def recall_score(stc_true, stc_est, threshold="90%", per_sample=True):
"""Compute the recall.
The recall is the ratio ``tp / (tp + fn)`` where ``tp`` is the number of
true positives and ``fn`` the number of false negatives. The recall is
intuitively the ability of the classifier to find all the positive samples.
The best value is 1 and the worst value is 0.
Threshold is used first for data binarization.
Parameters
----------
%(stc_true_metric)s
%(stc_est_metric)s
threshold : float | str
The threshold to apply to source estimates before computing
the recall. If a string the threshold is
a percentage and it should end with the percent character.
%(per_sample_metric)s
Returns
-------
%(stc_metric)s
Notes
-----
.. versionadded:: 1.2
"""
stc_true, stc_est = _uniform_stc(stc_true, stc_est)
stc_true, stc_est = _thresholding(stc_true, stc_est, threshold)
metric = _apply(_recall_score, stc_true, stc_est, per_sample=per_sample)
return metric
def _prepare_ppe_sd(stc_true, stc_est, src, threshold="50%"):
stc_true = stc_true.copy()
stc_est = stc_est.copy()
n_dipoles = 0
for i, v in enumerate(stc_true.vertices):
if len(v):
n_dipoles += len(v)
r_true = src[i]["rr"][v]
if n_dipoles != 1:
raise ValueError(f"True source must contain only one dipole, got {n_dipoles}.")
_, stc_est = _thresholding(None, stc_est, threshold)
r_est = np.empty([0, 3])
for i, v in enumerate(stc_est.vertices):
if len(v):
r_est = np.vstack([r_est, src[i]["rr"][v]])
return stc_est, r_true, r_est
def _peak_position_error(p, q, r_est, r_true):
q = _abs_col_sum(q)
if np.sum(q):
q /= np.sum(q)
r_est_mean = np.dot(q, r_est)
return np.linalg.norm(r_est_mean - r_true)
else:
return np.inf
@fill_doc
def peak_position_error(stc_true, stc_est, src, threshold="50%", per_sample=True):
r"""Compute the peak position error.
The peak position error measures the distance between the center-of-mass
of the estimated and the true source.
.. math::
PPE = \| \dfrac{\sum_i|s_i|r_{i}}{\sum_i|s_i|}
- r_{true}\|,
where :math:`r_{true}` is a true dipole position,
:math:`r_i` and :math:`|s_i|` denote respectively the position
and amplitude of i-th dipole in source estimate.
Threshold is used on estimated source for focusing the metric to strong
amplitudes and omitting the low-amplitude values.
Parameters
----------
%(stc_true_metric)s
%(stc_est_metric)s
src : instance of SourceSpaces
The source space on which the source estimates are defined.
threshold : float | str
The threshold to apply to source estimates before computing
the recall. If a string the threshold is
a percentage and it should end with the percent character.
%(per_sample_metric)s
Returns
-------
%(stc_metric)s
Notes
-----
These metrics are documented in :footcite:`StenroosHauk2013` and
:footcite:`LinEtAl2006a`.
.. versionadded:: 1.2
References
----------
.. footbibliography::
"""
stc_est, r_true, r_est = _prepare_ppe_sd(stc_true, stc_est, src, threshold)
func = partial(_peak_position_error, r_est=r_est, r_true=r_true)
metric = _apply(func, stc_true, stc_est, per_sample=per_sample)
return metric
def _spatial_deviation(p, q, r_est, r_true):
q = _abs_col_sum(q)
if np.sum(q):
q /= np.sum(q)
r_true_tile = np.tile(r_true, (r_est.shape[0], 1))
r_diff = r_est - r_true_tile
r_diff_norm = np.sum(r_diff**2, axis=1)
return np.sqrt(np.dot(q, r_diff_norm))
else:
return np.inf
@fill_doc
def spatial_deviation_error(stc_true, stc_est, src, threshold="50%", per_sample=True):
r"""Compute the spatial deviation.
The spatial deviation characterizes the spread of the estimate source
around the true source.
.. math::
SD = \dfrac{\sum_i|s_i|\|r_{i} - r_{true}\|^2}{\sum_i|s_i|}.
where :math:`r_{true}` is a true dipole position,
:math:`r_i` and :math:`|s_i|` denote respectively the position
and amplitude of i-th dipole in source estimate.
Threshold is used on estimated source for focusing the metric to strong
amplitudes and omitting the low-amplitude values.
Parameters
----------
%(stc_true_metric)s
%(stc_est_metric)s
src : instance of SourceSpaces
The source space on which the source estimates are defined.
threshold : float | str
The threshold to apply to source estimates before computing
the recall. If a string the threshold is
a percentage and it should end with the percent character.
%(per_sample_metric)s
Returns
-------
%(stc_metric)s
Notes
-----
These metrics are documented in :footcite:`StenroosHauk2013` and
:footcite:`LinEtAl2006a`.
.. versionadded:: 1.2
References
----------
.. footbibliography::
"""
stc_est, r_true, r_est = _prepare_ppe_sd(stc_true, stc_est, src, threshold)
func = partial(_spatial_deviation, r_est=r_est, r_true=r_true)
metric = _apply(func, stc_true, stc_est, per_sample=per_sample)
return metric