|
a |
|
b/tutorials/machine-learning/30_strf.py |
|
|
1 |
""" |
|
|
2 |
.. _tut-strf: |
|
|
3 |
|
|
|
4 |
===================================================================== |
|
|
5 |
Spectro-temporal receptive field (STRF) estimation on continuous data |
|
|
6 |
===================================================================== |
|
|
7 |
|
|
|
8 |
This demonstrates how an encoding model can be fit with multiple continuous |
|
|
9 |
inputs. In this case, we simulate the model behind a spectro-temporal receptive |
|
|
10 |
field (or STRF). First, we create a linear filter that maps patterns in |
|
|
11 |
spectro-temporal space onto an output, representing neural activity. We fit |
|
|
12 |
a receptive field model that attempts to recover the original linear filter |
|
|
13 |
that was used to create this data. |
|
|
14 |
""" |
|
|
15 |
# Authors: Chris Holdgraf <choldgraf@gmail.com> |
|
|
16 |
# Eric Larson <larson.eric.d@gmail.com> |
|
|
17 |
# |
|
|
18 |
# License: BSD-3-Clause |
|
|
19 |
# Copyright the MNE-Python contributors. |
|
|
20 |
|
|
|
21 |
# %% |
|
|
22 |
|
|
|
23 |
# sphinx_gallery_thumbnail_number = 7 |
|
|
24 |
|
|
|
25 |
import matplotlib.pyplot as plt |
|
|
26 |
import numpy as np |
|
|
27 |
from scipy.io import loadmat |
|
|
28 |
from scipy.stats import multivariate_normal |
|
|
29 |
from sklearn.preprocessing import scale |
|
|
30 |
|
|
|
31 |
import mne |
|
|
32 |
from mne.decoding import ReceptiveField, TimeDelayingRidge |
|
|
33 |
|
|
|
34 |
rng = np.random.RandomState(1337) # To make this example reproducible |
|
|
35 |
|
|
|
36 |
# %% |
|
|
37 |
# Load audio data |
|
|
38 |
# --------------- |
|
|
39 |
# |
|
|
40 |
# We'll read in the audio data from :footcite:`CrosseEtAl2016` in order to |
|
|
41 |
# simulate a response. |
|
|
42 |
# |
|
|
43 |
# In addition, we'll downsample the data along the time dimension in order to |
|
|
44 |
# speed up computation. Note that depending on the input values, this may |
|
|
45 |
# not be desired. For example if your input stimulus varies more quickly than |
|
|
46 |
# 1/2 the sampling rate to which we are downsampling. |
|
|
47 |
|
|
|
48 |
# Read in audio that's been recorded in epochs. |
|
|
49 |
path_audio = mne.datasets.mtrf.data_path() |
|
|
50 |
data = loadmat(str(path_audio / "speech_data.mat")) |
|
|
51 |
audio = data["spectrogram"].T |
|
|
52 |
sfreq = float(data["Fs"][0, 0]) |
|
|
53 |
n_decim = 2 |
|
|
54 |
audio = mne.filter.resample(audio, down=n_decim, npad="auto") |
|
|
55 |
sfreq /= n_decim |
|
|
56 |
|
|
|
57 |
# %% |
|
|
58 |
# Create a receptive field |
|
|
59 |
# ------------------------ |
|
|
60 |
# |
|
|
61 |
# We'll simulate a linear receptive field for a theoretical neural signal. This |
|
|
62 |
# defines how the signal will respond to power in this receptive field space. |
|
|
63 |
n_freqs = 20 |
|
|
64 |
tmin, tmax = -0.1, 0.4 |
|
|
65 |
|
|
|
66 |
# To simulate the data we'll create explicit delays here |
|
|
67 |
delays_samp = np.arange(np.round(tmin * sfreq), np.round(tmax * sfreq) + 1).astype(int) |
|
|
68 |
delays_sec = delays_samp / sfreq |
|
|
69 |
freqs = np.linspace(50, 5000, n_freqs) |
|
|
70 |
grid = np.array(np.meshgrid(delays_sec, freqs)) |
|
|
71 |
|
|
|
72 |
# We need data to be shaped as n_epochs, n_features, n_times, so swap axes here |
|
|
73 |
grid = grid.swapaxes(0, -1).swapaxes(0, 1) |
|
|
74 |
|
|
|
75 |
# Simulate a temporal receptive field with a Gabor filter |
|
|
76 |
means_high = [0.1, 500] |
|
|
77 |
means_low = [0.2, 2500] |
|
|
78 |
cov = [[0.001, 0], [0, 500000]] |
|
|
79 |
gauss_high = multivariate_normal.pdf(grid, means_high, cov) |
|
|
80 |
gauss_low = -1 * multivariate_normal.pdf(grid, means_low, cov) |
|
|
81 |
weights = gauss_high + gauss_low # Combine to create the "true" STRF |
|
|
82 |
kwargs = dict( |
|
|
83 |
vmax=np.abs(weights).max(), |
|
|
84 |
vmin=-np.abs(weights).max(), |
|
|
85 |
cmap="RdBu_r", |
|
|
86 |
shading="gouraud", |
|
|
87 |
) |
|
|
88 |
|
|
|
89 |
fig, ax = plt.subplots(layout="constrained") |
|
|
90 |
ax.pcolormesh(delays_sec, freqs, weights, **kwargs) |
|
|
91 |
ax.set(title="Simulated STRF", xlabel="Time Lags (s)", ylabel="Frequency (Hz)") |
|
|
92 |
plt.setp(ax.get_xticklabels(), rotation=45) |
|
|
93 |
|
|
|
94 |
# %% |
|
|
95 |
# Simulate a neural response |
|
|
96 |
# -------------------------- |
|
|
97 |
# |
|
|
98 |
# Using this receptive field, we'll create an artificial neural response to |
|
|
99 |
# a stimulus. |
|
|
100 |
# |
|
|
101 |
# To do this, we'll create a time-delayed version of the receptive field, and |
|
|
102 |
# then calculate the dot product between this and the stimulus. Note that this |
|
|
103 |
# is effectively doing a convolution between the stimulus and the receptive |
|
|
104 |
# field. See `here <https://en.wikipedia.org/wiki/Convolution>`_ for more |
|
|
105 |
# information. |
|
|
106 |
|
|
|
107 |
# Reshape audio to split into epochs, then make epochs the first dimension. |
|
|
108 |
n_epochs, n_seconds = 16, 5 |
|
|
109 |
audio = audio[:, : int(n_seconds * sfreq * n_epochs)] |
|
|
110 |
X = audio.reshape([n_freqs, n_epochs, -1]).swapaxes(0, 1) |
|
|
111 |
n_times = X.shape[-1] |
|
|
112 |
|
|
|
113 |
# Delay the spectrogram according to delays so it can be combined w/ the STRF |
|
|
114 |
# Lags will now be in axis 1, then we reshape to vectorize |
|
|
115 |
delays = np.arange(np.round(tmin * sfreq), np.round(tmax * sfreq) + 1).astype(int) |
|
|
116 |
|
|
|
117 |
# Iterate through indices and append |
|
|
118 |
X_del = np.zeros((len(delays),) + X.shape) |
|
|
119 |
for ii, ix_delay in enumerate(delays): |
|
|
120 |
# These arrays will take/put particular indices in the data |
|
|
121 |
take = [slice(None)] * X.ndim |
|
|
122 |
put = [slice(None)] * X.ndim |
|
|
123 |
if ix_delay > 0: |
|
|
124 |
take[-1] = slice(None, -ix_delay) |
|
|
125 |
put[-1] = slice(ix_delay, None) |
|
|
126 |
elif ix_delay < 0: |
|
|
127 |
take[-1] = slice(-ix_delay, None) |
|
|
128 |
put[-1] = slice(None, ix_delay) |
|
|
129 |
X_del[ii][tuple(put)] = X[tuple(take)] |
|
|
130 |
|
|
|
131 |
# Now set the delayed axis to the 2nd dimension |
|
|
132 |
X_del = np.rollaxis(X_del, 0, 3) |
|
|
133 |
X_del = X_del.reshape([n_epochs, -1, n_times]) |
|
|
134 |
n_features = X_del.shape[1] |
|
|
135 |
weights_sim = weights.ravel() |
|
|
136 |
|
|
|
137 |
# Simulate a neural response to the sound, given this STRF |
|
|
138 |
y = np.zeros((n_epochs, n_times)) |
|
|
139 |
for ii, iep in enumerate(X_del): |
|
|
140 |
# Simulate this epoch and add random noise |
|
|
141 |
noise_amp = 0.002 |
|
|
142 |
y[ii] = np.dot(weights_sim, iep) + noise_amp * rng.randn(n_times) |
|
|
143 |
|
|
|
144 |
# Plot the first 2 trials of audio and the simulated electrode activity |
|
|
145 |
X_plt = scale(np.hstack(X[:2]).T).T |
|
|
146 |
y_plt = scale(np.hstack(y[:2])) |
|
|
147 |
time = np.arange(X_plt.shape[-1]) / sfreq |
|
|
148 |
_, (ax1, ax2) = plt.subplots(2, 1, figsize=(6, 6), sharex=True, layout="constrained") |
|
|
149 |
ax1.pcolormesh(time, freqs, X_plt, vmin=0, vmax=4, cmap="Reds", shading="gouraud") |
|
|
150 |
ax1.set_title("Input auditory features") |
|
|
151 |
ax1.set(ylim=[freqs.min(), freqs.max()], ylabel="Frequency (Hz)") |
|
|
152 |
ax2.plot(time, y_plt) |
|
|
153 |
ax2.set( |
|
|
154 |
xlim=[time.min(), time.max()], |
|
|
155 |
title="Simulated response", |
|
|
156 |
xlabel="Time (s)", |
|
|
157 |
ylabel="Activity (a.u.)", |
|
|
158 |
) |
|
|
159 |
|
|
|
160 |
|
|
|
161 |
# %% |
|
|
162 |
# Fit a model to recover this receptive field |
|
|
163 |
# ------------------------------------------- |
|
|
164 |
# |
|
|
165 |
# Finally, we'll use the :class:`mne.decoding.ReceptiveField` class to recover |
|
|
166 |
# the linear receptive field of this signal. Note that properties of the |
|
|
167 |
# receptive field (e.g. smoothness) will depend on the autocorrelation in the |
|
|
168 |
# inputs and outputs. |
|
|
169 |
|
|
|
170 |
# Create training and testing data |
|
|
171 |
train, test = np.arange(n_epochs - 1), n_epochs - 1 |
|
|
172 |
X_train, X_test, y_train, y_test = X[train], X[test], y[train], y[test] |
|
|
173 |
X_train, X_test, y_train, y_test = ( |
|
|
174 |
np.rollaxis(ii, -1, 0) for ii in (X_train, X_test, y_train, y_test) |
|
|
175 |
) |
|
|
176 |
# Model the simulated data as a function of the spectrogram input |
|
|
177 |
alphas = np.logspace(-3, 3, 7) |
|
|
178 |
scores = np.zeros_like(alphas) |
|
|
179 |
models = [] |
|
|
180 |
for ii, alpha in enumerate(alphas): |
|
|
181 |
rf = ReceptiveField(tmin, tmax, sfreq, freqs, estimator=alpha) |
|
|
182 |
rf.fit(X_train, y_train) |
|
|
183 |
|
|
|
184 |
# Now make predictions about the model output, given input stimuli. |
|
|
185 |
scores[ii] = rf.score(X_test, y_test).item() |
|
|
186 |
models.append(rf) |
|
|
187 |
|
|
|
188 |
times = rf.delays_ / float(rf.sfreq) |
|
|
189 |
|
|
|
190 |
# Choose the model that performed best on the held out data |
|
|
191 |
ix_best_alpha = np.argmax(scores) |
|
|
192 |
best_mod = models[ix_best_alpha] |
|
|
193 |
coefs = best_mod.coef_[0] |
|
|
194 |
best_pred = best_mod.predict(X_test)[:, 0] |
|
|
195 |
|
|
|
196 |
# Plot the original STRF, and the one that we recovered with modeling. |
|
|
197 |
_, (ax1, ax2) = plt.subplots( |
|
|
198 |
1, |
|
|
199 |
2, |
|
|
200 |
figsize=(6, 3), |
|
|
201 |
sharey=True, |
|
|
202 |
sharex=True, |
|
|
203 |
layout="constrained", |
|
|
204 |
) |
|
|
205 |
ax1.pcolormesh(delays_sec, freqs, weights, **kwargs) |
|
|
206 |
ax2.pcolormesh(times, rf.feature_names, coefs, **kwargs) |
|
|
207 |
ax1.set_title("Original STRF") |
|
|
208 |
ax2.set_title("Best Reconstructed STRF") |
|
|
209 |
plt.setp([iax.get_xticklabels() for iax in [ax1, ax2]], rotation=45) |
|
|
210 |
|
|
|
211 |
# Plot the actual response and the predicted response on a held out stimulus |
|
|
212 |
time_pred = np.arange(best_pred.shape[0]) / sfreq |
|
|
213 |
fig, ax = plt.subplots() |
|
|
214 |
ax.plot(time_pred, y_test, color="k", alpha=0.2, lw=4) |
|
|
215 |
ax.plot(time_pred, best_pred, color="r", lw=1) |
|
|
216 |
ax.set(title="Original and predicted activity", xlabel="Time (s)") |
|
|
217 |
ax.legend(["Original", "Predicted"]) |
|
|
218 |
|
|
|
219 |
|
|
|
220 |
# %% |
|
|
221 |
# Visualize the effects of regularization |
|
|
222 |
# --------------------------------------- |
|
|
223 |
# |
|
|
224 |
# Above we fit a :class:`mne.decoding.ReceptiveField` model for one of many |
|
|
225 |
# values for the ridge regularization parameter. Here we will plot the model |
|
|
226 |
# score as well as the model coefficients for each value, in order to |
|
|
227 |
# visualize how coefficients change with different levels of regularization. |
|
|
228 |
# These issues as well as the STRF pipeline are described in detail |
|
|
229 |
# in :footcite:`TheunissenEtAl2001,WillmoreSmyth2003,HoldgrafEtAl2016`. |
|
|
230 |
|
|
|
231 |
# Plot model score for each ridge parameter |
|
|
232 |
fig = plt.figure(figsize=(10, 4), layout="constrained") |
|
|
233 |
ax = plt.subplot2grid([2, len(alphas)], [1, 0], 1, len(alphas)) |
|
|
234 |
ax.plot(np.arange(len(alphas)), scores, marker="o", color="r") |
|
|
235 |
ax.annotate( |
|
|
236 |
"Best parameter", |
|
|
237 |
(ix_best_alpha, scores[ix_best_alpha]), |
|
|
238 |
(ix_best_alpha, scores[ix_best_alpha] - 0.1), |
|
|
239 |
arrowprops={"arrowstyle": "->"}, |
|
|
240 |
) |
|
|
241 |
plt.xticks(np.arange(len(alphas)), [f"{ii:.0e}" for ii in alphas]) |
|
|
242 |
ax.set( |
|
|
243 |
xlabel="Ridge regularization value", |
|
|
244 |
ylabel="Score ($R^2$)", |
|
|
245 |
xlim=[-0.4, len(alphas) - 0.6], |
|
|
246 |
) |
|
|
247 |
|
|
|
248 |
# Plot the STRF of each ridge parameter |
|
|
249 |
for ii, (rf, i_alpha) in enumerate(zip(models, alphas)): |
|
|
250 |
ax = plt.subplot2grid([2, len(alphas)], [0, ii], 1, 1) |
|
|
251 |
ax.pcolormesh(times, rf.feature_names, rf.coef_[0], **kwargs) |
|
|
252 |
plt.xticks([], []) |
|
|
253 |
plt.yticks([], []) |
|
|
254 |
fig.suptitle("Model coefficients / scores for many ridge parameters", y=1) |
|
|
255 |
|
|
|
256 |
# %% |
|
|
257 |
# Using different regularization types |
|
|
258 |
# ------------------------------------ |
|
|
259 |
# In addition to the standard ridge regularization, the |
|
|
260 |
# :class:`mne.decoding.TimeDelayingRidge` class also exposes |
|
|
261 |
# `Laplacian <https://en.wikipedia.org/wiki/Laplacian_matrix>`_ regularization |
|
|
262 |
# term as: |
|
|
263 |
# |
|
|
264 |
# .. math:: |
|
|
265 |
# \left[\begin{matrix} |
|
|
266 |
# 1 & -1 & & & & \\ |
|
|
267 |
# -1 & 2 & -1 & & & \\ |
|
|
268 |
# & -1 & 2 & -1 & & \\ |
|
|
269 |
# & & \ddots & \ddots & \ddots & \\ |
|
|
270 |
# & & & -1 & 2 & -1 \\ |
|
|
271 |
# & & & & -1 & 1\end{matrix}\right] |
|
|
272 |
# |
|
|
273 |
# This imposes a smoothness constraint of nearby time samples and/or features. |
|
|
274 |
# Quoting :footcite:`CrosseEtAl2016` : |
|
|
275 |
# |
|
|
276 |
# Tikhonov [identity] regularization (Equation 5) reduces overfitting by |
|
|
277 |
# smoothing the TRF estimate in a way that is insensitive to |
|
|
278 |
# the amplitude of the signal of interest. However, the Laplacian |
|
|
279 |
# approach (Equation 6) reduces off-sample error whilst preserving |
|
|
280 |
# signal amplitude (Lalor et al., 2006). As a result, this approach |
|
|
281 |
# usually leads to an improved estimate of the system’s response (as |
|
|
282 |
# indexed by MSE) compared to Tikhonov regularization. |
|
|
283 |
# |
|
|
284 |
|
|
|
285 |
scores_lap = np.zeros_like(alphas) |
|
|
286 |
models_lap = [] |
|
|
287 |
for ii, alpha in enumerate(alphas): |
|
|
288 |
estimator = TimeDelayingRidge(tmin, tmax, sfreq, reg_type="laplacian", alpha=alpha) |
|
|
289 |
rf = ReceptiveField(tmin, tmax, sfreq, freqs, estimator=estimator) |
|
|
290 |
rf.fit(X_train, y_train) |
|
|
291 |
|
|
|
292 |
# Now make predictions about the model output, given input stimuli. |
|
|
293 |
scores_lap[ii] = rf.score(X_test, y_test).item() |
|
|
294 |
models_lap.append(rf) |
|
|
295 |
|
|
|
296 |
ix_best_alpha_lap = np.argmax(scores_lap) |
|
|
297 |
|
|
|
298 |
# %% |
|
|
299 |
# Compare model performance |
|
|
300 |
# ------------------------- |
|
|
301 |
# Below we visualize the model performance of each regularization method |
|
|
302 |
# (ridge vs. Laplacian) for different levels of alpha. As you can see, the |
|
|
303 |
# Laplacian method performs better in general, because it imposes a smoothness |
|
|
304 |
# constraint along the time and feature dimensions of the coefficients. |
|
|
305 |
# This matches the "true" receptive field structure and results in a better |
|
|
306 |
# model fit. |
|
|
307 |
|
|
|
308 |
fig = plt.figure(figsize=(10, 6), layout="constrained") |
|
|
309 |
ax = plt.subplot2grid([3, len(alphas)], [2, 0], 1, len(alphas)) |
|
|
310 |
ax.plot(np.arange(len(alphas)), scores_lap, marker="o", color="r") |
|
|
311 |
ax.plot(np.arange(len(alphas)), scores, marker="o", color="0.5", ls=":") |
|
|
312 |
ax.annotate( |
|
|
313 |
"Best Laplacian", |
|
|
314 |
(ix_best_alpha_lap, scores_lap[ix_best_alpha_lap]), |
|
|
315 |
(ix_best_alpha_lap, scores_lap[ix_best_alpha_lap] - 0.1), |
|
|
316 |
arrowprops={"arrowstyle": "->"}, |
|
|
317 |
) |
|
|
318 |
ax.annotate( |
|
|
319 |
"Best Ridge", |
|
|
320 |
(ix_best_alpha, scores[ix_best_alpha]), |
|
|
321 |
(ix_best_alpha, scores[ix_best_alpha] - 0.1), |
|
|
322 |
arrowprops={"arrowstyle": "->"}, |
|
|
323 |
) |
|
|
324 |
plt.xticks(np.arange(len(alphas)), [f"{ii:.0e}" for ii in alphas]) |
|
|
325 |
ax.set( |
|
|
326 |
xlabel="Laplacian regularization value", |
|
|
327 |
ylabel="Score ($R^2$)", |
|
|
328 |
xlim=[-0.4, len(alphas) - 0.6], |
|
|
329 |
) |
|
|
330 |
|
|
|
331 |
# Plot the STRF of each ridge parameter |
|
|
332 |
xlim = times[[0, -1]] |
|
|
333 |
for ii, (rf_lap, rf, i_alpha) in enumerate(zip(models_lap, models, alphas)): |
|
|
334 |
ax = plt.subplot2grid([3, len(alphas)], [0, ii], 1, 1) |
|
|
335 |
ax.pcolormesh(times, rf_lap.feature_names, rf_lap.coef_[0], **kwargs) |
|
|
336 |
ax.set(xticks=[], yticks=[], xlim=xlim) |
|
|
337 |
if ii == 0: |
|
|
338 |
ax.set(ylabel="Laplacian") |
|
|
339 |
ax = plt.subplot2grid([3, len(alphas)], [1, ii], 1, 1) |
|
|
340 |
ax.pcolormesh(times, rf.feature_names, rf.coef_[0], **kwargs) |
|
|
341 |
ax.set(xticks=[], yticks=[], xlim=xlim) |
|
|
342 |
if ii == 0: |
|
|
343 |
ax.set(ylabel="Ridge") |
|
|
344 |
fig.suptitle("Model coefficients / scores for laplacian regularization", y=1) |
|
|
345 |
|
|
|
346 |
# %% |
|
|
347 |
# Plot the original STRF, and the one that we recovered with modeling. |
|
|
348 |
rf = models[ix_best_alpha] |
|
|
349 |
rf_lap = models_lap[ix_best_alpha_lap] |
|
|
350 |
_, (ax1, ax2, ax3) = plt.subplots( |
|
|
351 |
1, |
|
|
352 |
3, |
|
|
353 |
figsize=(9, 3), |
|
|
354 |
sharey=True, |
|
|
355 |
sharex=True, |
|
|
356 |
layout="constrained", |
|
|
357 |
) |
|
|
358 |
ax1.pcolormesh(delays_sec, freqs, weights, **kwargs) |
|
|
359 |
ax2.pcolormesh(times, rf.feature_names, rf.coef_[0], **kwargs) |
|
|
360 |
ax3.pcolormesh(times, rf_lap.feature_names, rf_lap.coef_[0], **kwargs) |
|
|
361 |
ax1.set_title("Original STRF") |
|
|
362 |
ax2.set_title("Best Ridge STRF") |
|
|
363 |
ax3.set_title("Best Laplacian STRF") |
|
|
364 |
plt.setp([iax.get_xticklabels() for iax in [ax1, ax2, ax3]], rotation=45) |
|
|
365 |
|
|
|
366 |
# %% |
|
|
367 |
# References |
|
|
368 |
# ========== |
|
|
369 |
# .. footbibliography:: |