a b/deepsleepnet_data/prepare_physionet.py
1
#Source : https://github.com/akaraspt/deepsleepnet
2
3
import argparse
4
import glob
5
import math
6
import ntpath
7
import os
8
import shutil
9
import urllib
10
import urllib2
11
12
from datetime import datetime
13
14
import numpy as np
15
import pandas as pd
16
17
from mne import Epochs, pick_types, find_events
18
from mne.io import concatenate_raws, read_raw_edf
19
20
import dhedfreader
21
22
23
# Label values
24
W = 0
25
N1 = 1
26
N2 = 2
27
N3 = 3
28
REM = 4
29
UNKNOWN = 5
30
31
stage_dict = {
32
    "W": W,
33
    "N1": N1,
34
    "N2": N2,
35
    "N3": N3,
36
    "REM": REM,
37
    "UNKNOWN": UNKNOWN
38
}
39
40
class_dict = {
41
    0: "W",
42
    1: "N1",
43
    2: "N2",
44
    3: "N3",
45
    4: "REM",
46
    5: "UNKNOWN"
47
}
48
49
ann2label = {
50
    "Sleep stage W": 0,
51
    "Sleep stage 1": 1,
52
    "Sleep stage 2": 2,
53
    "Sleep stage 3": 3,
54
    "Sleep stage 4": 3,
55
    "Sleep stage R": 4,
56
    "Sleep stage ?": 5,
57
    "Movement time": 5
58
}
59
60
EPOCH_SEC_SIZE = 30
61
62
63
def main():
64
    parser = argparse.ArgumentParser()
65
    parser.add_argument("--data_dir", type=str, default="/data/physionet_sleep",
66
                        help="File path to the CSV or NPY file that contains walking data.")
67
    parser.add_argument("--output_dir", type=str, default="/data/physionet_sleep/eeg_fpz_cz",
68
                        help="Directory where to save outputs.")
69
    parser.add_argument("--select_ch", type=str, default="EEG Fpz-Cz",
70
                        help="File path to the trained model used to estimate walking speeds.")
71
    args = parser.parse_args()
72
73
    # Output dir
74
    if not os.path.exists(args.output_dir):
75
        os.makedirs(args.output_dir)
76
    else:
77
        shutil.rmtree(args.output_dir)
78
        os.makedirs(args.output_dir)
79
80
    # Select channel
81
    select_ch = args.select_ch
82
83
    # Read raw and annotation EDF files
84
    psg_fnames = glob.glob(os.path.join(args.data_dir, "*PSG.edf"))
85
    ann_fnames = glob.glob(os.path.join(args.data_dir, "*Hypnogram.edf"))
86
    psg_fnames.sort()
87
    ann_fnames.sort()
88
    psg_fnames = np.asarray(psg_fnames)
89
    ann_fnames = np.asarray(ann_fnames)
90
91
    for i in range(len(psg_fnames)):
92
        # if not "ST7171J0-PSG.edf" in psg_fnames[i]:
93
        #     continue
94
95
        raw = read_raw_edf(psg_fnames[i], preload=True, stim_channel=None)
96
        sampling_rate = raw.info['sfreq']
97
        raw_ch_df = raw.to_data_frame(scaling_time=100.0)[select_ch]
98
        raw_ch_df = raw_ch_df.to_frame()
99
        raw_ch_df.set_index(np.arange(len(raw_ch_df)))
100
101
        # Get raw header
102
        f = open(psg_fnames[i], 'r')
103
        reader_raw = dhedfreader.BaseEDFReader(f)
104
        reader_raw.read_header()
105
        h_raw = reader_raw.header
106
        f.close()
107
        raw_start_dt = datetime.strptime(h_raw['date_time'], "%Y-%m-%d %H:%M:%S")
108
109
        # Read annotation and its header
110
        f = open(ann_fnames[i], 'r')
111
        reader_ann = dhedfreader.BaseEDFReader(f)
112
        reader_ann.read_header()
113
        h_ann = reader_ann.header
114
        _, _, ann = zip(*reader_ann.records())
115
        f.close()
116
        ann_start_dt = datetime.strptime(h_ann['date_time'], "%Y-%m-%d %H:%M:%S")
117
118
        # Assert that raw and annotation files start at the same time
119
        assert raw_start_dt == ann_start_dt
120
121
        # Generate label and remove indices
122
        remove_idx = []    # indicies of the data that will be removed
123
        labels = []        # indicies of the data that have labels
124
        label_idx = []
125
        for a in ann[0]:
