# This notebook provides template code to prepare your favorite EEG signal datasets for either EEG-GCNN model training or evaluation. In fact, you can use this code to process any EEG dataset for training any model you like.

# Here, the TUH Epilepsy Corpus is used as an example.

Dataset documention - https://www.isip.piconepress.com/projects/tuh_eeg/downloads/tuh_eeg_epilepsy/v1.0.0/_AAREADME.txt

In [None]:
from glob2 import glob

edf_file_list = glob("../data/tuh_eeg_epilepsy_corpus/edf/*epilepsy/*/*/*/*/*.edf")
len(edf_file_list)

There are 1648 EDF files in the corpus, just like the README mentions. However, there are multiple files for one subject in many cases, we want to keep only unique IDs and remove duplicate entries.

In [None]:
import numpy as np
# extract subject IDs from the file path, create python set to extract unique elements from list, convert to list again 
unique_epilepsy_patient_ids = list(set([x.split("/")[-1].split("_")[0] for x in edf_file_list]))
len(unique_epilepsy_patient_ids)

In [None]:
with open('../data/subject_lists/epilepsy_corpus_subjects.txt', 'w') as file_handler:
    for item in unique_epilepsy_patient_ids:
        file_handler.write("{}\n".format(item))

This list of subject IDs is used to create an index csv file that contains all the dataset metadata you'd need to train and evaluate any predictive model on the Epilepsy corpus.

### See which sensor configurations are available for the Epilepsy corpus

In [None]:
edf_file_list = glob("../data/tuh_eeg_epilepsy_corpus/edf/*epilepsy/*/*/*/*/*.edf")
channel_configs = [x.split("/")[5] for x in edf_file_list]
list(set(channel_configs))

The corpus contains 3 different channel configurations, ensure that all the channels you need exist

### Open a signal file for each configuration and check channels

In [None]:
import mne
# file_path = r'../data/tuh_eeg_epilepsy_corpus/edf/no_epilepsy/02_tcp_le/055/00005573/s001_2009_01_20/00005573_s001_t000.edf'
# file_path = r'../data/tuh_eeg_epilepsy_corpus/edf/no_epilepsy/01_tcp_ar/098/00009853/s001_2013_04_10/00009853_s001_t000.edf'
file_path = r'../data/tuh_eeg_epilepsy_corpus/edf/no_epilepsy/03_tcp_ar_a/076/00007671/s002_2011_02_03/00007671_s002_t001.edf'

raw_data = mne.io.read_raw_edf(file_path, verbose=False, preload=False)
raw_data.info["ch_names"]

- eeg-gcnn channels not in 01_tcp_ar - T7, T8, P7, P8
- eeg-gcnn channels not in 02_tcp_le - T7, T8, P7, P8
- eeg-gcnn channels not in 03_tcp_ar_a - T7, T8, P7, P8

## Create dataset index with 10 second non-overlapping consecutive windows and labels for the task

Each recording is broken down into 10s windows. Each window gets one row of metadata in the index csv.

In [None]:
from glob2 import glob
import mne 

f = open('../data/subject_lists/epilepsy_corpus_subjects.txt', 'r')
unique_epilepsy_patient_ids = f.readlines()
unique_epilepsy_patient_ids = [x.strip() for x in unique_epilepsy_patient_ids]

# pick your desired preprocessing configuration.
SAMPLING_FREQ = 250.0
WINDOW_LENGTH_SECONDS = 10.0
WINDOW_LENGTH_SAMPLES = int(WINDOW_LENGTH_SECONDS * SAMPLING_FREQ)


