a b/tutorials/machine-learning/30_strf.py
1
"""
2
.. _tut-strf:
3
4
=====================================================================
5
Spectro-temporal receptive field (STRF) estimation on continuous data
6
=====================================================================
7
8
This demonstrates how an encoding model can be fit with multiple continuous
9
inputs. In this case, we simulate the model behind a spectro-temporal receptive
10
field (or STRF). First, we create a linear filter that maps patterns in
11
spectro-temporal space onto an output, representing neural activity. We fit
12
a receptive field model that attempts to recover the original linear filter
13
that was used to create this data.
14
"""
15
# Authors: Chris Holdgraf <choldgraf@gmail.com>
16
#          Eric Larson <larson.eric.d@gmail.com>
17
#
18
# License: BSD-3-Clause
19
# Copyright the MNE-Python contributors.
20
21
# %%
22
23
# sphinx_gallery_thumbnail_number = 7
24
25
import matplotlib.pyplot as plt
26
import numpy as np
27
from scipy.io import loadmat
28
from scipy.stats import multivariate_normal
29
from sklearn.preprocessing import scale
30
31
import mne
32
from mne.decoding import ReceptiveField, TimeDelayingRidge
33
34
rng = np.random.RandomState(1337)  # To make this example reproducible
35
36
# %%
37
# Load audio data
38
# ---------------
39
#
40
# We'll read in the audio data from :footcite:`CrosseEtAl2016` in order to
41
# simulate a response.
42
#
43
# In addition, we'll downsample the data along the time dimension in order to
44
# speed up computation. Note that depending on the input values, this may
45
# not be desired. For example if your input stimulus varies more quickly than
46
# 1/2 the sampling rate to which we are downsampling.
47
48
# Read in audio that's been recorded in epochs.
49
path_audio = mne.datasets.mtrf.data_path()
50
data = loadmat(str(path_audio / "speech_data.mat"))
51
audio = data["spectrogram"].T
52
sfreq = float(data["Fs"][0, 0])
53
n_decim = 2
54
audio = mne.filter.resample(audio, down=n_decim, npad="auto")
55
sfreq /= n_decim
56
57
# %%
58
# Create a receptive field
59
# ------------------------
60
#
61
# We'll simulate a linear receptive field for a theoretical neural signal. This
62
# defines how the signal will respond to power in this receptive field space.
63
n_freqs = 20
64
tmin, tmax = -0.1, 0.4
65
66
# To simulate the data we'll create explicit delays here
67
delays_samp = np.arange(np.round(tmin * sfreq), np.round(tmax * sfreq) + 1).astype(int)
68
delays_sec = delays_samp / sfreq
69
freqs = np.linspace(50, 5000, n_freqs)
70
grid = np.array(np.meshgrid(delays_sec, freqs))
71
72
# We need data to be shaped as n_epochs, n_features, n_times, so swap axes here
73
grid = grid.swapaxes(0, -1).swapaxes(0, 1)
74
75
# Simulate a temporal receptive field with a Gabor filter
76
means_high = [0.1, 500]
77
means_low = [0.2, 2500]
78
cov = [[0.001, 0], [0, 500000]]
79
gauss_high = multivariate_normal.pdf(grid, means_high, cov)
80
gauss_low = -1 * multivariate_normal.pdf(grid, means_low, cov)
81
weights = gauss_high + gauss_low  # Combine to create the "true" STRF
82
kwargs = dict(
83
    vmax=np.abs(weights).max(),
84
    vmin=-np.abs(weights).max(),
85
    cmap="RdBu_r",
86
    shading="gouraud",
87
)
88
89
fig, ax = plt.subplots(layout="constrained")
90
ax.pcolormesh(delays_sec, freqs, weights, **kwargs)
91
ax.set(title="Simulated STRF", xlabel="Time Lags (s)", ylabel="Frequency (Hz)")
92
plt.setp(ax.get_xticklabels(), rotation=45)
93
94
# %%
95
# Simulate a neural response
96
# --------------------------
97
#
98
# Using this receptive field, we'll create an artificial neural response to
99
# a stimulus.
100
#
101
# To do this, we'll create a time-delayed version of the receptive field, and
102
# then calculate the dot product between this and the stimulus. Note that this
103
# is effectively doing a convolution between the stimulus and the receptive
104
# field. See `here <https://en.wikipedia.org/wiki/Convolution>`_ for more
105
# information.
106
107
# Reshape audio to split into epochs, then make epochs the first dimension.
108
n_epochs, n_seconds = 16, 5
109
audio = audio[:, : int(n_seconds * sfreq * n_epochs)]
110
X = audio.reshape([n_freqs, n_epochs, -1]).swapaxes(0, 1)
111
n_times = X.shape[-1]
112
113
# Delay the spectrogram according to delays so it can be combined w/ the STRF
114
# Lags will now be in axis 1, then we reshape to vectorize
115
delays = np.arange(np.round(tmin * sfreq), np.round(tmax * sfreq) + 1).astype(int)
116
117
# Iterate through indices and append
118
X_del = np.zeros((len(delays),) + X.shape)
119
for ii, ix_delay in enumerate(delays):
120
    # These arrays will take/put particular indices in the data
