[074d3d]: / mne / viz / topo.py

Download this file

1376 lines (1242 with data), 42.6 kB

   1
   2
   3
   4
   5
   6
   7
   8
   9
  10
  11
  12
  13
  14
  15
  16
  17
  18
  19
  20
  21
  22
  23
  24
  25
  26
  27
  28
  29
  30
  31
  32
  33
  34
  35
  36
  37
  38
  39
  40
  41
  42
  43
  44
  45
  46
  47
  48
  49
  50
  51
  52
  53
  54
  55
  56
  57
  58
  59
  60
  61
  62
  63
  64
  65
  66
  67
  68
  69
  70
  71
  72
  73
  74
  75
  76
  77
  78
  79
  80
  81
  82
  83
  84
  85
  86
  87
  88
  89
  90
  91
  92
  93
  94
  95
  96
  97
  98
  99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
"""Functions to plot M/EEG data on topo (one axes per channel)."""
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
from copy import deepcopy
from functools import partial
import numpy as np
from scipy import ndimage
from .._fiff.pick import _picks_to_idx, channel_type, pick_types
from ..defaults import _handle_default
from ..utils import Bunch, _check_option, _clean_names, _is_numeric, _to_rgb, fill_doc
from .ui_events import ChannelsSelect, publish, subscribe
from .utils import (
DraggableColorbar,
SelectFromCollection,
_check_cov,
_check_delayed_ssp,
_draw_proj_checkbox,
_plot_masked_image,
_setup_ax_spines,
_setup_vmin_vmax,
add_background_image,
plt_show,
)
@fill_doc
def iter_topography(
info,
layout=None,
on_pick=None,
fig=None,
fig_facecolor="k",
axis_facecolor="k",
axis_spinecolor="k",
layout_scale=None,
legend=False,
select=False,
):
"""Create iterator over channel positions.
This function returns a generator that unpacks into
a series of matplotlib axis objects and data / channel
indices, both corresponding to the sensor positions
of the related layout passed or inferred from the channel info.
Hence, this enables convenient topography plot customization.
Parameters
----------
%(info_not_none)s
layout : instance of mne.channels.Layout | None
The layout to use. If None, layout will be guessed.
on_pick : callable | None
The callback function to be invoked on clicking one
of the axes. Is supposed to instantiate the following
API: ``function(axis, channel_index)``.
fig : matplotlib.figure.Figure | None
The figure object to be considered. If None, a new
figure will be created.
fig_facecolor : color
The figure face color. Defaults to black.
axis_facecolor : color
The axis face color. Defaults to black.
axis_spinecolor : color
The axis spine color. Defaults to black. In other words,
the color of the axis' edge lines.
layout_scale : float | None
Scaling factor for adjusting the relative size of the layout
on the canvas. If None, nothing will be scaled.
legend : bool
If True, an additional axis is created in the bottom right corner
that can be used to, e.g., construct a legend. The index of this
axis will be -1.
select : bool
Whether to enable the lasso-selection tool to enable the user to select
channels. The selected channels will be available in
``fig.lasso.selection``.
.. versionadded:: 1.10.0
Returns
-------
gen : generator
A generator that can be unpacked into:
ax : matplotlib.axis.Axis
The current axis of the topo plot.
ch_dx : int
The related channel index.
"""
return _iter_topography(
info,
layout,
on_pick,
fig,
fig_facecolor,
axis_facecolor,
axis_spinecolor,
layout_scale,
legend=legend,
select=select,
)
def _legend_axis(pos):
"""Add a legend axis to the bottom right."""
import matplotlib.pyplot as plt
left, bottom = pos[:, 0].max(), pos[:, 1].min()
# check if legend axis overlaps a data axis
overlaps = False
for _pos in pos:
h_overlap = _pos[0] <= left <= (_pos[0] + _pos[2])
v_overlap = _pos[1] <= bottom <= (_pos[1] + _pos[3])
if h_overlap and v_overlap:
overlaps = True
break
if overlaps:
left += 1.2 * _pos[2]
wid, hei = pos[-1, 2:]
return plt.axes([left, bottom, wid, hei])
def _iter_topography(
info,
layout,
on_pick,
fig,
fig_facecolor="k",
axis_facecolor="k",
axis_spinecolor="k",
layout_scale=None,
unified=False,
img=False,
axes=None,
legend=False,
select=False,
):
"""Iterate over topography.
Has the same parameters as iter_topography, plus:
unified : bool
If False (default), multiple matplotlib axes will be used.
If True, a single axis will be constructed. The former is
useful for custom plotting, the latter for speed.
"""
from matplotlib import collections
from matplotlib import pyplot as plt
from ..channels.layout import find_layout
if fig is None:
# Don't use constrained layout because we place axes manually
fig = plt.figure(layout=None)
def format_coord_unified(x, y, pos=None, ch_names=None):
"""Update status bar with channel name under cursor."""
# find candidate channels (ones that are down and left from cursor)
pdist = np.array([x, y]) - pos[:, :2]
pind = np.where((pdist >= 0).all(axis=1))[0]
if len(pind) > 0:
# find the closest channel
closest = pind[np.sum(pdist[pind, :] ** 2, axis=1).argmin()]
# check whether we are inside its box
in_box = (pdist[closest, :] < pos[closest, 2:]).all()
else:
in_box = False
return (
f"{ch_names[closest]} (click to magnify)" if in_box else "No channel here"
)
def format_coord_multiaxis(x, y, ch_name=None):
"""Update status bar with channel name under cursor."""
return f"{ch_name} (click to magnify)"
fig.set_facecolor(fig_facecolor)
if layout is None:
layout = find_layout(info)
if on_pick is not None:
callback = partial(_plot_topo_onpick, show_func=on_pick)
fig.canvas.mpl_connect("button_press_event", callback)
pos = layout.pos.copy()
if layout_scale:
pos[:, :2] *= layout_scale
ch_names = _clean_names(info["ch_names"])
iter_ch = [(x, y) for x, y in enumerate(layout.names) if y in ch_names]
if unified:
if axes is None:
under_ax = plt.axes([0, 0, 1, 1])
under_ax.axis("off")
else:
under_ax = axes
under_ax.format_coord = partial(
format_coord_unified, pos=pos, ch_names=layout.names
)
under_ax.set(xlim=[0, 1], ylim=[0, 1])
axs = list()
shown_ch_names = []
for idx, name in iter_ch:
ch_idx = ch_names.index(name)
shown_ch_names.append(name)
if not unified: # old, slow way
ax = plt.axes(pos[idx])
ax.patch.set_facecolor(axis_facecolor)
for spine in ax.spines.values():
spine.set_color(axis_spinecolor)
if not legend:
ax.set(xticklabels=[], yticklabels=[])
for tick in ax.get_xticklines() + ax.get_yticklines():
tick.set_visible(False)
ax._mne_ch_name = name
ax._mne_ch_idx = ch_idx
ax._mne_ax_face_color = axis_facecolor
ax.format_coord = partial(format_coord_multiaxis, ch_name=name)
yield ax, ch_idx
else:
ax = Bunch(
ax=under_ax,
pos=pos[idx],
data_lines=list(),
_mne_ch_name=name,
_mne_ch_idx=ch_idx,
_mne_ax_face_color=axis_facecolor,
)
axs.append(ax)
if not unified and legend:
ax = _legend_axis(pos)
yield ax, -1
if unified:
under_ax._mne_axs = axs
# Create a PolyCollection for the axis backgrounds
sel_pos = pos[[i[0] for i in iter_ch]]
verts = np.transpose(
[
sel_pos[:, :2],
sel_pos[:, :2] + sel_pos[:, 2:] * [1, 0],
sel_pos[:, :2] + sel_pos[:, 2:],
sel_pos[:, :2] + sel_pos[:, 2:] * [0, 1],
],
[1, 0, 2],
)
if not img: # Not needed for image plots.
collection = collections.PolyCollection(
verts,
facecolor=axis_facecolor,
edgecolor=axis_spinecolor,
linewidth=1.0,
)
under_ax.add_collection(collection)
if select:
# Configure the lasso-selection tool
fig.lasso = SelectFromCollection(
ax=under_ax,
collection=collection,
names=shown_ch_names,
alpha_nonselected=0,
alpha_selected=1,
linewidth_nonselected=0,
linewidth_selected=0.7,
)
def on_select():
publish(fig, ChannelsSelect(ch_names=fig.lasso.selection))
def on_channels_select(event):
selection_inds = np.flatnonzero(
np.isin(shown_ch_names, event.ch_names)
)
fig.lasso.select_many(selection_inds)
fig.lasso.callbacks.append(on_select)
subscribe(fig, "channels_select", on_channels_select)
for ax in axs:
yield ax, ax._mne_ch_idx
def _plot_topo(
info,
times,
show_func,
click_func=None,
layout=None,
vmin=None,
vmax=None,
ylim=None,
colorbar=None,
border="none",
axis_facecolor="k",
fig_facecolor="k",
cmap="RdBu_r",
layout_scale=None,
title=None,
x_label=None,
y_label=None,
font_color="w",
unified=False,
img=False,
axes=None,
select=False,
):
"""Plot on sensor layout."""
import matplotlib.pyplot as plt
if layout.kind == "custom":
layout = deepcopy(layout)
layout.pos[:, :2] -= layout.pos[:, :2].min(0)
layout.pos[:, :2] /= layout.pos[:, :2].max(0)
# prepare callbacks
tmin, tmax = times[0], times[-1]
click_func = show_func if click_func is None else click_func
on_pick = partial(
click_func,
tmin=tmin,
tmax=tmax,
vmin=vmin,
vmax=vmax,
ylim=ylim,
x_label=x_label,
y_label=y_label,
)
if axes is None:
# Don't use constrained layout because we place axes manually
fig = plt.figure(layout=None)
axes = plt.axes([0.015, 0.025, 0.97, 0.95])
axes.set_facecolor(fig_facecolor)
else:
fig = axes.figure
if colorbar:
sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin, vmax))
sm.set_array(np.linspace(vmin, vmax))
cb = fig.colorbar(
sm, ax=axes, pad=0.025, fraction=0.075, shrink=0.5, anchor=(-1, 0.5)
)
cb_yticks = plt.getp(cb.ax.axes, "yticklabels")
plt.setp(cb_yticks, color=font_color)
axes.axis("off")
my_topo_plot = _iter_topography(
info,
layout=layout,
on_pick=on_pick,
fig=fig,
layout_scale=layout_scale,
axis_spinecolor=border,
axis_facecolor=axis_facecolor,
fig_facecolor=fig_facecolor,
unified=unified,
img=img,
axes=axes,
select=select,
)
for ax, ch_idx in my_topo_plot:
if layout.kind == "Vectorview-all" and ylim is not None:
ylim_ = ylim.get(channel_type(info, ch_idx))
else:
ylim_ = ylim
show_func(ax, ch_idx, tmin=tmin, tmax=tmax, vmin=vmin, vmax=vmax, ylim=ylim_)
if title is not None:
plt.figtext(0.03, 0.95, title, color=font_color, fontsize=15, va="top")
return fig
def _plot_topo_onpick(event, show_func):
"""Onpick callback that shows a single channel in a new figure."""
orig_ax = event.inaxes
fig = orig_ax.figure
# If we are doing lasso select, allow it to handle the click instead.
if hasattr(fig, "lasso") and event.key in ["control", "ctrl+shift"]:
return
# make sure that the swipe gesture in OS-X doesn't open many figures
if fig.canvas._key in ["shift", "alt"]:
return
import matplotlib.pyplot as plt
try:
if hasattr(orig_ax, "_mne_axs"): # in unified, single-axes mode
x, y = event.xdata, event.ydata
for ax in orig_ax._mne_axs:
if (
x >= ax.pos[0]
and y >= ax.pos[1]
and x <= ax.pos[0] + ax.pos[2]
and y <= ax.pos[1] + ax.pos[3]
):
orig_ax = ax
break
else:
# no axis found
return
elif not hasattr(orig_ax, "_mne_ch_idx"):
# neither old nor new mode
return
ch_idx = orig_ax._mne_ch_idx
face_color = orig_ax._mne_ax_face_color
fig, ax = plt.subplots(1)
plt.title(orig_ax._mne_ch_name)
ax.set_facecolor(face_color)
# allow custom function to override parameters
show_func(ax, ch_idx)
plt_show(fig=fig)
except Exception as err:
# matplotlib silently ignores exceptions in event handlers,
# so we print
# it here to know what went wrong
print(err)
raise
def _compute_ax_scalings(bn, xlim, ylim):
"""Compute scale factors for a unified plot."""
if isinstance(ylim, dict):
# Take the first (ymin, ymax) entry.
ylim = next(iter(ylim.values()))
pos = bn.pos
bn.x_s = pos[2] / (xlim[1] - xlim[0])
bn.x_t = pos[0] - bn.x_s * xlim[0]
bn.y_s = pos[3] / (ylim[1] - ylim[0])
bn.y_t = pos[1] - bn.y_s * ylim[0]
def _imshow_tfr(
ax,
ch_idx,
tmin,
tmax,
vmin,
vmax,
onselect,
*,
ylim=None,
tfr=None,
freq=None,
x_label=None,
y_label=None,
colorbar=False,
cmap=("RdBu_r", True),
yscale="auto",
mask=None,
mask_style="both",
mask_cmap="Greys",
mask_alpha=0.1,
cnorm=None,
):
"""Show time-frequency map as two-dimensional image."""
from matplotlib.widgets import RectangleSelector
_check_option("yscale", yscale, ["auto", "linear", "log"])
cmap, interactive_cmap = cmap
times = np.linspace(tmin, tmax, num=tfr[ch_idx].shape[1])
img, t_end = _plot_masked_image(
ax,
tfr[ch_idx],
times,
mask,
yvals=freq,
cmap=cmap,
vmin=vmin,
vmax=vmax,
mask_style=mask_style,
mask_alpha=mask_alpha,
mask_cmap=mask_cmap,
yscale=yscale,
cnorm=cnorm,
)
if x_label is not None:
ax.set_xlabel(x_label)
if y_label is not None:
ax.set_ylabel(y_label)
if colorbar:
if isinstance(colorbar, DraggableColorbar):
cbar = colorbar.cbar # this happens with multiaxes case
else:
cbar = ax.get_figure().colorbar(mappable=img, ax=ax)
if interactive_cmap:
ax.CB = DraggableColorbar(cbar, img, kind="tfr_image", ch_type=None)
ax.RS = RectangleSelector(ax, onselect=onselect) # reference must be kept
return t_end
def _imshow_tfr_unified(
bn,
ch_idx,
tmin,
tmax,
vmin,
vmax,
onselect,
*,
ylim=None,
tfr=None,
freq=None,
vline=None,
x_label=None,
y_label=None,
colorbar=False,
picker=True,
cmap="RdBu_r",
title=None,
hline=None,
):
"""Show multiple tfrs on topo using a single axes."""
_compute_ax_scalings(bn, (tmin, tmax), (freq[0], freq[-1]))
ax = bn.ax
data_lines = bn.data_lines
extent = (
bn.x_t + bn.x_s * tmin,
bn.x_t + bn.x_s * tmax,
bn.y_t + bn.y_s * freq[0],
bn.y_t + bn.y_s * freq[-1],
)
data_lines.append(
ax.imshow(
tfr[ch_idx],
extent=extent,
aspect="auto",
origin="lower",
vmin=vmin,
vmax=vmax,
cmap=cmap,
)
)
data_lines[-1].set_clip_box(_pos_to_bbox(bn.pos, ax))
def _plot_timeseries(
ax,
ch_idx,
tmin,
tmax,
vmin,
vmax,
ylim,
data,
color,
times,
vline=None,
x_label=None,
y_label=None,
colorbar=False,
hline=None,
hvline_color="w",
labels=None,
):
"""Show time series on topo split across multiple axes."""
import matplotlib.pyplot as plt
picker_flag = False
for data_, color_, times_ in zip(data, color, times):
if not picker_flag:
# use large tol for picker so we can click anywhere in the axes
line = ax.plot(times_, data_[ch_idx], color=color_, picker=True)[0]
line.set_pickradius(1e9)
picker_flag = True
else:
ax.plot(times_, data_[ch_idx], color=color_)
def _format_coord(x, y, labels, ax):
"""Create status string based on cursor coordinates."""
# find indices for datasets near cursor (if any)
tdiffs = [np.abs(tvec - x).min() for tvec in times]
nearby = [k for k, tdiff in enumerate(tdiffs) if tdiff < (tmax - tmin) / 100]
xlabel = ax.get_xlabel()
xunit = (
xlabel[xlabel.find("(") + 1 : xlabel.find(")")]
if "(" in xlabel and ")" in xlabel
else "s"
)
timestr = f"{x:6.3f} {xunit}: "
if not nearby:
return f"{timestr} Nothing here"
labels = [""] * len(nearby) if labels is None else labels
nearby_data = [(data[n], labels[n], times[n]) for n in nearby]
ylabel = ax.get_ylabel()
yunit = (
ylabel[ylabel.find("(") + 1 : ylabel.find(")")]
if "(" in ylabel and ")" in ylabel
else ""
)
# try to estimate whether to truncate condition labels
slen = 9 + len(xunit) + sum([12 + len(yunit) + len(label) for label in labels])
bar_width = (ax.figure.get_size_inches() * ax.figure.dpi)[0] / 5.5
# show labels and y values for datasets near cursor
trunc_labels = bar_width < slen
s = timestr
for data_, label, tvec in nearby_data:
idx = np.abs(tvec - x).argmin()
s += f"{data_[ch_idx, idx]:7.2f} {yunit}"
if trunc_labels:
label = label if len(label) <= 10 else f"{label[:6]}..{label[-2:]}"
s += f" [{label}] " if label else " "
return s
ax.format_coord = lambda x, y: _format_coord(x, y, labels=labels, ax=ax)
def _cursor_vline(event):
"""Draw cursor (vertical line)."""
ax = event.inaxes
if not ax:
return
if ax._cursorline is not None:
ax._cursorline.remove()
ax._cursorline = ax.axvline(event.xdata, color=ax._cursorcolor)
ax.figure.canvas.draw()
def _rm_cursor(event):
ax = event.inaxes
if ax._cursorline is not None:
ax._cursorline.remove()
ax._cursorline = None
ax.figure.canvas.draw()
ax._cursorline = None
# choose cursor color based on perceived brightness of background
facecol = _to_rgb(ax.get_facecolor())
face_brightness = np.dot(facecol, [299, 587, 114])
ax._cursorcolor = "white" if face_brightness < 150 else "black"
plt.connect("motion_notify_event", _cursor_vline)
plt.connect("axes_leave_event", _rm_cursor)
ymin, ymax = ax.get_ylim()
# don't pass vline or hline here (this fxn doesn't do hvline_color):
_setup_ax_spines(ax, [], tmin, tmax, ymin, ymax, hline=False)
ax.figure.set_facecolor("k" if hvline_color == "w" else "w")
ax.spines["bottom"].set_color(hvline_color)
ax.spines["left"].set_color(hvline_color)
ax.tick_params(axis="x", colors=hvline_color, which="both")
ax.tick_params(axis="y", colors=hvline_color, which="both")
ax.title.set_color(hvline_color)
ax.xaxis.label.set_color(hvline_color)
ax.yaxis.label.set_color(hvline_color)
if x_label is not None:
ax.set_xlabel(x_label)
if y_label is not None:
if isinstance(y_label, list):
ax.set_ylabel(y_label[ch_idx])
else:
ax.set_ylabel(y_label)
if vline is not None:
vline = [vline] if _is_numeric(vline) else vline
for vline_ in vline:
plt.axvline(vline_, color=hvline_color, linewidth=1.0, linestyle="--")
if hline is not None:
hline = [hline] if _is_numeric(hline) else hline
for hline_ in hline:
plt.axhline(hline_, color=hvline_color, linewidth=1.0, zorder=10)
if colorbar:
plt.colorbar()
def _plot_timeseries_unified(
bn,
ch_idx,
tmin,
tmax,
vmin,
vmax,
ylim,
data,
color,
times,
vline=None,
x_label=None,
y_label=None,
colorbar=False,
hline=None,
hvline_color="w",
):
"""Show multiple time series on topo using a single axes."""
import matplotlib.pyplot as plt
if not (ylim and not any(v is None for v in ylim)):
ylim = [min(np.min(d) for d in data), max(np.max(d) for d in data)]
# Translation and scale parameters to take data->under_ax normalized coords
_compute_ax_scalings(bn, (tmin, tmax), ylim)
pos = bn.pos
data_lines = bn.data_lines
ax = bn.ax
for data_, color_, times_ in zip(data, color, times):
data_lines.append(
ax.plot(
bn.x_t + bn.x_s * times_,
bn.y_t + bn.y_s * data_[ch_idx],
linewidth=0.5,
color=color_,
)[0]
)
# Needs to be done afterward for some reason (probable matlotlib bug)
data_lines[-1].set_clip_box(_pos_to_bbox(pos, ax))
if vline:
vline = np.array(vline) * bn.x_s + bn.x_t
ax.vlines(
vline,
pos[1],
pos[1] + pos[3],
color=hvline_color,
linewidth=0.5,
linestyle="--",
)
if hline:
hline = np.array(hline) * bn.y_s + bn.y_t
ax.hlines(hline, pos[0], pos[0] + pos[2], color=hvline_color, linewidth=0.5)
if x_label is not None:
ax.text(
pos[0] + pos[2] / 2.0,
pos[1],
x_label,
horizontalalignment="center",
verticalalignment="top",
)
if y_label is not None:
y_label = y_label[ch_idx] if isinstance(y_label, list) else y_label
ax.text(
pos[0],
pos[1] + pos[3] / 2.0,
y_label,
horizontalignment="right",
verticalalignment="middle",
rotation=90,
)
if colorbar:
plt.colorbar()
def _erfimage_imshow(
ax,
ch_idx,
tmin,
tmax,
vmin,
vmax,
ylim=None,
data=None,
epochs=None,
sigma=None,
order=None,
scalings=None,
vline=None,
x_label=None,
y_label=None,
colorbar=False,
cmap="RdBu_r",
vlim_array=None,
):
"""Plot erfimage on sensor topography."""
import matplotlib.pyplot as plt
this_data = data[:, ch_idx, :]
if vlim_array is not None:
vmin, vmax = vlim_array[ch_idx]
if callable(order):
order = order(epochs.times, this_data)
if order is not None:
this_data = this_data[order]
if sigma > 0.0:
this_data = ndimage.gaussian_filter1d(this_data, sigma=sigma, axis=0)
img = ax.imshow(
this_data,
extent=[tmin, tmax, 0, len(data)],
aspect="auto",
origin="lower",
vmin=vmin,
vmax=vmax,
picker=True,
cmap=cmap,
interpolation="nearest",
)
ax = plt.gca()
if x_label is not None:
ax.set_xlabel(x_label)
if y_label is not None:
ax.set_ylabel(y_label)
if colorbar:
plt.colorbar(mappable=img)
def _erfimage_imshow_unified(
bn,
ch_idx,
tmin,
tmax,
vmin,
vmax,
ylim=None,
data=None,
epochs=None,
sigma=None,
order=None,
scalings=None,
vline=None,
x_label=None,
y_label=None,
colorbar=False,
cmap="RdBu_r",
vlim_array=None,
):
"""Plot erfimage topography using a single axis."""
_compute_ax_scalings(bn, (tmin, tmax), (0, len(epochs.events)))
ax = bn.ax
data_lines = bn.data_lines
extent = (
bn.x_t + bn.x_s * tmin,
bn.x_t + bn.x_s * tmax,
bn.y_t,
bn.y_t + bn.y_s * len(epochs.events),
)
this_data = data[:, ch_idx, :]
vmin, vmax = (None, None) if vlim_array is None else vlim_array[ch_idx]
if callable(order):
order = order(epochs.times, this_data)
if order is not None:
this_data = this_data[order]
if sigma > 0.0:
this_data = ndimage.gaussian_filter1d(this_data, sigma=sigma, axis=0)
data_lines.append(
ax.imshow(
this_data,
extent=extent,
aspect="auto",
origin="lower",
vmin=vmin,
vmax=vmax,
picker=True,
cmap=cmap,
interpolation="nearest",
)
)
def _plot_evoked_topo(
evoked,
layout=None,
layout_scale=0.945,
color=None,
border="none",
ylim=None,
scalings=None,
title=None,
proj=False,
vline=(0.0,),
hline=(0.0,),
fig_facecolor="k",
fig_background=None,
axis_facecolor="k",
font_color="w",
merge_channels=False,
legend=True,
axes=None,
noise_cov=None,
exclude="bads",
select=False,
show=True,
):
"""Plot 2D topography of evoked responses.
Clicking on the plot of an individual sensor opens a new figure showing
the evoked response for the selected sensor.
Parameters
----------
evoked : list of Evoked | Evoked
The evoked response to plot.
layout : instance of Layout | None
Layout instance specifying sensor positions (does not need to
be specified for Neuromag data). If possible, the correct layout is
inferred from the data.
layout_scale : float
Scaling factor for adjusting the relative size of the layout
on the canvas.
color : list of color objects | color object | None
Everything matplotlib accepts to specify colors. If not list-like,
the color specified will be repeated. If None, colors are
automatically drawn.
border : str
Matplotlib borders style to be used for each sensor plot.
ylim : dict | None
ylim for plots (after scaling has been applied). The value
determines the upper and lower subplot limits. e.g.
ylim = dict(eeg=[-20, 20]). Valid keys are eeg, mag, grad. If None,
the ylim parameter for each channel type is determined by the minimum
and maximum peak.
scalings : dict | None
The scalings of the channel types to be applied for plotting. If None,`
defaults to ``dict(eeg=1e6, grad=1e13, mag=1e15)``.
title : str
Title of the figure.
proj : bool | 'interactive'
If true SSP projections are applied before display. If 'interactive',
a check box for reversible selection of SSP projection vectors will
be shown.
vline : list of floats | None
The values at which to show a vertical line.
hline : list of floats | None
The values at which to show a horizontal line.
fig_facecolor : color
The figure face color. Defaults to black.
fig_background : None | array
A background image for the figure. This must be a valid input to
`matplotlib.pyplot.imshow`. Defaults to None.
axis_facecolor : color
The face color to be used for each sensor plot. Defaults to black.
font_color : color
The color of text in the colorbar and title. Defaults to white.
merge_channels : bool
Whether to use RMS value of gradiometer pairs. Only works for Neuromag
data. Defaults to False.
legend : bool | int | string | tuple
If True, create a legend based on evoked.comment. If False, disable the
legend. Otherwise, the legend is created and the parameter value is
passed as the location parameter to the matplotlib legend call. It can
be an integer (e.g. 0 corresponds to upper right corner of the plot),
a string (e.g. 'upper right'), or a tuple (x, y coordinates of the
lower left corner of the legend in the axes coordinate system).
See matplotlib documentation for more details.
axes : instance of matplotlib Axes | None
Axes to plot into. If None, axes will be created.
noise_cov : instance of Covariance | str | None
Noise covariance used to whiten the data while plotting.
Whitened data channels names are shown in italic.
Can be a string to load a covariance from disk.
exclude : list of str | 'bads'
Channels names to exclude from being shown. If 'bads', the
bad channels are excluded. By default, exclude is set to 'bads'.
select : bool
Whether to enable the lasso-selection tool to enable the user to select
channels. The selected channels will be available in
``fig.lasso.selection``.
show : bool
Show figure if True.
.. versionadded:: 0.16.0
Returns
-------
fig : instance of matplotlib.figure.Figure
Images of evoked responses at sensor locations
"""
import matplotlib.pyplot as plt
from ..channels.layout import _merge_ch_data, _pair_grad_sensors, find_layout
from ..cov import whiten_evoked
if type(evoked) not in (tuple, list):
evoked = [evoked]
noise_cov = _check_cov(noise_cov, evoked[0].info)
if noise_cov is not None:
evoked = [whiten_evoked(e, noise_cov) for e in evoked]
else:
evoked = [e.copy() for e in evoked]
info = evoked[0].info
ch_names = evoked[0].ch_names
scalings = _handle_default("scalings", scalings)
if not all(e.ch_names == ch_names for e in evoked):
raise ValueError("All evoked.picks must be the same")
ch_names = _clean_names(ch_names)
if merge_channels:
picks = _pair_grad_sensors(info, topomap_coords=False, exclude=exclude)
chs = list()
for pick in picks[::2]:
ch = info["chs"][pick]
ch["ch_name"] = ch["ch_name"][:-1] + "X"
chs.append(ch)
with info._unlock(update_redundant=True, check_after=True):
info["chs"] = chs
info["bads"] = list() # Bads handled by pair_grad_sensors
new_picks = list()
for e in evoked:
data, _ = _merge_ch_data(e.data[picks], "grad", [])
if noise_cov is None:
data *= scalings["grad"]
e.data = data
new_picks.append(range(len(data)))
picks = new_picks
types_used = ["grad"]
unit = _handle_default("units")["grad"] if noise_cov is None else "NA"
y_label = f"RMS amplitude ({unit})"
if layout is None:
layout = find_layout(info, exclude=exclude)
else:
layout = layout.pick(
"all",
exclude=_picks_to_idx(
info,
exclude if exclude != "bads" else info["bads"],
exclude=(),
allow_empty=True,
),
)
if not merge_channels:
# XXX. at the moment we are committed to 1- / 2-sensor-types layouts
chs_in_layout = [ch_name for ch_name in ch_names if ch_name in layout.names]
types_used = [channel_type(info, ch_names.index(ch)) for ch in chs_in_layout]
# Using dict conversion to remove duplicates
types_used = list(dict.fromkeys(types_used))
# remove possible reference meg channels
types_used = [
types_used for types_used in types_used if types_used != "ref_meg"
]
# one check for all vendors
is_meg = len([x for x in types_used if x in ["mag", "grad"]]) > 0
is_nirs = (
len(
[
x
for x in types_used
if x in ("hbo", "hbr", "fnirs_cw_amplitude", "fnirs_od")
]
)
> 0
)
if is_meg:
picks = [
pick_types(info, meg=kk, ref_meg=False, exclude=exclude)
for kk in types_used
]
elif is_nirs:
picks = [
pick_types(info, fnirs=kk, ref_meg=False, exclude=exclude)
for kk in types_used
]
else:
types_used_kwargs = {t: True for t in types_used}
picks = [pick_types(info, meg=False, exclude=exclude, **types_used_kwargs)]
assert isinstance(picks, list) and len(types_used) == len(picks)
if noise_cov is None:
for e in evoked:
for pick, ch_type in zip(picks, types_used):
e.data[pick] *= scalings[ch_type]
if proj is True and all(e.proj is not True for e in evoked):
evoked = [e.apply_proj() for e in evoked]
elif proj == "interactive": # let it fail early.
for e in evoked:
_check_delayed_ssp(e)
# Y labels for picked plots must be reconstructed
y_label = list()
for ch_idx in range(len(chs_in_layout)):
if noise_cov is None:
unit = _handle_default("units")[channel_type(info, ch_idx)]
else:
unit = "NA"
y_label.append(f"Amplitude ({unit})")
if ylim is None:
# find minima and maxima over all evoked data for each channel pick
ylim_ = dict()
for ch_type, p in zip(types_used, picks):
ylim_[ch_type] = [
min([e.data[p].min() for e in evoked]),
max([e.data[p].max() for e in evoked]),
]
elif isinstance(ylim, dict):
ylim_ = _handle_default("ylim", ylim)
ylim_ = {kk: ylim_[kk] for kk in types_used}
else:
raise TypeError(f"ylim must be None or a dict. Got {type(ylim)}.")
data = [e.data for e in evoked]
comments = [e.comment for e in evoked]
times = [e.times for e in evoked]
show_func = partial(
_plot_timeseries_unified,
data=data,
color=color,
times=times,
vline=vline,
hline=hline,
hvline_color=font_color,
)
click_func = partial(
_plot_timeseries,
data=data,
color=color,
times=times,
vline=vline,
hline=hline,
hvline_color=font_color,
labels=comments,
)
time_min = min([t[0] for t in times])
time_max = max([t[-1] for t in times])
fig = _plot_topo(
info=info,
times=[time_min, time_max],
show_func=show_func,
click_func=click_func,
layout=layout,
colorbar=False,
ylim=ylim_,
cmap=None,
layout_scale=layout_scale,
border=border,
fig_facecolor=fig_facecolor,
font_color=font_color,
axis_facecolor=axis_facecolor,
title=title,
x_label="Time (s)",
y_label=y_label,
unified=True,
axes=axes,
select=select,
)
add_background_image(fig, fig_background)
if legend is not False:
legend_loc = 0 if legend is True else legend
labels = [e.comment if e.comment else "Unknown" for e in evoked]
if select:
handles = fig.axes[0].lines[1 : len(evoked) + 1]
else:
handles = fig.axes[0].lines[: len(evoked)]
legend = plt.legend(
labels=labels, handles=handles, loc=legend_loc, prop={"size": 10}
)
legend.get_frame().set_facecolor(axis_facecolor)
txts = legend.get_texts()
for txt, col in zip(txts, color):
txt.set_color(col)
if proj == "interactive":
for e in evoked:
_check_delayed_ssp(e)
params = dict(
evokeds=evoked,
times=times,
plot_update_proj_callback=_plot_update_evoked_topo_proj,
projs=evoked[0].info["projs"],
fig=fig,
)
_draw_proj_checkbox(None, params)
plt_show(show)
return fig
def _plot_update_evoked_topo_proj(params, bools):
"""Update topo sensor plots."""
evokeds = [e.copy() for e in params["evokeds"]]
fig = params["fig"]
projs = [proj for proj, b in zip(params["projs"], bools) if b]
params["proj_bools"] = bools
for e in evokeds:
e.add_proj(projs, remove_existing=True)
e.apply_proj()
# make sure to only modify the time courses, not the ticks
for ax in fig.axes[0]._mne_axs:
for line, evoked in zip(ax.data_lines, evokeds):
line.set_ydata(ax.y_t + ax.y_s * evoked.data[ax._mne_ch_idx])
fig.canvas.draw()
def plot_topo_image_epochs(
epochs,
layout=None,
sigma=0.0,
vmin=None,
vmax=None,
colorbar=None,
order=None,
cmap="RdBu_r",
layout_scale=0.95,
title=None,
scalings=None,
border="none",
fig_facecolor="k",
fig_background=None,
font_color="w",
select=False,
show=True,
):
"""Plot Event Related Potential / Fields image on topographies.
Parameters
----------
epochs : instance of :class:`~mne.Epochs`
The epochs.
layout : instance of Layout
System specific sensor positions.
sigma : float
The standard deviation of the Gaussian smoothing to apply along
the epoch axis to apply in the image. If 0., no smoothing is applied.
vmin : float
The min value in the image. The unit is µV for EEG channels,
fT for magnetometers and fT/cm for gradiometers.
vmax : float
The max value in the image. The unit is µV for EEG channels,
fT for magnetometers and fT/cm for gradiometers.
colorbar : bool | None
Whether to display a colorbar or not. If ``None`` a colorbar will be
shown only if all channels are of the same type. Defaults to ``None``.
order : None | array of int | callable
If not None, order is used to reorder the epochs on the y-axis
of the image. If it's an array of int it should be of length
the number of good epochs. If it's a callable the arguments
passed are the times vector and the data as 2d array
(data.shape[1] == len(times)).
cmap : colormap
Colors to be mapped to the values.
layout_scale : float
Scaling factor for adjusting the relative size of the layout
on the canvas.
title : str
Title of the figure.
scalings : dict | None
The scalings of the channel types to be applied for plotting. If
``None``, defaults to ``dict(eeg=1e6, grad=1e13, mag=1e15)``.
border : str
Matplotlib borders style to be used for each sensor plot.
fig_facecolor : color
The figure face color. Defaults to black.
fig_background : None | array
A background image for the figure. This must be a valid input to
:func:`matplotlib.pyplot.imshow`. Defaults to ``None``.
font_color : color
The color of tick labels in the colorbar. Defaults to white.
select : bool
Whether to enable the lasso-selection tool to enable the user to select
channels. The selected channels will be available in
``fig.lasso.selection``.
.. versionadded:: 1.10.0
show : bool
Whether to show the figure. Defaults to ``True``.
Returns
-------
fig : instance of :class:`matplotlib.figure.Figure`
Figure distributing one image per channel across sensor topography.
Notes
-----
In an interactive Python session, this plot will be interactive; clicking
on a channel image will pop open a larger view of the image; this image
will always have a colorbar even when the topo plot does not (because it
shows multiple sensor types).
"""
from ..channels.layout import find_layout
scalings = _handle_default("scalings", scalings)
# make a copy because we discard non-data channels and scale the data
epochs = epochs.copy().load_data()
# use layout to subset channels present in epochs object
if layout is None:
layout = find_layout(epochs.info)
ch_names = set(layout.names) & set(epochs.ch_names)
idxs = [epochs.ch_names.index(ch_name) for ch_name in ch_names]
epochs = epochs.pick(idxs)
# get lists of channel type & scale coefficient
ch_types = epochs.get_channel_types()
scale_coeffs = [scalings.get(ch_type, 1) for ch_type in ch_types]
# scale the data
epochs._data *= np.array(scale_coeffs)[:, np.newaxis]
data = epochs.get_data(copy=False)
# get vlims for each channel type
vlim_dict = dict()
for ch_type in set(ch_types):
this_data = data[:, np.where(np.array(ch_types) == ch_type)]
vlim_dict[ch_type] = _setup_vmin_vmax(this_data, vmin, vmax)
vlim_array = np.array([vlim_dict[ch_type] for ch_type in ch_types])
# only show colorbar if we have a single channel type
if colorbar is None:
colorbar = len(set(ch_types)) == 1
# if colorbar=True, we know we have only 1 channel type so all entries
# in vlim_array are the same, just take the first one
if colorbar and vmin is None and vmax is None:
vmin, vmax = vlim_array[0]
show_func = partial(
_erfimage_imshow_unified,
scalings=scale_coeffs,
order=order,
data=data,
epochs=epochs,
sigma=sigma,
cmap=cmap,
vlim_array=vlim_array,
)
erf_imshow = partial(
_erfimage_imshow,
scalings=scale_coeffs,
order=order,
data=data,
epochs=epochs,
sigma=sigma,
cmap=cmap,
vlim_array=vlim_array,
colorbar=True,
)
fig = _plot_topo(
info=epochs.info,
times=epochs.times,
click_func=erf_imshow,
show_func=show_func,
layout=layout,
colorbar=colorbar,
vmin=vmin,
vmax=vmax,
cmap=cmap,
layout_scale=layout_scale,
title=title,
fig_facecolor=fig_facecolor,
font_color=font_color,
border=border,
x_label="Time (s)",
y_label="Epoch",
unified=True,
img=True,
select=select,
)
add_background_image(fig, fig_background)
plt_show(show)
return fig
def _pos_to_bbox(pos, ax):
"""Convert layout position to bbox."""
import matplotlib.transforms as mtransforms
return mtransforms.TransformedBbox(
mtransforms.Bbox.from_bounds(*pos),
ax.transAxes,
)