433 lines (432 with data), 16.8 kB
{
"cells": [
{
"cell_type": "markdown",
"id": "iraqi-brook",
"metadata": {},
"source": [
"# This notebook provides template code to prepare your favorite EEG signal datasets for either EEG-GCNN model training or evaluation. In fact, you can use this code to process any EEG dataset for training any model you like."
]
},
{
"cell_type": "markdown",
"id": "minor-verse",
"metadata": {},
"source": [
"# Here, the TUH Epilepsy Corpus is used as an example."
]
},
{
"cell_type": "markdown",
"id": "underlying-brass",
"metadata": {},
"source": [
"Dataset documention - https://www.isip.piconepress.com/projects/tuh_eeg/downloads/tuh_eeg_epilepsy/v1.0.0/_AAREADME.txt"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "executed-packet",
"metadata": {},
"outputs": [],
"source": [
"from glob2 import glob\n",
"\n",
"edf_file_list = glob(\"../data/tuh_eeg_epilepsy_corpus/edf/*epilepsy/*/*/*/*/*.edf\")\n",
"len(edf_file_list)"
]
},
{
"cell_type": "markdown",
"id": "latin-hearing",
"metadata": {},
"source": [
"There are 1648 EDF files in the corpus, just like the README mentions. However, there are multiple files for one subject in many cases, we want to keep only unique IDs and remove duplicate entries."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "rational-parks",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"# extract subject IDs from the file path, create python set to extract unique elements from list, convert to list again \n",
"unique_epilepsy_patient_ids = list(set([x.split(\"/\")[-1].split(\"_\")[0] for x in edf_file_list]))\n",
"len(unique_epilepsy_patient_ids)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "neural-average",
"metadata": {},
"outputs": [],
"source": [
"with open('../data/subject_lists/epilepsy_corpus_subjects.txt', 'w') as file_handler:\n",
" for item in unique_epilepsy_patient_ids:\n",
" file_handler.write(\"{}\\n\".format(item))"
]
},
{
"cell_type": "markdown",
"id": "julian-coordinate",
"metadata": {},
"source": [
"This list of subject IDs is used to create an index csv file that contains all the dataset metadata you'd need to train and evaluate any predictive model on the Epilepsy corpus."
]
},
{
"cell_type": "markdown",
"id": "ranking-torture",
"metadata": {},
"source": [
"### See which sensor configurations are available for the Epilepsy corpus"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "satisfied-catch",
"metadata": {},
"outputs": [],
"source": [
"edf_file_list = glob(\"../data/tuh_eeg_epilepsy_corpus/edf/*epilepsy/*/*/*/*/*.edf\")\n",
"channel_configs = [x.split(\"/\")[5] for x in edf_file_list]\n",
"list(set(channel_configs))"
]
},
{
"cell_type": "markdown",
"id": "infrared-niger",
"metadata": {},
"source": [
"The corpus contains 3 different channel configurations, ensure that all the channels you need exist"
]
},
{
"cell_type": "markdown",
"id": "subtle-citizenship",
"metadata": {},
"source": [
"### Open a signal file for each configuration and check channels"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "pointed-artwork",
"metadata": {},
"outputs": [],
"source": [
"import mne\n",
"# file_path = r'../data/tuh_eeg_epilepsy_corpus/edf/no_epilepsy/02_tcp_le/055/00005573/s001_2009_01_20/00005573_s001_t000.edf'\n",
"# file_path = r'../data/tuh_eeg_epilepsy_corpus/edf/no_epilepsy/01_tcp_ar/098/00009853/s001_2013_04_10/00009853_s001_t000.edf'\n",
"file_path = r'../data/tuh_eeg_epilepsy_corpus/edf/no_epilepsy/03_tcp_ar_a/076/00007671/s002_2011_02_03/00007671_s002_t001.edf'\n",
"\n",
"raw_data = mne.io.read_raw_edf(file_path, verbose=False, preload=False)\n",
"raw_data.info[\"ch_names\"]"
]
},
{
"cell_type": "markdown",
"id": "foster-wyoming",
"metadata": {},
"source": [
"- eeg-gcnn channels not in 01_tcp_ar - T7, T8, P7, P8\n",
"- eeg-gcnn channels not in 02_tcp_le - T7, T8, P7, P8\n",
"- eeg-gcnn channels not in 03_tcp_ar_a - T7, T8, P7, P8"
]
},
{
"cell_type": "markdown",
"id": "widespread-desire",
"metadata": {},
"source": [
"## Create dataset index with 10 second non-overlapping consecutive windows and labels for the task"
]
},
{
"cell_type": "markdown",
"id": "ceramic-springfield",
"metadata": {},
"source": [
"Each recording is broken down into 10s windows. Each window gets one row of metadata in the index csv."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "retired-sussex",
"metadata": {},
"outputs": [],
"source": [
"from glob2 import glob\n",
"import mne \n",
"\n",
"f = open('../data/subject_lists/epilepsy_corpus_subjects.txt', 'r')\n",
"unique_epilepsy_patient_ids = f.readlines()\n",
"unique_epilepsy_patient_ids = [x.strip() for x in unique_epilepsy_patient_ids]\n",
"\n",
"# pick your desired preprocessing configuration.\n",
"SAMPLING_FREQ = 250.0\n",
"WINDOW_LENGTH_SECONDS = 10.0\n",
"WINDOW_LENGTH_SAMPLES = int(WINDOW_LENGTH_SECONDS * SAMPLING_FREQ)\n",
"\n",
"\n",
"# loop over one subject at a time, and add corresponding metadata to csv\n",
"dataset_index_rows = [ ]\n",
"label_count = { \n",
" \"epilepsy\": 0,\n",
" \"no_epilepsy\": 0\n",
"}\n",
"for idx, patient_id in enumerate(unique_epilepsy_patient_ids):\n",
"\n",
" print(f\"\\n\\n\\n {patient_id} : {idx+1}/{len(unique_epilepsy_patient_ids)} \\n\\n\\n\")\n",
" \n",
" # find all edf files corresponding to this patient id\n",
" patient_edf_file_list = glob(f\"../data/tuh_eeg_epilepsy_corpus/edf/*epilepsy/*/*/{patient_id}/*/{patient_id}_*.edf\")\n",
" assert len(patient_edf_file_list) >= 1\n",
" \n",
" # CAUTION - later ignoring multiple recordings of a subject, taking only one.\n",
" print(len(patient_edf_file_list))\n",
" \n",
" # get label of the recording from the file name, ensure all labels for the same subject are the same\n",
" # NOTE - the label of the recording is copied to each of its windows\n",
" labels = [x.split(\"/\")[4] for x in patient_edf_file_list]\n",
" assert labels == [labels[0]]*len(labels)\n",
" print (labels)\n",
" \n",
" label = labels[0]\n",
" label_count[label] += 1\n",
" \n",
" # CAUTION - considering only the first recording here!\n",
" raw_file_path = patient_edf_file_list[0]\n",
" raw_data = mne.io.read_raw_edf(raw_file_path, verbose=False, preload=False)\n",
" \n",
" # generate window metadata = one row of dataset_index\n",
" for start_sample_index in range(0, int(int(raw_data.times[-1]) * SAMPLING_FREQ), WINDOW_LENGTH_SAMPLES):\n",
"\n",
" end_sample_index = start_sample_index + (WINDOW_LENGTH_SAMPLES - 1)\n",
" \n",
" # ensure 10 seconds are available in window and recording does not end\n",
" if end_sample_index > raw_data.n_times:\n",
" break\n",
"\n",
" row = {}\n",
" row[\"patient_id\"] = patient_id\n",
" row[\"raw_file_path\"] = patient_edf_file_list[0]\n",
" row[\"record_length_seconds\"] = raw_data.times[-1]\n",
" # this is the desired SFREQ using which sample indices are derived.\n",
" # CAUTION - this is not the original SFREQ at which the data is recorded.\n",
" row[\"sampling_freq\"] = SAMPLING_FREQ\n",
" row[\"channel_config\"] = raw_file_path.split(\"/\")[5]\n",
" row[\"start_sample_index\"] = start_sample_index\n",
" row[\"end_sample_index\"] = end_sample_index\n",
" row[\"text_label\"] = label\n",
" row[\"numeric_label\"] = 0 if label == \"no_epilepsy\" else 1\n",
" dataset_index_rows.append(row)\n",
" \n",
"# create dataframe from rows, save to disk\n",
"import pandas as pd\n",
"df = pd.DataFrame(dataset_index_rows, columns=[\"patient_id\", \n",
" \"raw_file_path\",\n",
" \"record_length_seconds\", \n",
" \"sampling_freq\",\n",
" \"channel_config\",\n",
" \"start_sample_index\",\n",
" \"end_sample_index\",\n",
" \"text_label\",\n",
" \"numeric_label\"])\n",
"df.to_csv(\"epilepsy_corpus_window_index.csv\", index=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "christian-survey",
"metadata": {},
"outputs": [],
"source": [
"label_count"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "amateur-flower",
"metadata": {},
"outputs": [],
"source": [
"df"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "neither-congo",
"metadata": {},
"outputs": [],
"source": [
"df.shape"
]
},
{
"cell_type": "markdown",
"id": "impossible-feeding",
"metadata": {},
"source": [
"This dataset yielded a total of 33864 windows (each of 10s), that can now be used for feature extraction and model training/evaluation."
]
},
{
"cell_type": "markdown",
"id": "cardiac-shooting",
"metadata": {},
"source": [
"# 1) Iterate over the dataset using index, 2) preprocess each recording, 3) generate brain rhythm PSD + connectivity features, 4) save computed features to disk as numpy array"
]
},
{
"cell_type": "markdown",
"id": "unexpected-framing",
"metadata": {},
"source": [
"This step takes time when done serially. Since the signal in each window is assumed to be independent, you can parallelize the feature generation step if you'd like. Here, we do it serially."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "appointed-sweet",
"metadata": {},
"outputs": [],
"source": [
"import mne\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"# most of these are simply wrappers around mne-python\n",
"from eeg_pipeline import standardize_sensors, downsample, highpass, remove_line_noise, get_brain_waves_power\n",
"\n",
"# for purposes of feature generation, you don't want to open up the same recording again and again for different windows within it\n",
"# therefore, group index by subject, iterate over grouped dataframe instead.\n",
"index_df = pd.read_csv(\"epilepsy_corpus_window_index.csv\")\n",
"grouped_df = index_df.groupby(\"raw_file_path\")\n",
"\n",
"# create empty feature matrices to fill in\n",
"# PSD features - 8 channels x 6 brain rhythms power\n",
"# Connectivity features - 64 directed edges in a fully-connected graph with 8 channels\n",
"feature_matrix = np.zeros((index_df.shape[0], 8*6))\n",
"spec_coh_matrix = np.zeros((index_df.shape[0], 64))\n",
"\n",
"SAMPLING_FREQ = 250.0\n",
"\n",
"# open up one raw_file at a time.\n",
"for raw_file_path, group_df in grouped_df:\n",
" \n",
" print(f\"FILE NAME: {raw_file_path}\")\n",
" print(f\"WINDOW IDS IN FILE: {group_df.index.tolist()}\")\n",
" channel_config = str(group_df[\"channel_config\"].unique()[0])\n",
" print(channel_config)\n",
" \n",
" # NOTE - PREPROCESSING = open the file, select channels, apply montage, downsample to 250, highpass, notch filter\n",
" raw_data = mne.io.read_raw_edf(raw_file_path, verbose=True, preload=True)\n",
" raw_data = standardize_sensors(raw_data, channel_config, return_montage=True)\n",
" raw_data, sfreq = downsample(raw_data, SAMPLING_FREQ)\n",
" raw_data = highpass(raw_data, 1.0)\n",
" raw_data = remove_line_noise(raw_data)\n",
" \n",
" # data is ready for feature extraction, loop over windows, extract features\n",
" for window_idx in group_df.index.tolist():\n",
" \n",
" # get raw data for the window\n",
" start_sample = group_df.loc[window_idx]['start_sample_index']\n",
" stop_sample = group_df.loc[window_idx]['end_sample_index']\n",
" window_data = raw_data.get_data(start=start_sample, stop=stop_sample)\n",
"\n",
" \n",
" # CONNECTIVITY EDGE FEATURES - compute spectral coherence values between all sensors within the window\n",
" from mne.connectivity import spectral_connectivity\n",
" # required transformation for mne spectral connectivity API\n",
" transf_window_data = np.expand_dims(window_data, axis=0)\n",
"\n",
" # the spectral connectivity of each channel with every other.\n",
" for ch_idx in range(8):\n",
"\n",
" # https://mne.tools/stable/generated/mne.connectivity.spectral_connectivity.html#mne.connectivity.spectral_connectivity\n",
" spec_conn, freqs, times, n_epochs, n_tapers = spectral_connectivity(data=transf_window_data, \n",
" method='coh', \n",
" indices=([ch_idx]*8, range(8)), \n",
" sfreq=SAMPLING_FREQ, \n",
" # fmin=(1.0, 4.0, 7.5, 13.0, 16.0, 30.0), \n",
" # fmax=(4.0, 7.5, 13.0, 16.0, 30.0, 40.0),\n",
" fmin=1.0, fmax=40.0,\n",
" faverage=True, verbose=False)\n",
"\n",
" # print(np.squeeze(spec_conn))\n",
" # print(freqs)\n",
" # print(times)\n",
" # print(n_epochs)\n",
" # print(n_tapers)\n",
" \n",
" spec_coh_values = np.squeeze(spec_conn)\n",
" assert spec_coh_values.shape[0] == 8\n",
" \n",
" # save to connectivity feature matrix at appropriate index\n",
" start_edge_idx = ch_idx * 8\n",
" end_edge_idx = start_edge_idx + 8\n",
" spec_coh_matrix[window_idx, start_edge_idx:end_edge_idx] = spec_coh_values\n",
" \n",
"# print(\"[WINDOW] CONNECTIVITY FEATURES DONE!...\")\n",
" \n",
" # PSD NODE FEATURES - derive total power in 6 brain rhythm bands for each montage channel\n",
" from mne.time_frequency import psd_array_welch\n",
" psd_welch, freqs = psd_array_welch(window_data, sfreq=SAMPLING_FREQ, fmax=50.0, n_per_seg=150, \n",
" average='mean', verbose=False)\n",
" # Convert power to dB scale.\n",
" psd_welch = 10 * np.log10(psd_welch)\n",
" band_powers = get_brain_waves_power(psd_welch, freqs)\n",
" assert band_powers.shape == (8, 6)\n",
"\n",
" # flatten all features, and save to feature matrix at appropriate index\n",
" features = band_powers.flatten()\n",
" feature_matrix[window_idx, :] = features\n",
" \n",
"# print (\"[WINDOW] PSD FEATURES DONE!...\")\n",
"\n",
" print (\"\\n[RECORDING] ALL WINDOWS DONE! FILE DONE!...\\n\")\n",
" \n",
"# save the features and labels as numpy array to disk\n",
"np.save(\"../data/saved_numpy_arrays/X_psd_epilepsy_corpus.npy\", feature_matrix)\n",
"np.save(\"../data/saved_numpy_arrays/X_spec_coh_epilepsy_corpus.npy\", spec_coh_matrix)\n",
"np.save(\"../data/saved_numpy_arrays/y_epilepsy_corpus.npy\", index_df[\"text_label\"].to_numpy())\n",
"\n",
"print (\"\\nALL ARRAYS SAVED TO DISK!...\\n\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}