121
    take = [slice(None)] * X.ndim
122
    put = [slice(None)] * X.ndim
123
    if ix_delay > 0:
124
        take[-1] = slice(None, -ix_delay)
125
        put[-1] = slice(ix_delay, None)
126
    elif ix_delay < 0:
127
        take[-1] = slice(-ix_delay, None)
128
        put[-1] = slice(None, ix_delay)
129
    X_del[ii][tuple(put)] = X[tuple(take)]
130
131
# Now set the delayed axis to the 2nd dimension
132
X_del = np.rollaxis(X_del, 0, 3)
133
X_del = X_del.reshape([n_epochs, -1, n_times])
134
n_features = X_del.shape[1]
135
weights_sim = weights.ravel()
136
137
# Simulate a neural response to the sound, given this STRF
138
y = np.zeros((n_epochs, n_times))
139
for ii, iep in enumerate(X_del):
140
    # Simulate this epoch and add random noise
141
    noise_amp = 0.002
142
    y[ii] = np.dot(weights_sim, iep) + noise_amp * rng.randn(n_times)
143
144
# Plot the first 2 trials of audio and the simulated electrode activity
145
X_plt = scale(np.hstack(X[:2]).T).T
146
y_plt = scale(np.hstack(y[:2]))
147
time = np.arange(X_plt.shape[-1]) / sfreq
148
_, (ax1, ax2) = plt.subplots(2, 1, figsize=(6, 6), sharex=True, layout="constrained")
149
ax1.pcolormesh(time, freqs, X_plt, vmin=0, vmax=4, cmap="Reds", shading="gouraud")
150
ax1.set_title("Input auditory features")
151
ax1.set(ylim=[freqs.min(), freqs.max()], ylabel="Frequency (Hz)")
152
ax2.plot(time, y_plt)
153
ax2.set(
154
    xlim=[time.min(), time.max()],
155
    title="Simulated response",
156
    xlabel="Time (s)",
157
    ylabel="Activity (a.u.)",
158
)
159
160
161
# %%
162
# Fit a model to recover this receptive field
163
# -------------------------------------------
164
#
165
# Finally, we'll use the :class:`mne.decoding.ReceptiveField` class to recover
166
# the linear receptive field of this signal. Note that properties of the
167
# receptive field (e.g. smoothness) will depend on the autocorrelation in the
168
# inputs and outputs.
169
170
# Create training and testing data
171
train, test = np.arange(n_epochs - 1), n_epochs - 1
172
X_train, X_test, y_train, y_test = X[train], X[test], y[train], y[test]
173
X_train, X_test, y_train, y_test = (
174
    np.rollaxis(ii, -1, 0) for ii in (X_train, X_test, y_train, y_test)
175
)
176
# Model the simulated data as a function of the spectrogram input
177
alphas = np.logspace(-3, 3, 7)
178
scores = np.zeros_like(alphas)
179
models = []
180
for ii, alpha in enumerate(alphas):
181
    rf = ReceptiveField(tmin, tmax, sfreq, freqs, estimator=alpha)
182
    rf.fit(X_train, y_train)
183
184
    # Now make predictions about the model output, given input stimuli.
185
    scores[ii] = rf.score(X_test, y_test).item()
186
    models.append(rf)
