--- a +++ b/prepare_physionet.py @@ -0,0 +1,226 @@ +''' +https://github.com/akaraspt/deepsleepnet +Copyright 2017 Akara Supratak and Hao Dong. All rights reserved. +''' + +import argparse +import glob +import math +import ntpath +import os +import shutil +import urllib +import urllib2 + +from datetime import datetime + +import numpy as np +import pandas as pd + +from mne import Epochs, pick_types, find_events +from mne.io import concatenate_raws, read_raw_edf + +import dhedfreader + + +# Label values +W = 0 +N1 = 1 +N2 = 2 +N3 = 3 +REM = 4 +UNKNOWN = 5 + +stage_dict = { + "W": W, + "N1": N1, + "N2": N2, + "N3": N3, + "REM": REM, + "UNKNOWN": UNKNOWN +} + +class_dict = { + 0: "W", + 1: "N1", + 2: "N2", + 3: "N3", + 4: "REM", + 5: "UNKNOWN" +} + +ann2label = { + "Sleep stage W": 0, + "Sleep stage 1": 1, + "Sleep stage 2": 2, + "Sleep stage 3": 3, + "Sleep stage 4": 3, + "Sleep stage R": 4, + "Sleep stage ?": 5, + "Movement time": 5 +} + +EPOCH_SEC_SIZE = 30 + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--data_dir", type=str, default="data_2013/", + help="File path to the CSV or NPY file that contains walking data.") + parser.add_argument("--output_dir", type=str, default="data_2013/eeg_fpz_cz", + help="Directory where to save outputs.") + parser.add_argument("--select_ch", type=str, default="EEG Fpz-Cz", + help="File path to the trained model used to estimate walking speeds.") + args = parser.parse_args() + + # Output dir + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + else: + shutil.rmtree(args.output_dir) + os.makedirs(args.output_dir) + + # Select channel + select_ch = args.select_ch + + # Read raw and annotation EDF files + psg_fnames = glob.glob(os.path.join(args.data_dir, "*PSG.edf")) + ann_fnames = glob.glob(os.path.join(args.data_dir, "*Hypnogram.edf")) + psg_fnames.sort() + ann_fnames.sort() + psg_fnames = np.asarray(psg_fnames) + ann_fnames = np.asarray(ann_fnames) + + for i in range(len(psg_fnames)): + # if not "ST7171J0-PSG.edf" in psg_fnames[i]: + # continue + # i = ii+80 + # if i >= len(psg_fnames): + # break + + raw = read_raw_edf(psg_fnames[i], preload=True, stim_channel=None) + sampling_rate = raw.info['sfreq'] + raw_ch_df = raw.to_data_frame(scaling_time=100.0)[select_ch] + raw_ch_df = raw_ch_df.to_frame() + raw_ch_df.set_index(np.arange(len(raw_ch_df))) + + # Get raw header + f = open(psg_fnames[i], 'r') + reader_raw = dhedfreader.BaseEDFReader(f) + reader_raw.read_header() + h_raw = reader_raw.header + f.close() + raw_start_dt = datetime.strptime(h_raw['date_time'], "%Y-%m-%d %H:%M:%S") + + # Read annotation and its header + f = open(ann_fnames[i], 'r') + reader_ann = dhedfreader.BaseEDFReader(f) + reader_ann.read_header() + h_ann = reader_ann.header + _, _, ann = zip(*reader_ann.records()) + f.close() + ann_start_dt = datetime.strptime(h_ann['date_time'], "%Y-%m-%d %H:%M:%S") + + # Assert that raw and annotation files start at the same time + assert raw_start_dt == ann_start_dt + + # Generate label and remove indices + remove_idx = [] # indicies of the data that will be removed + labels = [] # indicies of the data that have labels + label_idx = [] + for a in ann[0]: + onset_sec, duration_sec, ann_char = a + ann_str = "".join(ann_char) + label = ann2label[ann_str] + if label != UNKNOWN: + if duration_sec % EPOCH_SEC_SIZE != 0: + raise Exception("Something wrong") + duration_epoch = int(duration_sec / EPOCH_SEC_SIZE) + label_epoch = np.ones(duration_epoch, dtype=np.int) * label + labels.append(label_epoch) + idx = int(onset_sec * sampling_rate) + np.arange(duration_sec * sampling_rate, dtype=np.int) + label_idx.append(idx) + + print ("Include onset:{}, duration:{}, label:{} ({})".format( + onset_sec, duration_sec, label, ann_str + )) + else: + idx = int(onset_sec * sampling_rate) + np.arange(duration_sec * sampling_rate, dtype=np.int) + remove_idx.append(idx) + + print ("Remove onset:{}, duration:{}, label:{} ({})".format( + onset_sec, duration_sec, label, ann_str)) + labels = np.hstack(labels) + + print ("before remove unwanted: {}".format(np.arange(len(raw_ch_df)).shape)) + if len(remove_idx) > 0: + remove_idx = np.hstack(remove_idx) + select_idx = np.setdiff1d(np.arange(len(raw_ch_df)), remove_idx) + else: + select_idx = np.arange(len(raw_ch_df)) + print ("after remove unwanted: {}".format(select_idx.shape)) + + # Select only the data with labels + print ("before intersect label: {}".format(select_idx.shape)) + label_idx = np.hstack(label_idx) + select_idx = np.intersect1d(select_idx, label_idx) + print ("after intersect label: {}".format(select_idx.shape)) + + # Remove extra index + if len(label_idx) > len(select_idx): + print("before remove extra labels: {}, {}".format(select_idx.shape, labels.shape)) + extra_idx = np.setdiff1d(label_idx, select_idx) + # Trim the tail + if np.all(extra_idx > select_idx[-1]): + # n_trims = len(select_idx) % int(EPOCH_SEC_SIZE * sampling_rate) + # n_label_trims = int(math.ceil(n_trims / (EPOCH_SEC_SIZE * sampling_rate))) + n_label_trims = int(math.ceil(len(extra_idx) / (EPOCH_SEC_SIZE * sampling_rate))) + if n_label_trims!=0: + # select_idx = select_idx[:-n_trims] + labels = labels[:-n_label_trims] + print("after remove extra labels: {}, {}".format(select_idx.shape, labels.shape)) + + # Remove movement and unknown stages if any + raw_ch = raw_ch_df.values[select_idx] + + # Verify that we can split into 30-s epochs + if len(raw_ch) % (EPOCH_SEC_SIZE * sampling_rate) != 0: + raise Exception("Something wrong") + n_epochs = len(raw_ch) / (EPOCH_SEC_SIZE * sampling_rate) + + # Get epochs and their corresponding labels + x = np.asarray(np.split(raw_ch, n_epochs)).astype(np.float32) + y = labels.astype(np.int32) + + assert len(x) == len(y) + + # Select on sleep periods + w_edge_mins = 30 + nw_idx = np.where(y != stage_dict["W"])[0] + start_idx = nw_idx[0] - (w_edge_mins * 2) + end_idx = nw_idx[-1] + (w_edge_mins * 2) + if start_idx < 0: start_idx = 0 + if end_idx >= len(y): end_idx = len(y) - 1 + select_idx = np.arange(start_idx, end_idx+1) + print("Data before selection: {}, {}".format(x.shape, y.shape)) + x = x[select_idx] + y = y[select_idx] + print("Data after selection: {}, {}".format(x.shape, y.shape)) + + # Save + filename = ntpath.basename(psg_fnames[i]).replace("-PSG.edf", ".npz") + save_dict = { + "x": x, + "y": y, + "fs": sampling_rate, + "ch_label": select_ch, + "header_raw": h_raw, + "header_annotation": h_ann, + } + np.savez(os.path.join(args.output_dir, filename), **save_dict) + + print ("\n=======================================\n") + + +if __name__ == "__main__": + main()