# loop over one subject at a time, and add corresponding metadata to csv
dataset_index_rows = [ ]
label_count = { 
    "epilepsy": 0,
    "no_epilepsy": 0
}
for idx, patient_id in enumerate(unique_epilepsy_patient_ids):

    print(f"\n\n\n {patient_id} : {idx+1}/{len(unique_epilepsy_patient_ids)} \n\n\n")
    
    # find all edf files corresponding to this patient id
    patient_edf_file_list = glob(f"../data/tuh_eeg_epilepsy_corpus/edf/*epilepsy/*/*/{patient_id}/*/{patient_id}_*.edf")
    assert len(patient_edf_file_list) >= 1
    
    # CAUTION - later ignoring multiple recordings of a subject, taking only one.
    print(len(patient_edf_file_list))
        
    # get label of the recording from the file name, ensure all labels for the same subject are the same
    # NOTE - the label of the recording is copied to each of its windows
    labels = [x.split("/")[4] for x in patient_edf_file_list]
    assert labels == [labels[0]]*len(labels)
    print (labels)
    
    label = labels[0]
    label_count[label] += 1
    
    # CAUTION - considering only the first recording here!
    raw_file_path = patient_edf_file_list[0]
    raw_data = mne.io.read_raw_edf(raw_file_path, verbose=False, preload=False)
    
    # generate window metadata = one row of dataset_index
    for start_sample_index in range(0, int(int(raw_data.times[-1]) * SAMPLING_FREQ), WINDOW_LENGTH_SAMPLES):

        end_sample_index = start_sample_index + (WINDOW_LENGTH_SAMPLES - 1)
        
        # ensure 10 seconds are available in window and recording does not end
        if end_sample_index > raw_data.n_times:
            break

        row = {}
        row["patient_id"] = patient_id
        row["raw_file_path"] = patient_edf_file_list[0]
        row["record_length_seconds"] = raw_data.times[-1]
        # this is the desired SFREQ using which sample indices are derived.
        # CAUTION - this is not the original SFREQ at which the data is recorded.
        row["sampling_freq"] = SAMPLING_FREQ
        row["channel_config"] = raw_file_path.split("/")[5]
        row["start_sample_index"] = start_sample_index
        row["end_sample_index"] = end_sample_index
        row["text_label"] = label
        row["numeric_label"] = 0 if label == "no_epilepsy" else 1
        dataset_index_rows.append(row)
        
# create dataframe from rows, save to disk
import pandas as pd
df = pd.DataFrame(dataset_index_rows, columns=["patient_id", 
                                "raw_file_path",
                                "record_length_seconds", 
                                "sampling_freq",
                                 "channel_config",
                                "start_sample_index",
                                "end_sample_index",
                                "text_label",
                                "numeric_label"])
df.to_csv("epilepsy_corpus_window_index.csv", index=False)

In [None]:
label_count

In [None]:
df

In [None]:
df.shape

This dataset yielded a total of 33864 windows (each of 10s), that can now be used for feature extraction and model training/evaluation.

# 1) Iterate over the dataset using index, 2) preprocess each recording, 3) generate brain rhythm PSD + connectivity features, 4) save computed features to disk as numpy array

This step takes time when done serially. Since the signal in each window is assumed to be independent, you can parallelize the feature generation step if you'd like. Here, we do it serially.

In [None]:
import mne
import numpy as np
import pandas as pd

# most of these are simply wrappers around mne-python
from eeg_pipeline import standardize_sensors, downsample, highpass, remove_line_noise, get_brain_waves_power

# for purposes of feature generation, you don't want to open up the same recording again and again for different windows within it
# therefore, group index by subject, iterate over grouped dataframe instead.
index_df = pd.read_csv("epilepsy_corpus_window_index.csv")
grouped_df = index_df.groupby("raw_file_path")

# create empty feature matrices to fill in
# PSD features - 8 channels x 6 brain rhythms power
# Connectivity features - 64 directed edges in a fully-connected graph with 8 channels
feature_matrix = np.zeros((index_df.shape[0], 8*6))
spec_coh_matrix = np.zeros((index_df.shape[0], 64))

SAMPLING_FREQ = 250.0