187
188
times = rf.delays_ / float(rf.sfreq)
189
190
# Choose the model that performed best on the held out data
191
ix_best_alpha = np.argmax(scores)
192
best_mod = models[ix_best_alpha]
193
coefs = best_mod.coef_[0]
194
best_pred = best_mod.predict(X_test)[:, 0]
195
196
# Plot the original STRF, and the one that we recovered with modeling.
197
_, (ax1, ax2) = plt.subplots(
198
    1,
199
    2,
200
    figsize=(6, 3),
201
    sharey=True,
202
    sharex=True,
203
    layout="constrained",
204
)
205
ax1.pcolormesh(delays_sec, freqs, weights, **kwargs)
206
ax2.pcolormesh(times, rf.feature_names, coefs, **kwargs)
207
ax1.set_title("Original STRF")
208
ax2.set_title("Best Reconstructed STRF")
209
plt.setp([iax.get_xticklabels() for iax in [ax1, ax2]], rotation=45)
210
211
# Plot the actual response and the predicted response on a held out stimulus
212
time_pred = np.arange(best_pred.shape[0]) / sfreq
213
fig, ax = plt.subplots()
214
ax.plot(time_pred, y_test, color="k", alpha=0.2, lw=4)
215
ax.plot(time_pred, best_pred, color="r", lw=1)
216
ax.set(title="Original and predicted activity", xlabel="Time (s)")
217
ax.legend(["Original", "Predicted"])
218
219
220
# %%
221
# Visualize the effects of regularization
222
# ---------------------------------------
223
#
224
# Above we fit a :class:`mne.decoding.ReceptiveField` model for one of many
225
# values for the ridge regularization parameter. Here we will plot the model
226
# score as well as the model coefficients for each value, in order to
227
# visualize how coefficients change with different levels of regularization.
228
# These issues as well as the STRF pipeline are described in detail
229
# in :footcite:`TheunissenEtAl2001,WillmoreSmyth2003,HoldgrafEtAl2016`.
230
231
# Plot model score for each ridge parameter
232
fig = plt.figure(figsize=(10, 4), layout="constrained")
233
ax = plt.subplot2grid([2, len(alphas)], [1, 0], 1, len(alphas))
234
ax.plot(np.arange(len(alphas)), scores, marker="o", color="r")
235
ax.annotate(
236
    "Best parameter",
237
    (ix_best_alpha, scores[ix_best_alpha]),
238
    (ix_best_alpha, scores[ix_best_alpha] - 0.1),
239
    arrowprops={"arrowstyle": "->"},
240
)
241
plt.xticks(np.arange(len(alphas)), [f"{ii:.0e}" for ii in alphas])
242
ax.set(
243
    xlabel="Ridge regularization value",
244
    ylabel="Score ($R^2$)",
245
    xlim=[-0.4, len(alphas) - 0.6],
246
)
247
248
# Plot the STRF of each ridge parameter
249
for ii, (rf, i_alpha) in enumerate(zip(models, alphas)):
250
    ax = plt.subplot2grid([2, len(alphas)], [0, ii], 1, 1)
251
    ax.pcolormesh(times, rf.feature_names, rf.coef_[0], **kwargs)
252
    plt.xticks([], [])
253
    plt.yticks([], [])
254
fig.suptitle("Model coefficients / scores for many ridge parameters", y=1)
255
256
# %%
257
# Using different regularization types
258
# ------------------------------------
259
# In addition to the standard ridge regularization, the
260
# :class:`mne.decoding.TimeDelayingRidge` class also exposes
261
# `Laplacian <https://en.wikipedia.org/wiki/Laplacian_matrix>`_ regularization
262
# term as:
263
#
264
# .. math::
265
#    \left[\begin{matrix}
266
#         1 & -1 &   &   & & \\
267
#        -1 &  2 & -1 &   & & \\
268
#           & -1 & 2 & -1 & & \\
269
#           & & \ddots & \ddots & \ddots & \\
270
#           & & & -1 & 2 & -1 \\
271
#           & & &    & -1 & 1\end{matrix}\right]
272
#
273
# This imposes a smoothness constraint of nearby time samples and/or features.
274
# Quoting :footcite:`CrosseEtAl2016` :
275
#
276
#    Tikhonov [identity] regularization (Equation 5) reduces overfitting by
277
#    smoothing the TRF estimate in a way that is insensitive to
278
#    the amplitude of the signal of interest. However, the Laplacian
279
#    approach (Equation 6) reduces off-sample error whilst preserving
280
#    signal amplitude (Lalor et al., 2006). As a result, this approach
281
#    usually leads to an improved estimate of the system’s response (as
282
#    indexed by MSE) compared to Tikhonov regularization.
283
#
284
285
scores_lap = np.zeros_like(alphas)
286
models_lap = []
287
for ii, alpha in enumerate(alphas):
288
    estimator = TimeDelayingRidge(tmin, tmax, sfreq, reg_type="laplacian", alpha=alpha)
