# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import numpy as np
import pytest
from numpy.testing import (
assert_allclose,
assert_array_almost_equal,
assert_array_equal,
assert_array_less,
)
from mne.inverse_sparse.mxne_optim import (
_Phi,
_PhiT,
dgap_l21l1,
iterative_mixed_norm_solver,
iterative_tf_mixed_norm_solver,
mixed_norm_solver,
norm_epsilon,
norm_epsilon_inf,
tf_mixed_norm_solver,
)
from mne.time_frequency._stft import stft_norm2
from mne.utils import _record_warnings, catch_logging
def _generate_tf_data():
n, p, t = 30, 40, 64
rng = np.random.RandomState(0)
G = rng.randn(n, p)
G /= np.std(G, axis=0)[None, :]
X = np.zeros((p, t))
active_set = [0, 4]
times = np.linspace(0, 2 * np.pi, t)
X[0] = np.sin(times)
X[4] = -2 * np.sin(4 * times)
X[4, times <= np.pi / 2] = 0
X[4, times >= np.pi] = 0
M = np.dot(G, X)
M += 1 * rng.randn(*M.shape)
return M, G, active_set
def test_l21_mxne():
"""Test convergence of MxNE solver."""
n, p, t, alpha = 30, 40, 20, 1.0
rng = np.random.RandomState(0)
G = rng.randn(n, p)
G /= np.std(G, axis=0)[None, :]
X = np.zeros((p, t))
X[0] = 3
X[4] = -2
M = np.dot(G, X)
args = (M, G, alpha, 1000, 1e-8)
with _record_warnings(): # CD
X_hat_cd, active_set, _, gap_cd = mixed_norm_solver(
*args, active_set_size=None, debias=True, solver="cd", return_gap=True
)
assert_array_less(gap_cd, 1e-8)
assert_array_equal(np.where(active_set)[0], [0, 4])
with _record_warnings(): # CD
X_hat_bcd, active_set, E, gap_bcd = mixed_norm_solver(
M,
G,
alpha,
maxit=1000,
tol=1e-8,
active_set_size=None,
debias=True,
solver="bcd",
return_gap=True,
)
assert_array_less(gap_bcd, 9.6e-9)
assert_array_equal(np.where(active_set)[0], [0, 4])
assert_allclose(X_hat_bcd, X_hat_cd, rtol=1e-2)
with _record_warnings(): # CD
X_hat_cd, active_set, _ = mixed_norm_solver(
*args, active_set_size=2, debias=True, solver="cd"
)
assert_array_equal(np.where(active_set)[0], [0, 4])
with _record_warnings(): # CD
X_hat_bcd, active_set, _ = mixed_norm_solver(
*args, active_set_size=2, debias=True, solver="bcd"
)
assert_array_equal(np.where(active_set)[0], [0, 4])
assert_allclose(X_hat_bcd, X_hat_cd, rtol=1e-2)
with _record_warnings(): # CD
X_hat_bcd, active_set, _ = mixed_norm_solver(
*args, active_set_size=2, debias=True, n_orient=2, solver="bcd"
)
assert_array_equal(np.where(active_set)[0], [0, 1, 4, 5])
# suppress a coordinate-descent warning here
with pytest.warns(RuntimeWarning, match="descent"):
X_hat_cd, active_set, _ = mixed_norm_solver(
*args, active_set_size=2, debias=True, n_orient=2, solver="cd"
)
assert_array_equal(np.where(active_set)[0], [0, 1, 4, 5])
assert_allclose(X_hat_bcd, X_hat_cd, rtol=1e-2)
with _record_warnings(): # CD
X_hat_bcd, active_set, _ = mixed_norm_solver(
*args, active_set_size=2, debias=True, n_orient=5, solver="bcd"
)
assert_array_equal(np.where(active_set)[0], [0, 1, 2, 3, 4])
with pytest.warns(RuntimeWarning, match="descent"):
X_hat_cd, active_set, _ = mixed_norm_solver(
*args, active_set_size=2, debias=True, n_orient=5, solver="cd"
)
assert_array_equal(np.where(active_set)[0], [0, 1, 2, 3, 4])
assert_allclose(X_hat_bcd, X_hat_cd)
@pytest.mark.slowtest
def test_non_convergence():
"""Test non-convergence of MxNE solver to catch unexpected bugs."""
n, p, t, alpha = 30, 40, 20, 1.0
rng = np.random.RandomState(0)
G = rng.randn(n, p)
G /= np.std(G, axis=0)[None, :]
X = np.zeros((p, t))
X[0] = 3
X[4] = -2
M = np.dot(G, X)
# Impossible to converge with only 1 iteration and tol 1e-12
# In case of non-convegence, we test that no error is returned.
args = (M, G, alpha, 1, 1e-12)
with catch_logging() as log:
mixed_norm_solver(
*args, active_set_size=None, debias=True, solver="bcd", verbose=True
)
log = log.getvalue()
assert "Convergence reached" not in log
def test_tf_mxne():
"""Test convergence of TF-MxNE solver."""
alpha_space = 10.0
alpha_time = 5.0
M, G, active_set = _generate_tf_data()
with _record_warnings(): # CD
X_hat_tf, active_set_hat_tf, E, gap_tfmxne = tf_mixed_norm_solver(
M,
G,
alpha_space,
alpha_time,
maxit=200,
tol=1e-8,
verbose=True,
n_orient=1,
tstep=4,
wsize=32,
return_gap=True,
)
assert_array_less(gap_tfmxne, 1e-8)
assert_array_equal(np.where(active_set_hat_tf)[0], active_set)
def test_norm_epsilon():
"""Test computation of espilon norm on TF coefficients."""
tstep = np.array([2])
wsize = np.array([4])
n_times = 10
n_steps = np.ceil(n_times / tstep.astype(float)).astype(int)
n_freqs = wsize // 2 + 1
n_coefs = n_steps * n_freqs
phi = _Phi(wsize, tstep, n_coefs, n_times)
Y = np.zeros((n_steps * n_freqs).item())
l1_ratio = 0.03
assert_allclose(norm_epsilon(Y, l1_ratio, phi), 0.0)
Y[0] = 2.0
assert_allclose(norm_epsilon(Y, l1_ratio, phi), np.max(Y))
l1_ratio = 1.0
assert_allclose(norm_epsilon(Y, l1_ratio, phi), np.max(Y))
# dummy value without random:
Y = np.arange((n_steps * n_freqs).item())
l1_ratio = 0.0
assert_allclose(
norm_epsilon(Y, l1_ratio, phi) ** 2,
stft_norm2(Y.reshape(-1, n_freqs[0], n_steps[0])),
)
l1_ratio = 0.03
# test that vanilla epsilon norm = weights equal to 1
w_time = np.ones(n_coefs[0])
Y = np.abs(np.random.randn(n_coefs[0]))
assert_allclose(
norm_epsilon(Y, l1_ratio, phi), norm_epsilon(Y, l1_ratio, phi, w_time=w_time)
)
# scaling w_time and w_space by the same amount should divide
# epsilon norm by the same amount
Y = np.arange(n_coefs.item()) + 1
mult = 2.0
assert_allclose(
norm_epsilon(Y, l1_ratio, phi, w_space=1, w_time=np.ones(n_coefs.item()))
/ mult,
norm_epsilon(
Y, l1_ratio, phi, w_space=mult, w_time=mult * np.ones(n_coefs.item())
),
)
@pytest.mark.slowtest # slow-ish on Travis OSX
@pytest.mark.timeout(60) # ~30 s on Travis OSX and Linux OpenBLAS
def test_dgapl21l1():
"""Test duality gap for L21 + L1 regularization."""
n_orient = 2
M, G, active_set = _generate_tf_data()
n_times = M.shape[1]
n_sources = G.shape[1]
tstep, wsize = np.array([4, 2]), np.array([64, 16])
n_steps = np.ceil(n_times / tstep.astype(float)).astype(int)
n_freqs = wsize // 2 + 1
n_coefs = n_steps * n_freqs
phi = _Phi(wsize, tstep, n_coefs, n_times)
phiT = _PhiT(tstep, n_freqs, n_steps, n_times)
for l1_ratio in [0.05, 0.1]:
alpha_max = norm_epsilon_inf(G, M, phi, l1_ratio, n_orient)
alpha_space = (1.0 - l1_ratio) * alpha_max
alpha_time = l1_ratio * alpha_max
Z = np.zeros([n_sources, phi.n_coefs.sum()])
# for alpha = alpha_max, Z = 0 is the solution so the dgap is 0
gap = dgap_l21l1(
M,
G,
Z,
np.ones(n_sources, dtype=bool),
alpha_space,
alpha_time,
phi,
phiT,
n_orient,
-np.inf,
)[0]
assert_allclose(0.0, gap)
# check that solution for alpha smaller than alpha_max is non 0:
X_hat_tf, active_set_hat_tf, E, gap = tf_mixed_norm_solver(
M,
G,
alpha_space / 1.01,
alpha_time / 1.01,
maxit=200,
tol=1e-8,
verbose=True,
debias=False,
n_orient=n_orient,
tstep=tstep,
wsize=wsize,
return_gap=True,
)
# allow possible small numerical errors (negative gap)
assert_array_less(-1e-10, gap)
assert_array_less(gap, 1e-8)
assert_array_less(1, len(active_set_hat_tf))
X_hat_tf, active_set_hat_tf, E, gap = tf_mixed_norm_solver(
M,
G,
alpha_space / 5.0,
alpha_time / 5.0,
maxit=200,
tol=1e-8,
verbose=True,
debias=False,
n_orient=n_orient,
tstep=tstep,
wsize=wsize,
return_gap=True,
)
assert_array_less(-1e-10, gap)
assert_array_less(gap, 1e-8)
assert_array_less(1, len(active_set_hat_tf))
def test_tf_mxne_vs_mxne():
"""Test equivalence of TF-MxNE (with alpha_time=0) and MxNE."""
alpha_space = 60.0
alpha_time = 0.0
M, G, active_set = _generate_tf_data()
X_hat_tf, active_set_hat_tf, E = tf_mixed_norm_solver(
M,
G,
alpha_space,
alpha_time,
maxit=200,
tol=1e-8,
verbose=True,
debias=False,
n_orient=1,
tstep=4,
wsize=32,
)
# Also run L21 and check that we get the same
X_hat_l21, _, _ = mixed_norm_solver(
M,
G,
alpha_space,
maxit=200,
tol=1e-8,
verbose=False,
n_orient=1,
active_set_size=None,
debias=False,
)
assert_allclose(X_hat_tf, X_hat_l21, rtol=1e-1)
@pytest.mark.slowtest # slow-ish on Travis OSX
def test_iterative_reweighted_mxne():
"""Test convergence of irMxNE solver."""
n, p, t, alpha = 30, 40, 20, 1
rng = np.random.RandomState(0)
G = rng.randn(n, p)
G /= np.std(G, axis=0)[None, :]
X = np.zeros((p, t))
X[0] = 3
X[4] = -2
M = np.dot(G, X)
with _record_warnings(): # CD
X_hat_l21, _, _ = mixed_norm_solver(
M,
G,
alpha,
maxit=1000,
tol=1e-8,
verbose=False,
n_orient=1,
active_set_size=None,
debias=False,
solver="bcd",
)
with _record_warnings(): # CD
X_hat_bcd, active_set, _ = iterative_mixed_norm_solver(
M,
G,
alpha,
1,
maxit=1000,
tol=1e-8,
active_set_size=None,
debias=False,
solver="bcd",
)
assert_allclose(X_hat_bcd, X_hat_l21, rtol=1e-3)
with _record_warnings(): # CD
X_hat_bcd, active_set, _ = iterative_mixed_norm_solver(
M,
G,
alpha,
5,
maxit=1000,
tol=1e-8,
active_set_size=2,
debias=True,
solver="bcd",
)
assert_array_equal(np.where(active_set)[0], [0, 4])
with _record_warnings(): # CD
X_hat_cd, active_set, _ = iterative_mixed_norm_solver(
M,
G,
alpha,
5,
maxit=1000,
tol=1e-8,
active_set_size=None,
debias=True,
solver="cd",
)
assert_array_equal(np.where(active_set)[0], [0, 4])
assert_array_almost_equal(X_hat_bcd, X_hat_cd, 5)
with _record_warnings(): # CD
X_hat_bcd, active_set, _ = iterative_mixed_norm_solver(
M,
G,
alpha,
5,
maxit=1000,
tol=1e-8,
active_set_size=2,
debias=True,
n_orient=2,
solver="bcd",
)
assert_array_equal(np.where(active_set)[0], [0, 1, 4, 5])
# suppress a coordinate-descent warning here
with pytest.warns(RuntimeWarning, match="descent"):
X_hat_cd, active_set, _ = iterative_mixed_norm_solver(
M,
G,
alpha,
5,
maxit=1000,
tol=1e-8,
active_set_size=2,
debias=True,
n_orient=2,
solver="cd",
)
assert_array_equal(np.where(active_set)[0], [0, 1, 4, 5])
assert_allclose(X_hat_bcd, X_hat_cd)
X_hat_bcd, active_set, _ = iterative_mixed_norm_solver(
M, G, alpha, 5, maxit=1000, tol=1e-8, active_set_size=2, debias=True, n_orient=5
)
assert_array_equal(np.where(active_set)[0], [0, 1, 2, 3, 4])
with pytest.warns(RuntimeWarning, match="descent"):
X_hat_cd, active_set, _ = iterative_mixed_norm_solver(
M,
G,
alpha,
5,
maxit=1000,
tol=1e-8,
active_set_size=2,
debias=True,
n_orient=5,
solver="cd",
)
assert_array_equal(np.where(active_set)[0], [0, 1, 2, 3, 4])
assert_allclose(X_hat_bcd, X_hat_cd)
@pytest.mark.slowtest
def test_iterative_reweighted_tfmxne():
"""Test convergence of irTF-MxNE solver."""
M, G, true_active_set = _generate_tf_data()
alpha_space = 38.0
alpha_time = 0.5
tstep, wsize = [4, 2], [64, 16]
X_hat_tf, _, _ = tf_mixed_norm_solver(
M,
G,
alpha_space,
alpha_time,
maxit=1000,
tol=1e-4,
wsize=wsize,
tstep=tstep,
verbose=False,
n_orient=1,
debias=False,
)
X_hat_bcd, active_set, _ = iterative_tf_mixed_norm_solver(
M,
G,
alpha_space,
alpha_time,
1,
wsize=wsize,
tstep=tstep,
maxit=1000,
tol=1e-4,
debias=False,
verbose=False,
)
assert_allclose(X_hat_tf, X_hat_bcd, rtol=1e-3)
assert_array_equal(np.where(active_set)[0], true_active_set)
alpha_space = 50.0
X_hat_bcd, active_set, _ = iterative_tf_mixed_norm_solver(
M,
G,
alpha_space,
alpha_time,
3,
wsize=wsize,
tstep=tstep,
n_orient=5,
maxit=1000,
tol=1e-4,
debias=False,
verbose=False,
)
assert_array_equal(np.where(active_set)[0], [0, 1, 2, 3, 4])
alpha_space = 40.0
X_hat_bcd, active_set, _ = iterative_tf_mixed_norm_solver(
M,
G,
alpha_space,
alpha_time,
2,
wsize=wsize,
tstep=tstep,
n_orient=2,
maxit=1000,
tol=1e-4,
debias=False,
verbose=False,
)
assert_array_equal(np.where(active_set)[0], [0, 1, 4, 5])