a b/Download_Raw_EEG_Data/Extract-Raw-Data-Into-Matlab-Files.py
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
'''
5
6
Please see this before you run this file!!!!!!
7
8
NOTICE:
9
    Be advised that this “Extract-Raw-Data-Into-Matlab-Files.py” Python File
10
    should only be executed under the Python 2 Environment.
11
    I highly recommend to execute the file under the Python 2.7 Environment
12
    because I have passed the test.
13
    However, if you are using Python 3 Environment to run this file,
14
    I'm afraid there will be an error and the generated labels will be wrong.
15
16
    If you have any question, please be sure to let me know.
17
    My email is shuyuej@ieee.org.
18
    Thanks a lot.
19
20
'''
21
import numpy as np
22
import io
23
import os
24
import pyedflib
25
import scipy.io as sio
26
from scipy.signal import butter, filtfilt, iirnotch, lfilter
27
28
MOVEMENT_START = 1 * 160  # MI starts 1s after trial begin
29
MOVEMENT_END   = 5 * 160  # MI lasts 4 seconds
30
NOISE_LEVEL    = 0.01
31
32
PHYSIONET_ELECTRODES = {
33
    1:  "FC5",  2: "FC3",  3: "FC1",  4: "FCz",  5: "FC2",  6: "FC4",
34
    7:  "FC6",  8: "C5",   9: "C3",  10: "C1",  11: "Cz",  12: "C2",
35
    13: "C4",  14: "C6",  15: "CP5", 16: "CP3", 17: "CP1", 18: "CPz",
36
    19: "CP2", 20: "CP4", 21: "CP6", 22: "Fp1", 23: "Fpz", 24: "Fp2",
37
    25: "AF7", 26: "AF3", 27: "AFz", 28: "AF4", 29: "AF8", 30: "F7",
38
    31: "F5",  32: "F3",  33: "F1",  34: "Fz",  35: "F2",  36: "F4",
39
    37: "F6",  38: "F8",  39: "FT7", 40: "FT8", 41: "T7",  42: "T8",
40
    43: "T9",  44: "T10", 45: "TP7", 46: "TP8", 47: "P7",  48: "P5",
41
    49: "P3",  50: "P1",  51: "Pz",  52: "P2",  53: "P4",  54: "P6",
42
    55: "P8",  56: "PO7", 57: "PO3", 58: "POz", 59: "PO4", 60: "PO8",
43
    61: "O1",  62: "Oz",  63: "O2",  64: "Iz"}
44
45
46
def load_edf_signals(path):
47
    try:
48
        sig = pyedflib.EdfReader(path)
49
        n = sig.signals_in_file
50
        signal_labels = sig.getSignalLabels()
51
        sigbuf = np.zeros((n, sig.getNSamples()[0]))
52
53
        for j in np.arange(n):
54
            sigbuf[j, :] = sig.readSignal(j)
55
        # (n,3) annotations: [t in s, duration, type T0/T1/T2]
56
        annotations = sig.read_annotation()
57
    except KeyboardInterrupt:
58
        # prevent memory leak and access problems of unclosed buffers
59
        sig._close()
60
        raise
61
62
    sig._close()
63
    del sig
64
    return sigbuf.transpose(), annotations
65
66
67
def get_physionet_electrode_positions():
68
    refpos = get_electrode_positions()
69
    return np.array([refpos[PHYSIONET_ELECTRODES[idx]] for idx in range(1, 65)])
70
71
72
def projection_2d(loc):
73
    """
74
    Azimuthal equidistant projection (AEP) of 3D carthesian coordinates.
75
    Preserves distance to origin while projecting to 2D carthesian space.
76
    loc: N x 3 array of 3D points
77
    returns: N x 2 array of projected 2D points
78
    """
79
    x, y, z = loc[:, 0], loc[:, 1], loc[:, 2]
80
    theta = np.arctan2(y, x)  # theta = azimuth
81
    rho = np.pi / 2 - np.arctan2(z, np.hypot(x, y))  # rho = pi/2 - elevation
