Diff of /eeg_pipeline.py [000000] .. [4dadda]

Switch to unified view

a b/eeg_pipeline.py
1
import mne
2
import matplotlib.pyplot as plt
3
import pandas as pd
4
import numpy as np
5
import scipy.io
6
import os
7
from collections import OrderedDict
8
9
def standardize_sensors(raw_data, channel_config, return_montage=True):
10
11
    # channel_names = [x.upper() for x in raw_data.ch_names]
12
    
13
    NUM_REDUCED_SENSORS = 8
14
    montage_sensor_set = ["F7", "F3", "F8", "F4", "T3", "C3", "T4", "C4", "T5", "P3", "T6", "P4", "O1", "O2"]
15
    first = ["F7", "F8", "T3", "T4", "T5", "T6", "O1", "O2"]
16
    second = ["F3", "F4", "C3", "C4", "P3", "P4", "P3", "P4"]
17
    
18
    if channel_config in ["01_tcp_ar", "03_tcp_ar_a"]:
19
        montage_sensor_set = [str("EEG "+x+"-REF") for x in montage_sensor_set]
20
        first = [str("EEG "+x+"-REF") for x in first]
21
        second = [str("EEG "+x+"-REF") for x in second]
22
23
    elif channel_config == "02_tcp_le":
24
        montage_sensor_set = [str("EEG "+x+"-LE") for x in montage_sensor_set]
25
        first = [str("EEG "+x+"-LE") for x in first]
26
        second = [str("EEG "+x+"-LE") for x in second]
27
28
    raw_data = raw_data.pick_channels(montage_sensor_set, ordered=True)
29
30
    # return channels without subtraction - 14 of them
31
    if return_montage == False:
32
        return raw_data, raw_data
33
34
    # use a sensor's data to get total number of samples
35
    reduced_data = np.zeros((NUM_REDUCED_SENSORS, raw_data.n_times))
36
37
    # create derived channels
38
    for idx in range(NUM_REDUCED_SENSORS):
39
        reduced_data[idx, :] = raw_data[first[idx]][0][:] - raw_data[second[idx]][0][:]
40
    
41
    # create new info object for reduced sensors
42
    reduced_info = mne.create_info(ch_names=[
43
                                            "F7-F3", "F8-F4",
44
                                            "T3-C3", "T4-C4",
45
                                            "T5-P3", "T6-P4",
46
                                            "O1-P3", "O2-P4"
47
                                            ], sfreq=raw_data.info["sfreq"], ch_types=["eeg"]*NUM_REDUCED_SENSORS)
48
    
49
    # https://mne.tools/dev/auto_examples/io/plot_objects_from_arrays.html?highlight=rawarray
50
    reduced_raw_data = mne.io.RawArray(reduced_data, reduced_info)
51
    # return reduced_raw_data, raw_data
52
    return reduced_raw_data
53
54
55
def downsample(raw_data, freq=250):
56
    raw_data = raw_data.resample(sfreq=freq)
57
    return raw_data, freq
58
59
60
def highpass(raw_data, cutoff=1.0):
61
    raw_data.filter(l_freq=cutoff, h_freq=None)
62
    return raw_data
63
64
65
def remove_line_noise(raw_data, ac_freqs = np.arange(50, 101, 50)):
66
    raw_data.notch_filter(freqs=ac_freqs, picks="eeg", verbose=False)
67
    return raw_data
68
69
# accepts PSD of all sensors, returns band power for all sensors
70
def get_brain_waves_power(psd_welch, freqs):
71
72
    brain_waves = OrderedDict({
73
        "delta" : [1.0, 4.0],
74
        "theta": [4.0, 7.5],
75
        "alpha": [7.5, 13.0],
76
        "lower_beta": [13.0, 16.0],
77
        "higher_beta": [16.0, 30.0],
78
        "gamma": [30.0, 40.0]
79
    })
80
81
    # create new variable you want to "fill": n_brain_wave_bands
82
    band_powers = np.zeros((psd_welch.shape[0], 6))
83
84
    for wave_idx, wave in enumerate(brain_waves.keys()):
85
        # identify freq indices of the wave band
86
        if wave_idx == 0:
87
            band_freqs_idx = np.argwhere((freqs <= brain_waves[wave][1]))
88
        else:
89
            band_freqs_idx = np.argwhere((freqs >= brain_waves[wave][0]) & (freqs <= brain_waves[wave][1]))
90
91
        # extract the psd values for those freq indices
92
        band_psd = psd_welch[:, band_freqs_idx.ravel()]
93
94
        # sum the band psd data to get total band power
95
        total_band_power = np.sum(band_psd, axis=1)
96
97
        # set power in band for all sensors
98
        band_powers[:, wave_idx] = total_band_power    
99
100
    return band_powers