126
            onset_sec, duration_sec, ann_char = a
127
            ann_str = "".join(ann_char)
128
            label = ann2label[ann_str]
129
            if label != UNKNOWN:
130
                if duration_sec % EPOCH_SEC_SIZE != 0:
131
                    raise Exception("Something wrong")
132
                duration_epoch = int(duration_sec / EPOCH_SEC_SIZE)
133
                label_epoch = np.ones(duration_epoch, dtype=np.int) * label
134
                labels.append(label_epoch)
135
                idx = int(onset_sec * sampling_rate) + np.arange(duration_sec * sampling_rate, dtype=np.int)
136
                label_idx.append(idx)
137
138
                print "Include onset:{}, duration:{}, label:{} ({})".format(
139
                    onset_sec, duration_sec, label, ann_str
140
                )
141
            else:
142
                idx = int(onset_sec * sampling_rate) + np.arange(duration_sec * sampling_rate, dtype=np.int)
143
                remove_idx.append(idx)
144
145
                print "Remove onset:{}, duration:{}, label:{} ({})".format(
146
                    onset_sec, duration_sec, label, ann_str
147
                )
148
        labels = np.hstack(labels)
149
        
150
        print "before remove unwanted: {}".format(np.arange(len(raw_ch_df)).shape)
151
        if len(remove_idx) > 0:
152
            remove_idx = np.hstack(remove_idx)
153
            select_idx = np.setdiff1d(np.arange(len(raw_ch_df)), remove_idx)
154
        else:
155
            select_idx = np.arange(len(raw_ch_df))
156
        print "after remove unwanted: {}".format(select_idx.shape)
157
158
        # Select only the data with labels
159
        print "before intersect label: {}".format(select_idx.shape)
160
        label_idx = np.hstack(label_idx)
161
        select_idx = np.intersect1d(select_idx, label_idx)
162
        print "after intersect label: {}".format(select_idx.shape)
163
164
        # Remove extra index
165
        if len(label_idx) > len(select_idx):
166
            print "before remove extra labels: {}, {}".format(select_idx.shape, labels.shape)
167
            extra_idx = np.setdiff1d(label_idx, select_idx)
168
            # Trim the tail
169
            if np.all(extra_idx > select_idx[-1]):
170
                n_trims = len(select_idx) % int(EPOCH_SEC_SIZE * sampling_rate)
171
                n_label_trims = int(math.ceil(n_trims / (EPOCH_SEC_SIZE * sampling_rate)))
172
                select_idx = select_idx[:-n_trims]
173
                labels = labels[:-n_label_trims]
174
            print "after remove extra labels: {}, {}".format(select_idx.shape, labels.shape)
175
176
        # Remove movement and unknown stages if any
177
        raw_ch = raw_ch_df.values[select_idx]
178
179
        # Verify that we can split into 30-s epochs
180
        if len(raw_ch) % (EPOCH_SEC_SIZE * sampling_rate) != 0:
181
            raise Exception("Something wrong")
182
        n_epochs = len(raw_ch) / (EPOCH_SEC_SIZE * sampling_rate)
183
184
        # Get epochs and their corresponding labels
185
        x = np.asarray(np.split(raw_ch, n_epochs)).astype(np.float32)
186
        y = labels.astype(np.int32)
187
188
        assert len(x) == len(y)
189
190
        # Select on sleep periods
191
        w_edge_mins = 30
192
        nw_idx = np.where(y != stage_dict["W"])[0]
193
        start_idx = nw_idx[0] - (w_edge_mins * 2)
194
        end_idx = nw_idx[-1] + (w_edge_mins * 2)
195
        if start_idx < 0: start_idx = 0
196
        if end_idx >= len(y): end_idx = len(y) - 1
197
        select_idx = np.arange(start_idx, end_idx+1)
198
        print("Data before selection: {}, {}".format(x.shape, y.shape))
199
        x = x[select_idx]
200
        y = y[select_idx]
201
        print("Data after selection: {}, {}".format(x.shape, y.shape))
202
203
        # Save
204
        filename = ntpath.basename(psg_fnames[i]).replace("-PSG.edf", ".npz")
205
        save_dict = {
206
            "x": x, 
207
            "y": y, 
208
            "fs": sampling_rate,
209
            "ch_label": select_ch,
210
            "header_raw": h_raw,
211
            "header_annotation": h_ann,
212
        }
213
        np.savez(os.path.join(args.output_dir, filename), **save_dict)
214
215
        print "\n=======================================\n"
216
217
218
if __name__ == "__main__":
219
    main()