|
a |
|
b/eeg_pipeline.py |
|
|
1 |
import mne |
|
|
2 |
import matplotlib.pyplot as plt |
|
|
3 |
import pandas as pd |
|
|
4 |
import numpy as np |
|
|
5 |
import scipy.io |
|
|
6 |
import os |
|
|
7 |
from collections import OrderedDict |
|
|
8 |
|
|
|
9 |
def standardize_sensors(raw_data, channel_config, return_montage=True): |
|
|
10 |
|
|
|
11 |
# channel_names = [x.upper() for x in raw_data.ch_names] |
|
|
12 |
|
|
|
13 |
NUM_REDUCED_SENSORS = 8 |
|
|
14 |
montage_sensor_set = ["F7", "F3", "F8", "F4", "T3", "C3", "T4", "C4", "T5", "P3", "T6", "P4", "O1", "O2"] |
|
|
15 |
first = ["F7", "F8", "T3", "T4", "T5", "T6", "O1", "O2"] |
|
|
16 |
second = ["F3", "F4", "C3", "C4", "P3", "P4", "P3", "P4"] |
|
|
17 |
|
|
|
18 |
if channel_config in ["01_tcp_ar", "03_tcp_ar_a"]: |
|
|
19 |
montage_sensor_set = [str("EEG "+x+"-REF") for x in montage_sensor_set] |
|
|
20 |
first = [str("EEG "+x+"-REF") for x in first] |
|
|
21 |
second = [str("EEG "+x+"-REF") for x in second] |
|
|
22 |
|
|
|
23 |
elif channel_config == "02_tcp_le": |
|
|
24 |
montage_sensor_set = [str("EEG "+x+"-LE") for x in montage_sensor_set] |
|
|
25 |
first = [str("EEG "+x+"-LE") for x in first] |
|
|
26 |
second = [str("EEG "+x+"-LE") for x in second] |
|
|
27 |
|
|
|
28 |
raw_data = raw_data.pick_channels(montage_sensor_set, ordered=True) |
|
|
29 |
|
|
|
30 |
# return channels without subtraction - 14 of them |
|
|
31 |
if return_montage == False: |
|
|
32 |
return raw_data, raw_data |
|
|
33 |
|
|
|
34 |
# use a sensor's data to get total number of samples |
|
|
35 |
reduced_data = np.zeros((NUM_REDUCED_SENSORS, raw_data.n_times)) |
|
|
36 |
|
|
|
37 |
# create derived channels |
|
|
38 |
for idx in range(NUM_REDUCED_SENSORS): |
|
|
39 |
reduced_data[idx, :] = raw_data[first[idx]][0][:] - raw_data[second[idx]][0][:] |
|
|
40 |
|
|
|
41 |
# create new info object for reduced sensors |
|
|
42 |
reduced_info = mne.create_info(ch_names=[ |
|
|
43 |
"F7-F3", "F8-F4", |
|
|
44 |
"T3-C3", "T4-C4", |
|
|
45 |
"T5-P3", "T6-P4", |
|
|
46 |
"O1-P3", "O2-P4" |
|
|
47 |
], sfreq=raw_data.info["sfreq"], ch_types=["eeg"]*NUM_REDUCED_SENSORS) |
|
|
48 |
|
|
|
49 |
# https://mne.tools/dev/auto_examples/io/plot_objects_from_arrays.html?highlight=rawarray |
|
|
50 |
reduced_raw_data = mne.io.RawArray(reduced_data, reduced_info) |
|
|
51 |
# return reduced_raw_data, raw_data |
|
|
52 |
return reduced_raw_data |
|
|
53 |
|
|
|
54 |
|
|
|
55 |
def downsample(raw_data, freq=250): |
|
|
56 |
raw_data = raw_data.resample(sfreq=freq) |
|
|
57 |
return raw_data, freq |
|
|
58 |
|
|
|
59 |
|
|
|
60 |
def highpass(raw_data, cutoff=1.0): |
|
|
61 |
raw_data.filter(l_freq=cutoff, h_freq=None) |
|
|
62 |
return raw_data |
|
|
63 |
|
|
|
64 |
|
|
|
65 |
def remove_line_noise(raw_data, ac_freqs = np.arange(50, 101, 50)): |
|
|
66 |
raw_data.notch_filter(freqs=ac_freqs, picks="eeg", verbose=False) |
|
|
67 |
return raw_data |
|
|
68 |
|
|
|
69 |
# accepts PSD of all sensors, returns band power for all sensors |
|
|
70 |
def get_brain_waves_power(psd_welch, freqs): |
|
|
71 |
|
|
|
72 |
brain_waves = OrderedDict({ |
|
|
73 |
"delta" : [1.0, 4.0], |
|
|
74 |
"theta": [4.0, 7.5], |
|
|
75 |
"alpha": [7.5, 13.0], |
|
|
76 |
"lower_beta": [13.0, 16.0], |
|
|
77 |
"higher_beta": [16.0, 30.0], |
|
|
78 |
"gamma": [30.0, 40.0] |
|
|
79 |
}) |
|
|
80 |
|
|
|
81 |
# create new variable you want to "fill": n_brain_wave_bands |
|
|
82 |
band_powers = np.zeros((psd_welch.shape[0], 6)) |
|
|
83 |
|
|
|
84 |
for wave_idx, wave in enumerate(brain_waves.keys()): |
|
|
85 |
# identify freq indices of the wave band |
|
|
86 |
if wave_idx == 0: |
|
|
87 |
band_freqs_idx = np.argwhere((freqs <= brain_waves[wave][1])) |
|
|
88 |
else: |
|
|
89 |
band_freqs_idx = np.argwhere((freqs >= brain_waves[wave][0]) & (freqs <= brain_waves[wave][1])) |
|
|
90 |
|
|
|
91 |
# extract the psd values for those freq indices |
|
|
92 |
band_psd = psd_welch[:, band_freqs_idx.ravel()] |
|
|
93 |
|
|
|
94 |
# sum the band psd data to get total band power |
|
|
95 |
total_band_power = np.sum(band_psd, axis=1) |
|
|
96 |
|
|
|
97 |
# set power in band for all sensors |
|
|
98 |
band_powers[:, wave_idx] = total_band_power |
|
|
99 |
|
|
|
100 |
return band_powers |