|
a |
|
b/deepsleepnet_data/prepare_physionet.py |
|
|
1 |
#Source : https://github.com/akaraspt/deepsleepnet |
|
|
2 |
|
|
|
3 |
import argparse |
|
|
4 |
import glob |
|
|
5 |
import math |
|
|
6 |
import ntpath |
|
|
7 |
import os |
|
|
8 |
import shutil |
|
|
9 |
import urllib |
|
|
10 |
import urllib2 |
|
|
11 |
|
|
|
12 |
from datetime import datetime |
|
|
13 |
|
|
|
14 |
import numpy as np |
|
|
15 |
import pandas as pd |
|
|
16 |
|
|
|
17 |
from mne import Epochs, pick_types, find_events |
|
|
18 |
from mne.io import concatenate_raws, read_raw_edf |
|
|
19 |
|
|
|
20 |
import dhedfreader |
|
|
21 |
|
|
|
22 |
|
|
|
23 |
# Label values |
|
|
24 |
W = 0 |
|
|
25 |
N1 = 1 |
|
|
26 |
N2 = 2 |
|
|
27 |
N3 = 3 |
|
|
28 |
REM = 4 |
|
|
29 |
UNKNOWN = 5 |
|
|
30 |
|
|
|
31 |
stage_dict = { |
|
|
32 |
"W": W, |
|
|
33 |
"N1": N1, |
|
|
34 |
"N2": N2, |
|
|
35 |
"N3": N3, |
|
|
36 |
"REM": REM, |
|
|
37 |
"UNKNOWN": UNKNOWN |
|
|
38 |
} |
|
|
39 |
|
|
|
40 |
class_dict = { |
|
|
41 |
0: "W", |
|
|
42 |
1: "N1", |
|
|
43 |
2: "N2", |
|
|
44 |
3: "N3", |
|
|
45 |
4: "REM", |
|
|
46 |
5: "UNKNOWN" |
|
|
47 |
} |
|
|
48 |
|
|
|
49 |
ann2label = { |
|
|
50 |
"Sleep stage W": 0, |
|
|
51 |
"Sleep stage 1": 1, |
|
|
52 |
"Sleep stage 2": 2, |
|
|
53 |
"Sleep stage 3": 3, |
|
|
54 |
"Sleep stage 4": 3, |
|
|
55 |
"Sleep stage R": 4, |
|
|
56 |
"Sleep stage ?": 5, |
|
|
57 |
"Movement time": 5 |
|
|
58 |
} |
|
|
59 |
|
|
|
60 |
EPOCH_SEC_SIZE = 30 |
|
|
61 |
|
|
|
62 |
|
|
|
63 |
def main(): |
|
|
64 |
parser = argparse.ArgumentParser() |
|
|
65 |
parser.add_argument("--data_dir", type=str, default="/data/physionet_sleep", |
|
|
66 |
help="File path to the CSV or NPY file that contains walking data.") |
|
|
67 |
parser.add_argument("--output_dir", type=str, default="/data/physionet_sleep/eeg_fpz_cz", |
|
|
68 |
help="Directory where to save outputs.") |
|
|
69 |
parser.add_argument("--select_ch", type=str, default="EEG Fpz-Cz", |
|
|
70 |
help="File path to the trained model used to estimate walking speeds.") |
|
|
71 |
args = parser.parse_args() |
|
|
72 |
|
|
|
73 |
# Output dir |
|
|
74 |
if not os.path.exists(args.output_dir): |
|
|
75 |
os.makedirs(args.output_dir) |
|
|
76 |
else: |
|
|
77 |
shutil.rmtree(args.output_dir) |
|
|
78 |
os.makedirs(args.output_dir) |
|
|
79 |
|
|
|
80 |
# Select channel |
|
|
81 |
select_ch = args.select_ch |
|
|
82 |
|
|
|
83 |
# Read raw and annotation EDF files |
|
|
84 |
psg_fnames = glob.glob(os.path.join(args.data_dir, "*PSG.edf")) |
|
|
85 |
ann_fnames = glob.glob(os.path.join(args.data_dir, "*Hypnogram.edf")) |
|
|
86 |
psg_fnames.sort() |
|
|
87 |
ann_fnames.sort() |
|
|
88 |
psg_fnames = np.asarray(psg_fnames) |
|
|
89 |
ann_fnames = np.asarray(ann_fnames) |
|
|
90 |
|
|
|
91 |
for i in range(len(psg_fnames)): |
|
|
92 |
# if not "ST7171J0-PSG.edf" in psg_fnames[i]: |
|
|
93 |
# continue |
|
|
94 |
|
|
|
95 |
raw = read_raw_edf(psg_fnames[i], preload=True, stim_channel=None) |
|
|
96 |
sampling_rate = raw.info['sfreq'] |
|
|
97 |
raw_ch_df = raw.to_data_frame(scaling_time=100.0)[select_ch] |
|
|
98 |
raw_ch_df = raw_ch_df.to_frame() |
|
|
99 |
raw_ch_df.set_index(np.arange(len(raw_ch_df))) |
|
|
100 |
|
|
|
101 |
# Get raw header |
|
|
102 |
f = open(psg_fnames[i], 'r') |
|
|
103 |
reader_raw = dhedfreader.BaseEDFReader(f) |
|
|
104 |
reader_raw.read_header() |
|
|
105 |
h_raw = reader_raw.header |
|
|
106 |
f.close() |
|
|
107 |
raw_start_dt = datetime.strptime(h_raw['date_time'], "%Y-%m-%d %H:%M:%S") |
|
|
108 |
|
|
|
109 |
# Read annotation and its header |
|
|
110 |
f = open(ann_fnames[i], 'r') |
|
|
111 |
reader_ann = dhedfreader.BaseEDFReader(f) |
|
|
112 |
reader_ann.read_header() |
|
|
113 |
h_ann = reader_ann.header |
|
|
114 |
_, _, ann = zip(*reader_ann.records()) |
|
|
115 |
f.close() |
|
|
116 |
ann_start_dt = datetime.strptime(h_ann['date_time'], "%Y-%m-%d %H:%M:%S") |
|
|
117 |
|
|
|
118 |
# Assert that raw and annotation files start at the same time |
|
|
119 |
assert raw_start_dt == ann_start_dt |
|
|
120 |
|
|
|
121 |
# Generate label and remove indices |
|
|
122 |
remove_idx = [] # indicies of the data that will be removed |
|
|
123 |
labels = [] # indicies of the data that have labels |
|
|
124 |
label_idx = [] |
|
|
125 |
for a in ann[0]: |
|
|
126 |
onset_sec, duration_sec, ann_char = a |
|
|
127 |
ann_str = "".join(ann_char) |
|
|
128 |
label = ann2label[ann_str] |
|
|
129 |
if label != UNKNOWN: |
|
|
130 |
if duration_sec % EPOCH_SEC_SIZE != 0: |
|
|
131 |
raise Exception("Something wrong") |
|
|
132 |
duration_epoch = int(duration_sec / EPOCH_SEC_SIZE) |
|
|
133 |
label_epoch = np.ones(duration_epoch, dtype=np.int) * label |
|
|
134 |
labels.append(label_epoch) |
|
|
135 |
idx = int(onset_sec * sampling_rate) + np.arange(duration_sec * sampling_rate, dtype=np.int) |
|
|
136 |
label_idx.append(idx) |
|
|
137 |
|
|
|
138 |
print "Include onset:{}, duration:{}, label:{} ({})".format( |
|
|
139 |
onset_sec, duration_sec, label, ann_str |
|
|
140 |
) |
|
|
141 |
else: |
|
|
142 |
idx = int(onset_sec * sampling_rate) + np.arange(duration_sec * sampling_rate, dtype=np.int) |
|
|
143 |
remove_idx.append(idx) |
|
|
144 |
|
|
|
145 |
print "Remove onset:{}, duration:{}, label:{} ({})".format( |
|
|
146 |
onset_sec, duration_sec, label, ann_str |
|
|
147 |
) |
|
|
148 |
labels = np.hstack(labels) |
|
|
149 |
|
|
|
150 |
print "before remove unwanted: {}".format(np.arange(len(raw_ch_df)).shape) |
|
|
151 |
if len(remove_idx) > 0: |
|
|
152 |
remove_idx = np.hstack(remove_idx) |
|
|
153 |
select_idx = np.setdiff1d(np.arange(len(raw_ch_df)), remove_idx) |
|
|
154 |
else: |
|
|
155 |
select_idx = np.arange(len(raw_ch_df)) |
|
|
156 |
print "after remove unwanted: {}".format(select_idx.shape) |
|
|
157 |
|
|
|
158 |
# Select only the data with labels |
|
|
159 |
print "before intersect label: {}".format(select_idx.shape) |
|
|
160 |
label_idx = np.hstack(label_idx) |
|
|
161 |
select_idx = np.intersect1d(select_idx, label_idx) |
|
|
162 |
print "after intersect label: {}".format(select_idx.shape) |
|
|
163 |
|
|
|
164 |
# Remove extra index |
|
|
165 |
if len(label_idx) > len(select_idx): |
|
|
166 |
print "before remove extra labels: {}, {}".format(select_idx.shape, labels.shape) |
|
|
167 |
extra_idx = np.setdiff1d(label_idx, select_idx) |
|
|
168 |
# Trim the tail |
|
|
169 |
if np.all(extra_idx > select_idx[-1]): |
|
|
170 |
n_trims = len(select_idx) % int(EPOCH_SEC_SIZE * sampling_rate) |
|
|
171 |
n_label_trims = int(math.ceil(n_trims / (EPOCH_SEC_SIZE * sampling_rate))) |
|
|
172 |
select_idx = select_idx[:-n_trims] |
|
|
173 |
labels = labels[:-n_label_trims] |
|
|
174 |
print "after remove extra labels: {}, {}".format(select_idx.shape, labels.shape) |
|
|
175 |
|
|
|
176 |
# Remove movement and unknown stages if any |
|
|
177 |
raw_ch = raw_ch_df.values[select_idx] |
|
|
178 |
|
|
|
179 |
# Verify that we can split into 30-s epochs |
|
|
180 |
if len(raw_ch) % (EPOCH_SEC_SIZE * sampling_rate) != 0: |
|
|
181 |
raise Exception("Something wrong") |
|
|
182 |
n_epochs = len(raw_ch) / (EPOCH_SEC_SIZE * sampling_rate) |
|
|
183 |
|
|
|
184 |
# Get epochs and their corresponding labels |
|
|
185 |
x = np.asarray(np.split(raw_ch, n_epochs)).astype(np.float32) |
|
|
186 |
y = labels.astype(np.int32) |
|
|
187 |
|
|
|
188 |
assert len(x) == len(y) |
|
|
189 |
|
|
|
190 |
# Select on sleep periods |
|
|
191 |
w_edge_mins = 30 |
|
|
192 |
nw_idx = np.where(y != stage_dict["W"])[0] |
|
|
193 |
start_idx = nw_idx[0] - (w_edge_mins * 2) |
|
|
194 |
end_idx = nw_idx[-1] + (w_edge_mins * 2) |
|
|
195 |
if start_idx < 0: start_idx = 0 |
|
|
196 |
if end_idx >= len(y): end_idx = len(y) - 1 |
|
|
197 |
select_idx = np.arange(start_idx, end_idx+1) |
|
|
198 |
print("Data before selection: {}, {}".format(x.shape, y.shape)) |
|
|
199 |
x = x[select_idx] |
|
|
200 |
y = y[select_idx] |
|
|
201 |
print("Data after selection: {}, {}".format(x.shape, y.shape)) |
|
|
202 |
|
|
|
203 |
# Save |
|
|
204 |
filename = ntpath.basename(psg_fnames[i]).replace("-PSG.edf", ".npz") |
|
|
205 |
save_dict = { |
|
|
206 |
"x": x, |
|
|
207 |
"y": y, |
|
|
208 |
"fs": sampling_rate, |
|
|
209 |
"ch_label": select_ch, |
|
|
210 |
"header_raw": h_raw, |
|
|
211 |
"header_annotation": h_ann, |
|
|
212 |
} |
|
|
213 |
np.savez(os.path.join(args.output_dir, filename), **save_dict) |
|
|
214 |
|
|
|
215 |
print "\n=======================================\n" |
|
|
216 |
|
|
|
217 |
|
|
|
218 |
if __name__ == "__main__": |
|
|
219 |
main() |