# open up one raw_file at a time.
for raw_file_path, group_df in grouped_df:
    
    print(f"FILE NAME: {raw_file_path}")
    print(f"WINDOW IDS IN FILE: {group_df.index.tolist()}")
    channel_config = str(group_df["channel_config"].unique()[0])
    print(channel_config)
    
    # NOTE - PREPROCESSING = open the file, select channels, apply montage, downsample to 250, highpass, notch filter
    raw_data = mne.io.read_raw_edf(raw_file_path, verbose=True, preload=True)
    raw_data = standardize_sensors(raw_data, channel_config, return_montage=True)
    raw_data, sfreq = downsample(raw_data, SAMPLING_FREQ)
    raw_data = highpass(raw_data, 1.0)
    raw_data = remove_line_noise(raw_data)
    
    # data is ready for feature extraction, loop over windows, extract features
    for window_idx in group_df.index.tolist():
        
        # get raw data for the window
        start_sample = group_df.loc[window_idx]['start_sample_index']
        stop_sample = group_df.loc[window_idx]['end_sample_index']
        window_data = raw_data.get_data(start=start_sample, stop=stop_sample)

        
        # CONNECTIVITY EDGE FEATURES - compute spectral coherence values between all sensors within the window
        from mne.connectivity import spectral_connectivity
        # required transformation for mne spectral connectivity API
        transf_window_data = np.expand_dims(window_data, axis=0)

        # the spectral connectivity of each channel with every other.
        for ch_idx in range(8):

            # https://mne.tools/stable/generated/mne.connectivity.spectral_connectivity.html#mne.connectivity.spectral_connectivity
            spec_conn, freqs, times, n_epochs, n_tapers = spectral_connectivity(data=transf_window_data, 
                                              method='coh', 
                                              indices=([ch_idx]*8, range(8)), 
                                              sfreq=SAMPLING_FREQ, 
            #                                   fmin=(1.0, 4.0, 7.5, 13.0, 16.0, 30.0), 
            #                                   fmax=(4.0, 7.5, 13.0, 16.0, 30.0, 40.0),
                                              fmin=1.0, fmax=40.0,
                                            faverage=True, verbose=False)

            #             print(np.squeeze(spec_conn))
            #             print(freqs)
            #             print(times)
            #             print(n_epochs)
            #             print(n_tapers)
            
            spec_coh_values = np.squeeze(spec_conn)
            assert spec_coh_values.shape[0] == 8
            
            # save to connectivity feature matrix at appropriate index
            start_edge_idx = ch_idx * 8
            end_edge_idx = start_edge_idx + 8
            spec_coh_matrix[window_idx, start_edge_idx:end_edge_idx] = spec_coh_values
        
#         print("[WINDOW] CONNECTIVITY FEATURES DONE!...")
        
        # PSD NODE FEATURES - derive total power in 6 brain rhythm bands for each montage channel
        from mne.time_frequency import psd_array_welch
        psd_welch, freqs = psd_array_welch(window_data, sfreq=SAMPLING_FREQ, fmax=50.0, n_per_seg=150, 
                                           average='mean', verbose=False)
        # Convert power to dB scale.
        psd_welch = 10 * np.log10(psd_welch)
        band_powers = get_brain_waves_power(psd_welch, freqs)
        assert band_powers.shape == (8, 6)

        # flatten all features, and save to feature matrix at appropriate index
        features = band_powers.flatten()
        feature_matrix[window_idx, :] = features
        
#         print ("[WINDOW] PSD FEATURES DONE!...")

    print ("\n[RECORDING] ALL WINDOWS DONE! FILE DONE!...\n")
    
# save the features and labels as numpy array to disk
np.save("../data/saved_numpy_arrays/X_psd_epilepsy_corpus.npy", feature_matrix)
np.save("../data/saved_numpy_arrays/X_spec_coh_epilepsy_corpus.npy", spec_coh_matrix)
np.save("../data/saved_numpy_arrays/y_epilepsy_corpus.npy", index_df["text_label"].to_numpy())

print ("\nALL ARRAYS SAVED TO DISK!...\n")