82
    return np.stack((np.multiply(rho, np.cos(theta)), np.multiply(rho, np.sin(theta))), 1)
83
84
85
def get_electrode_positions():
86
    """
87
    Returns a dictionary (Name) -> (x,y,z) of electrode name in the extended
88
    10-20 system and its carthesian coordinates in unit sphere.
89
    """
90
    positions = dict()
91
    with io.open("electrode_positions.txt", "r") as pos_file:
92
        for line in pos_file:
93
            parts = line.split()
94
            positions[parts[0]] = tuple([float(part) for part in parts[1:]])
95
    return positions
96
97
98
def load_physionet_data(subject_id, num_classes=4, long_edge=False):
99
    """
100
    subject_id: ID (1-109) for the subject to be loaded from file
101
    num_classes: number of classes (2, 3 or 4) for L/R, L/R/0, L/R/0/F
102
    long_edge: if False include 1s before and after MI, if True include 3s
103
    returns (X, y, pos, fs)
104
        X: Trials with shape (N_subjects, N_trials, N_samples, N_channels)
105
        y: labels with shape (N_subjects, N_trials, N_classes)
106
        pos: 2D projected electrode positions
107
        fs: sample rate
108
    """
109
110
    SAMPLE_RATE  = 160
111
    EEG_CHANNELS = 64
112
    BASELINE_RUN = 1
113
114
    MI_RUNS = [4, 8, 12]  # l/r fist
115
    if num_classes >= 4:
116
        MI_RUNS += [6, 10, 14]  # feet (& fists)
117
118
    # total number of samples per long run
119
    RUN_LENGTH = 125 * SAMPLE_RATE
120
121
    # length of single trial in seconds
122
    TRIAL_LENGTH = 4 if not long_edge else 6
123
    NUM_TRIALS   = 21 * num_classes
124
125
    n_runs = len(MI_RUNS)
126
    X = np.zeros((n_runs, RUN_LENGTH, EEG_CHANNELS))
127
    events = []
128
129
    base_path = 'S%03dR%02d.edf'
130
131
    for i_run, current_run in enumerate(MI_RUNS):
132
        # load from file
133
        path = base_path % (subject_id, current_run)
134
        signals, annotations = load_edf_signals(path)
135
        X[i_run, :signals.shape[0], :] = signals
136
137
        # read annotations
138
        current_event = [i_run, 0, 0, 0]  # run, class (l/r), start, end
139
140
        for annotation in annotations:
141
            t = int(annotation[0] * SAMPLE_RATE * 1e-7)
142
            action = int(annotation[2][1])
143
144
            if action == 0 and current_event[1] != 0:
145
                # make 6 second runs by extending snippet
146
                length = TRIAL_LENGTH * SAMPLE_RATE
147
                pad = (length - (t - current_event[2])) / 2
148
149
                current_event[2] -= pad + (t - current_event[2]) % 2
150
                current_event[3] = t + pad
151
152
                if (current_run - 6) % 4 != 0 or current_event[1] == 2:
153
                    if (current_run - 6) % 4 == 0:
154
                        current_event[1] = 3
155
                    events.append(current_event)
156
157
            elif action > 0:
158
                current_event = [i_run, action, t, 0]
159
    
160
    # split runs into trials
161
    num_mi_trials = len(events)
162
    trials = np.zeros((NUM_TRIALS, TRIAL_LENGTH * SAMPLE_RATE, EEG_CHANNELS))
163
    labels = np.zeros((NUM_TRIALS, num_classes))
164
165
    for i, ev in enumerate(events):
166
        trials[i, :, :] = X[ev[0], ev[2]:ev[3]]
167
        labels[i, ev[1] - 1] = 1.
168
169
    if num_classes < 3:
170
        return (trials[:num_mi_trials, ...],
171
                labels[:num_mi_trials, ...],
172
                projection_2d(get_physionet_electrode_positions()),
173
                SAMPLE_RATE)
174
    else:
175
        # baseline run
176
        path = base_path % (subject_id, BASELINE_RUN)
177
        signals, annotations = load_edf_signals(path)
