--- a +++ b/tutorials/machine-learning/30_strf.py @@ -0,0 +1,369 @@ +""" +.. _tut-strf: + +===================================================================== +Spectro-temporal receptive field (STRF) estimation on continuous data +===================================================================== + +This demonstrates how an encoding model can be fit with multiple continuous +inputs. In this case, we simulate the model behind a spectro-temporal receptive +field (or STRF). First, we create a linear filter that maps patterns in +spectro-temporal space onto an output, representing neural activity. We fit +a receptive field model that attempts to recover the original linear filter +that was used to create this data. +""" +# Authors: Chris Holdgraf <choldgraf@gmail.com> +# Eric Larson <larson.eric.d@gmail.com> +# +# License: BSD-3-Clause +# Copyright the MNE-Python contributors. + +# %% + +# sphinx_gallery_thumbnail_number = 7 + +import matplotlib.pyplot as plt +import numpy as np +from scipy.io import loadmat +from scipy.stats import multivariate_normal +from sklearn.preprocessing import scale + +import mne +from mne.decoding import ReceptiveField, TimeDelayingRidge + +rng = np.random.RandomState(1337) # To make this example reproducible + +# %% +# Load audio data +# --------------- +# +# We'll read in the audio data from :footcite:`CrosseEtAl2016` in order to +# simulate a response. +# +# In addition, we'll downsample the data along the time dimension in order to +# speed up computation. Note that depending on the input values, this may +# not be desired. For example if your input stimulus varies more quickly than +# 1/2 the sampling rate to which we are downsampling. + +# Read in audio that's been recorded in epochs. +path_audio = mne.datasets.mtrf.data_path() +data = loadmat(str(path_audio / "speech_data.mat")) +audio = data["spectrogram"].T +sfreq = float(data["Fs"][0, 0]) +n_decim = 2 +audio = mne.filter.resample(audio, down=n_decim, npad="auto") +sfreq /= n_decim + +# %% +# Create a receptive field +# ------------------------ +# +# We'll simulate a linear receptive field for a theoretical neural signal. This +# defines how the signal will respond to power in this receptive field space. +n_freqs = 20 +tmin, tmax = -0.1, 0.4 + +# To simulate the data we'll create explicit delays here +delays_samp = np.arange(np.round(tmin * sfreq), np.round(tmax * sfreq) + 1).astype(int) +delays_sec = delays_samp / sfreq +freqs = np.linspace(50, 5000, n_freqs) +grid = np.array(np.meshgrid(delays_sec, freqs)) + +# We need data to be shaped as n_epochs, n_features, n_times, so swap axes here +grid = grid.swapaxes(0, -1).swapaxes(0, 1) + +# Simulate a temporal receptive field with a Gabor filter +means_high = [0.1, 500] +means_low = [0.2, 2500] +cov = [[0.001, 0], [0, 500000]] +gauss_high = multivariate_normal.pdf(grid, means_high, cov) +gauss_low = -1 * multivariate_normal.pdf(grid, means_low, cov) +weights = gauss_high + gauss_low # Combine to create the "true" STRF +kwargs = dict( + vmax=np.abs(weights).max(), + vmin=-np.abs(weights).max(), + cmap="RdBu_r", + shading="gouraud", +) + +fig, ax = plt.subplots(layout="constrained") +ax.pcolormesh(delays_sec, freqs, weights, **kwargs) +ax.set(title="Simulated STRF", xlabel="Time Lags (s)", ylabel="Frequency (Hz)") +plt.setp(ax.get_xticklabels(), rotation=45) + +# %% +# Simulate a neural response +# -------------------------- +# +# Using this receptive field, we'll create an artificial neural response to +# a stimulus. +# +# To do this, we'll create a time-delayed version of the receptive field, and +# then calculate the dot product between this and the stimulus. Note that this +# is effectively doing a convolution between the stimulus and the receptive +# field. See `here <https://en.wikipedia.org/wiki/Convolution>`_ for more +# information. + +# Reshape audio to split into epochs, then make epochs the first dimension. +n_epochs, n_seconds = 16, 5 +audio = audio[:, : int(n_seconds * sfreq * n_epochs)] +X = audio.reshape([n_freqs, n_epochs, -1]).swapaxes(0, 1) +n_times = X.shape[-1] + +# Delay the spectrogram according to delays so it can be combined w/ the STRF +# Lags will now be in axis 1, then we reshape to vectorize +delays = np.arange(np.round(tmin * sfreq), np.round(tmax * sfreq) + 1).astype(int) + +# Iterate through indices and append +X_del = np.zeros((len(delays),) + X.shape) +for ii, ix_delay in enumerate(delays): + # These arrays will take/put particular indices in the data + take = [slice(None)] * X.ndim + put = [slice(None)] * X.ndim + if ix_delay > 0: + take[-1] = slice(None, -ix_delay) + put[-1] = slice(ix_delay, None) + elif ix_delay < 0: + take[-1] = slice(-ix_delay, None) + put[-1] = slice(None, ix_delay) + X_del[ii][tuple(put)] = X[tuple(take)] + +# Now set the delayed axis to the 2nd dimension +X_del = np.rollaxis(X_del, 0, 3) +X_del = X_del.reshape([n_epochs, -1, n_times]) +n_features = X_del.shape[1] +weights_sim = weights.ravel() + +# Simulate a neural response to the sound, given this STRF +y = np.zeros((n_epochs, n_times)) +for ii, iep in enumerate(X_del): + # Simulate this epoch and add random noise + noise_amp = 0.002 + y[ii] = np.dot(weights_sim, iep) + noise_amp * rng.randn(n_times) + +# Plot the first 2 trials of audio and the simulated electrode activity +X_plt = scale(np.hstack(X[:2]).T).T +y_plt = scale(np.hstack(y[:2])) +time = np.arange(X_plt.shape[-1]) / sfreq +_, (ax1, ax2) = plt.subplots(2, 1, figsize=(6, 6), sharex=True, layout="constrained") +ax1.pcolormesh(time, freqs, X_plt, vmin=0, vmax=4, cmap="Reds", shading="gouraud") +ax1.set_title("Input auditory features") +ax1.set(ylim=[freqs.min(), freqs.max()], ylabel="Frequency (Hz)") +ax2.plot(time, y_plt) +ax2.set( + xlim=[time.min(), time.max()], + title="Simulated response", + xlabel="Time (s)", + ylabel="Activity (a.u.)", +) + + +# %% +# Fit a model to recover this receptive field +# ------------------------------------------- +# +# Finally, we'll use the :class:`mne.decoding.ReceptiveField` class to recover +# the linear receptive field of this signal. Note that properties of the +# receptive field (e.g. smoothness) will depend on the autocorrelation in the +# inputs and outputs. + +# Create training and testing data +train, test = np.arange(n_epochs - 1), n_epochs - 1 +X_train, X_test, y_train, y_test = X[train], X[test], y[train], y[test] +X_train, X_test, y_train, y_test = ( + np.rollaxis(ii, -1, 0) for ii in (X_train, X_test, y_train, y_test) +) +# Model the simulated data as a function of the spectrogram input +alphas = np.logspace(-3, 3, 7) +scores = np.zeros_like(alphas) +models = [] +for ii, alpha in enumerate(alphas): + rf = ReceptiveField(tmin, tmax, sfreq, freqs, estimator=alpha) + rf.fit(X_train, y_train) + + # Now make predictions about the model output, given input stimuli. + scores[ii] = rf.score(X_test, y_test).item() + models.append(rf) + +times = rf.delays_ / float(rf.sfreq) + +# Choose the model that performed best on the held out data +ix_best_alpha = np.argmax(scores) +best_mod = models[ix_best_alpha] +coefs = best_mod.coef_[0] +best_pred = best_mod.predict(X_test)[:, 0] + +# Plot the original STRF, and the one that we recovered with modeling. +_, (ax1, ax2) = plt.subplots( + 1, + 2, + figsize=(6, 3), + sharey=True, + sharex=True, + layout="constrained", +) +ax1.pcolormesh(delays_sec, freqs, weights, **kwargs) +ax2.pcolormesh(times, rf.feature_names, coefs, **kwargs) +ax1.set_title("Original STRF") +ax2.set_title("Best Reconstructed STRF") +plt.setp([iax.get_xticklabels() for iax in [ax1, ax2]], rotation=45) + +# Plot the actual response and the predicted response on a held out stimulus +time_pred = np.arange(best_pred.shape[0]) / sfreq +fig, ax = plt.subplots() +ax.plot(time_pred, y_test, color="k", alpha=0.2, lw=4) +ax.plot(time_pred, best_pred, color="r", lw=1) +ax.set(title="Original and predicted activity", xlabel="Time (s)") +ax.legend(["Original", "Predicted"]) + + +# %% +# Visualize the effects of regularization +# --------------------------------------- +# +# Above we fit a :class:`mne.decoding.ReceptiveField` model for one of many +# values for the ridge regularization parameter. Here we will plot the model +# score as well as the model coefficients for each value, in order to +# visualize how coefficients change with different levels of regularization. +# These issues as well as the STRF pipeline are described in detail +# in :footcite:`TheunissenEtAl2001,WillmoreSmyth2003,HoldgrafEtAl2016`. + +# Plot model score for each ridge parameter +fig = plt.figure(figsize=(10, 4), layout="constrained") +ax = plt.subplot2grid([2, len(alphas)], [1, 0], 1, len(alphas)) +ax.plot(np.arange(len(alphas)), scores, marker="o", color="r") +ax.annotate( + "Best parameter", + (ix_best_alpha, scores[ix_best_alpha]), + (ix_best_alpha, scores[ix_best_alpha] - 0.1), + arrowprops={"arrowstyle": "->"}, +) +plt.xticks(np.arange(len(alphas)), [f"{ii:.0e}" for ii in alphas]) +ax.set( + xlabel="Ridge regularization value", + ylabel="Score ($R^2$)", + xlim=[-0.4, len(alphas) - 0.6], +) + +# Plot the STRF of each ridge parameter +for ii, (rf, i_alpha) in enumerate(zip(models, alphas)): + ax = plt.subplot2grid([2, len(alphas)], [0, ii], 1, 1) + ax.pcolormesh(times, rf.feature_names, rf.coef_[0], **kwargs) + plt.xticks([], []) + plt.yticks([], []) +fig.suptitle("Model coefficients / scores for many ridge parameters", y=1) + +# %% +# Using different regularization types +# ------------------------------------ +# In addition to the standard ridge regularization, the +# :class:`mne.decoding.TimeDelayingRidge` class also exposes +# `Laplacian <https://en.wikipedia.org/wiki/Laplacian_matrix>`_ regularization +# term as: +# +# .. math:: +# \left[\begin{matrix} +# 1 & -1 & & & & \\ +# -1 & 2 & -1 & & & \\ +# & -1 & 2 & -1 & & \\ +# & & \ddots & \ddots & \ddots & \\ +# & & & -1 & 2 & -1 \\ +# & & & & -1 & 1\end{matrix}\right] +# +# This imposes a smoothness constraint of nearby time samples and/or features. +# Quoting :footcite:`CrosseEtAl2016` : +# +# Tikhonov [identity] regularization (Equation 5) reduces overfitting by +# smoothing the TRF estimate in a way that is insensitive to +# the amplitude of the signal of interest. However, the Laplacian +# approach (Equation 6) reduces off-sample error whilst preserving +# signal amplitude (Lalor et al., 2006). As a result, this approach +# usually leads to an improved estimate of the system’s response (as +# indexed by MSE) compared to Tikhonov regularization. +# + +scores_lap = np.zeros_like(alphas) +models_lap = [] +for ii, alpha in enumerate(alphas): + estimator = TimeDelayingRidge(tmin, tmax, sfreq, reg_type="laplacian", alpha=alpha) + rf = ReceptiveField(tmin, tmax, sfreq, freqs, estimator=estimator) + rf.fit(X_train, y_train) + + # Now make predictions about the model output, given input stimuli. + scores_lap[ii] = rf.score(X_test, y_test).item() + models_lap.append(rf) + +ix_best_alpha_lap = np.argmax(scores_lap) + +# %% +# Compare model performance +# ------------------------- +# Below we visualize the model performance of each regularization method +# (ridge vs. Laplacian) for different levels of alpha. As you can see, the +# Laplacian method performs better in general, because it imposes a smoothness +# constraint along the time and feature dimensions of the coefficients. +# This matches the "true" receptive field structure and results in a better +# model fit. + +fig = plt.figure(figsize=(10, 6), layout="constrained") +ax = plt.subplot2grid([3, len(alphas)], [2, 0], 1, len(alphas)) +ax.plot(np.arange(len(alphas)), scores_lap, marker="o", color="r") +ax.plot(np.arange(len(alphas)), scores, marker="o", color="0.5", ls=":") +ax.annotate( + "Best Laplacian", + (ix_best_alpha_lap, scores_lap[ix_best_alpha_lap]), + (ix_best_alpha_lap, scores_lap[ix_best_alpha_lap] - 0.1), + arrowprops={"arrowstyle": "->"}, +) +ax.annotate( + "Best Ridge", + (ix_best_alpha, scores[ix_best_alpha]), + (ix_best_alpha, scores[ix_best_alpha] - 0.1), + arrowprops={"arrowstyle": "->"}, +) +plt.xticks(np.arange(len(alphas)), [f"{ii:.0e}" for ii in alphas]) +ax.set( + xlabel="Laplacian regularization value", + ylabel="Score ($R^2$)", + xlim=[-0.4, len(alphas) - 0.6], +) + +# Plot the STRF of each ridge parameter +xlim = times[[0, -1]] +for ii, (rf_lap, rf, i_alpha) in enumerate(zip(models_lap, models, alphas)): + ax = plt.subplot2grid([3, len(alphas)], [0, ii], 1, 1) + ax.pcolormesh(times, rf_lap.feature_names, rf_lap.coef_[0], **kwargs) + ax.set(xticks=[], yticks=[], xlim=xlim) + if ii == 0: + ax.set(ylabel="Laplacian") + ax = plt.subplot2grid([3, len(alphas)], [1, ii], 1, 1) + ax.pcolormesh(times, rf.feature_names, rf.coef_[0], **kwargs) + ax.set(xticks=[], yticks=[], xlim=xlim) + if ii == 0: + ax.set(ylabel="Ridge") +fig.suptitle("Model coefficients / scores for laplacian regularization", y=1) + +# %% +# Plot the original STRF, and the one that we recovered with modeling. +rf = models[ix_best_alpha] +rf_lap = models_lap[ix_best_alpha_lap] +_, (ax1, ax2, ax3) = plt.subplots( + 1, + 3, + figsize=(9, 3), + sharey=True, + sharex=True, + layout="constrained", +) +ax1.pcolormesh(delays_sec, freqs, weights, **kwargs) +ax2.pcolormesh(times, rf.feature_names, rf.coef_[0], **kwargs) +ax3.pcolormesh(times, rf_lap.feature_names, rf_lap.coef_[0], **kwargs) +ax1.set_title("Original STRF") +ax2.set_title("Best Ridge STRF") +ax3.set_title("Best Laplacian STRF") +plt.setp([iax.get_xticklabels() for iax in [ax1, ax2, ax3]], rotation=45) + +# %% +# References +# ========== +# .. footbibliography::