289
    rf = ReceptiveField(tmin, tmax, sfreq, freqs, estimator=estimator)
290
    rf.fit(X_train, y_train)
291
292
    # Now make predictions about the model output, given input stimuli.
293
    scores_lap[ii] = rf.score(X_test, y_test).item()
294
    models_lap.append(rf)
295
296
ix_best_alpha_lap = np.argmax(scores_lap)
297
298
# %%
299
# Compare model performance
300
# -------------------------
301
# Below we visualize the model performance of each regularization method
302
# (ridge vs. Laplacian) for different levels of alpha. As you can see, the
303
# Laplacian method performs better in general, because it imposes a smoothness
304
# constraint along the time and feature dimensions of the coefficients.
305
# This matches the "true" receptive field structure and results in a better
306
# model fit.
307
308
fig = plt.figure(figsize=(10, 6), layout="constrained")
309
ax = plt.subplot2grid([3, len(alphas)], [2, 0], 1, len(alphas))
310
ax.plot(np.arange(len(alphas)), scores_lap, marker="o", color="r")
311
ax.plot(np.arange(len(alphas)), scores, marker="o", color="0.5", ls=":")
312
ax.annotate(
313
    "Best Laplacian",
314
    (ix_best_alpha_lap, scores_lap[ix_best_alpha_lap]),
315
    (ix_best_alpha_lap, scores_lap[ix_best_alpha_lap] - 0.1),
316
    arrowprops={"arrowstyle": "->"},
317
)
318
ax.annotate(
319
    "Best Ridge",
320
    (ix_best_alpha, scores[ix_best_alpha]),
321
    (ix_best_alpha, scores[ix_best_alpha] - 0.1),
322
    arrowprops={"arrowstyle": "->"},
323
)
324
plt.xticks(np.arange(len(alphas)), [f"{ii:.0e}" for ii in alphas])
325
ax.set(
326
    xlabel="Laplacian regularization value",
327
    ylabel="Score ($R^2$)",
328
    xlim=[-0.4, len(alphas) - 0.6],
329
)
330
331
# Plot the STRF of each ridge parameter
332
xlim = times[[0, -1]]
333
for ii, (rf_lap, rf, i_alpha) in enumerate(zip(models_lap, models, alphas)):
334
    ax = plt.subplot2grid([3, len(alphas)], [0, ii], 1, 1)
335
    ax.pcolormesh(times, rf_lap.feature_names, rf_lap.coef_[0], **kwargs)
336
    ax.set(xticks=[], yticks=[], xlim=xlim)
337
    if ii == 0:
338
        ax.set(ylabel="Laplacian")
339
    ax = plt.subplot2grid([3, len(alphas)], [1, ii], 1, 1)
340
    ax.pcolormesh(times, rf.feature_names, rf.coef_[0], **kwargs)
341
    ax.set(xticks=[], yticks=[], xlim=xlim)
342
    if ii == 0:
343
        ax.set(ylabel="Ridge")
344
fig.suptitle("Model coefficients / scores for laplacian regularization", y=1)
345
346
# %%
347
# Plot the original STRF, and the one that we recovered with modeling.
348
rf = models[ix_best_alpha]
349
rf_lap = models_lap[ix_best_alpha_lap]
350
_, (ax1, ax2, ax3) = plt.subplots(
351
    1,
352
    3,
353
    figsize=(9, 3),
354
    sharey=True,
355
    sharex=True,
356
    layout="constrained",
357
)
358
ax1.pcolormesh(delays_sec, freqs, weights, **kwargs)
359
ax2.pcolormesh(times, rf.feature_names, rf.coef_[0], **kwargs)
360
ax3.pcolormesh(times, rf_lap.feature_names, rf_lap.coef_[0], **kwargs)
361
ax1.set_title("Original STRF")
362
ax2.set_title("Best Ridge STRF")
363
ax3.set_title("Best Laplacian STRF")
364
plt.setp([iax.get_xticklabels() for iax in [ax1, ax2, ax3]], rotation=45)
365
366
# %%
367
# References
368
# ==========
369
# .. footbibliography::