178
        SAMPLES = TRIAL_LENGTH * SAMPLE_RATE
179
        for i in range(num_mi_trials, NUM_TRIALS):
180
            offset = np.random.randint(0, signals.shape[0] - SAMPLES)
181
            trials[i, :, :] = signals[offset: offset+SAMPLES, :]
182
            labels[i, -1] = 1.
183
        return trials, labels, projection_2d(get_physionet_electrode_positions()), SAMPLE_RATE
184
185
186
def load_raw_data(electrodes, subject=None, num_classes=4, long_edge=False):
187
    # load from file
188
    trials = []
189
    labels = []
190
191
    if subject == None:
192
        subject_ids = range(11, 15)
193
    else:
194
        try:
195
            subject_ids = [int(subject)]
196
        except:
197
            subject_ids = subject
198
199
    for subject_id in subject_ids:
200
        try:
201
            t, l, loc, fs = load_physionet_data(subject_id, num_classes, long_edge=long_edge)
202
            if num_classes == 2 and t.shape[0] != 42:
203
                # drop subjects with less trials
204
                continue
205
            trials.append(t[:, :, electrodes])
206
            labels.append(l)
207
        except:
208
            pass
209
    return np.array(trials, dtype=np.float64).reshape((len(trials),) + trials[0].shape + (1,)), \
210
           np.array(labels, dtype=np.float64)
211
212
213
def butter_bandpass(lowcut, highcut, fs, order=5):
214
    nyq = 0.5 * fs
215
    low = lowcut / nyq
216
    high = highcut / nyq
217
    b, a = butter(order, [low, high], btype='band')
218
    return b, a
219
220
221
def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):
222
    b, a = butter_bandpass(lowcut, highcut, fs, order=order)
223
    y = filtfilt(b, a, data)
224
    return y
225
226
227
print("Start to save the Files!")
228
229
230
# Sample rate and desired cutoff frequencies (in Hz).
231
fs      = 160.0
232
lowcut  = 5.0
233
highcut = 30.0
234
235
# Save Dataset of 4 clases
236
nclasses = 4
237
238
# Save 20 subjects' dataset
239
SAVE = '20-Subjects'
240
if not os.path.exists(SAVE):
241
    os.mkdir(SAVE)
242
243
subject = range(1, 21)
244
245
# Save 64 electrodes
246
for i in range(0, 64):
247
    electrodes = [i]
248
    X = 'X_' + str(i)
249
    Y = 'Y_' + str(i)
250
    X, Y = load_raw_data(electrodes=electrodes, subject=subject, num_classes=nclasses)
251
    X = np.squeeze(X)
252
253
    sio.savemat(SAVE + 'Dataset_%d.mat' % int(i+1), {'Dataset': X})
254
    sio.savemat(SAVE + 'Labels_%d.mat' % int(i+1), {'Labels': Y})
255
256
    print("Finished saving %d electrodes" % int(i+1))
257
258
# SAVE = '100-Subjects/'
259
#
260
# if not os.path.exists(SAVE):
261
#     os.mkdir(SAVE)
262
#
263
# subject = range(1, 102)
264
#
265
# for i in range(0, 64):
266
#     electrodes = [i]
267
#     X = 'X_' + str(i)
268
#     Y = 'Y_' + str(i)
269
#     X, Y = load_raw_data(electrodes=electrodes, subject=subject, num_classes=nclasses)
270
#     X = np.squeeze(X)
271
#
272
#     # # Notch Filter
273
#     # b, a = iirnotch(w0=60.0, Q=30.0, fs=fs)
274
#     # X = lfilter(b, a, X)
275
#     #
276
#     # # Butterworth Band-pass filter
277
#     # X = butter_bandpass_filter(X, lowcut, highcut, fs, order=4)
278
#
279
#     sio.savemat(SAVE + 'Dataset_%d.mat' % int(i+1), {'Dataset': X})
280
#     sio.savemat(SAVE + 'Labels_%d.mat' % int(i+1), {'Labels': Y})
281
#
282
#     print("Finished saving %d electrodes" % int(i+1))