a b/preprocess.py
1
""" 
2
Copyright (C) 2022 King Saud University, Saudi Arabia 
3
SPDX-License-Identifier: Apache-2.0 
4
5
Licensed under the Apache License, Version 2.0 (the "License"); you may not use
6
this file except in compliance with the License. You may obtain a copy of the 
7
License at
8
9
http://www.apache.org/licenses/LICENSE-2.0  
10
11
Unless required by applicable law or agreed to in writing, software distributed
12
under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 
13
CONDITIONS OF ANY KIND, either express or implied. See the License for the
14
specific language governing permissions and limitations under the License. 
15
16
Author:  Hamdi Altaheri 
17
"""
18
19
# Dataset BCI Competition IV-2a is available at 
20
# http://bnci-horizon-2020.eu/database/data-sets
21
22
import numpy as np
23
import scipy.io as sio
24
from tensorflow.keras.utils import to_categorical
25
from sklearn.preprocessing import StandardScaler
26
from sklearn.utils import shuffle
27
28
# We need the following function to load and preprocess the High Gamma Dataset
29
# from preprocess_HGD import load_HGD_data
30
31
#%%
32
def load_data_LOSO (data_path, subject, dataset): 
33
    """ Loading and Dividing of the data set based on the 
34
    'Leave One Subject Out' (LOSO) evaluation approach. 
35
    LOSO is used for  Subject-independent evaluation.
36
    In LOSO, the model is trained and evaluated by several folds, equal to the 
37
    number of subjects, and for each fold, one subject is used for evaluation
38
    and the others for training. The LOSO evaluation technique ensures that 
39
    separate subjects (not visible in the training data) are usedto evaluate 
40
    the model.
41
    
42
        Parameters
43
        ----------
44
        data_path: string
45
            dataset path
46
            # Dataset BCI Competition IV-2a is available at 
47
            # http://bnci-horizon-2020.eu/database/data-sets
48
        subject: int
49
            number of subject in [1, .. ,9/14]
50
            Here, the subject data is used  test the model and other subjects data
51
            for training
52
    """
53
    
54
    X_train, y_train = [], []
55
    for sub in range (0,9):
56
        path = data_path+'s' + str(sub+1) + '/'
57
        
58
        if (dataset == 'BCI2a'):
59
            X1, y1 = load_BCI2a_data(path, sub+1, True)
60
            X2, y2 = load_BCI2a_data(path, sub+1, False)
61
        elif (dataset == 'CS2R'):
62
            X1, y1, _, _, _  = load_CS2R_data_v2(path, sub, True)
63
            X2, y2, _, _, _  = load_CS2R_data_v2(path, sub, False)
64
        # elif (dataset == 'HGD'):
65
        #     X1, y1 = load_HGD_data(path, sub+1, True)
66
        #     X2, y2 = load_HGD_data(path, sub+1, False)
67
        
68
        X = np.concatenate((X1, X2), axis=0)
69
        y = np.concatenate((y1, y2), axis=0)
70
                   
71
        if (sub == subject):
72
            X_test = X
73
            y_test = y
74
        elif len(X_train) == 0:  
75
            X_train = X
76
            y_train = y
77
        else:
78
            X_train = np.concatenate((X_train, X), axis=0)
79
            y_train = np.concatenate((y_train, y), axis=0)
80
81
    return X_train, y_train, X_test, y_test
82
83
84
#%%
85
def load_BCI2a_data(data_path, subject, training, all_trials = True):
86
    """ Loading and Dividing of the data set based on the subject-specific 
87
    (subject-dependent) approach.
88
    In this approach, we used the same training and testing dataas the original
89
    competition, i.e., 288 x 9 trials in session 1 for training, 
90
    and 288 x 9 trials in session 2 for testing.  
91
   
92
        Parameters
93
        ----------
94
        data_path: string
95
            dataset path
96
            # Dataset BCI Competition IV-2a is available on 
97
            # http://bnci-horizon-2020.eu/database/data-sets
98
        subject: int
99
            number of subject in [1, .. ,9]
100
        training: bool
101
            if True, load training data
102
            if False, load testing data
103
        all_trials: bool
104
            if True, load all trials
105
            if False, ignore trials with artifacts 
106
    """
107
    
108
    # Define MI-trials parameters
109
    n_channels = 22
110
    n_tests = 6*48     
111
    window_Length = 7*250 
112
    
113
    # Define MI trial window 
114
    fs = 250          # sampling rate
115
    t1 = int(1.5*fs)  # start time_point
116
    t2 = int(6*fs)    # end time_point
117
118
    class_return = np.zeros(n_tests)
119
    data_return = np.zeros((n_tests, n_channels, window_Length))
120
121
    NO_valid_trial = 0
122
    if training:
123
        a = sio.loadmat(data_path+'A0'+str(subject)+'T.mat')
124
    else:
125
        a = sio.loadmat(data_path+'A0'+str(subject)+'E.mat')
126
    a_data = a['data']
127
    for ii in range(0,a_data.size):
128
        a_data1 = a_data[0,ii]
129
        a_data2= [a_data1[0,0]]
130
        a_data3= a_data2[0]
131
        a_X         = a_data3[0]
132
        a_trial     = a_data3[1]
133
        a_y         = a_data3[2]
134
        a_artifacts = a_data3[5]
135
136
        for trial in range(0,a_trial.size):
137
             if(a_artifacts[trial] != 0 and not all_trials):
138
                 continue
139
             data_return[NO_valid_trial,:,:] = np.transpose(a_X[int(a_trial[trial]):(int(a_trial[trial])+window_Length),:22])
140
             class_return[NO_valid_trial] = int(a_y[trial])
141
             NO_valid_trial +=1        
142
    
143
144
    data_return = data_return[0:NO_valid_trial, :, t1:t2]
145
    class_return = class_return[0:NO_valid_trial]
146
    class_return = (class_return-1).astype(int)
147
148
    return data_return, class_return
149
150
151
152
#%%
153
import json
154
from mne.io import read_raw_edf
155
from dateutil.parser import parse
156
import glob as glob
157
from datetime import datetime
158
159
def load_CS2R_data_v2(data_path, subject, training, 
160
                      classes_labels =  ['Fingers', 'Wrist','Elbow','Rest'], 
161
                      all_trials = True):
162
    """ Loading training/testing data for the CS2R motor imagery dataset
163
    for a specific subject        
164
   
165
        Parameters
166
        ----------
167
        data_path: string
168
            dataset path
169
        subject: int
170
            number of subject in [1, .. ,9]
171
        training: bool
172
            if True, load training data
173
            if False, load testing data
174
        classes_labels: tuple
175
            classes of motor imagery returned by the method (default: all) 
176
    """
177
    
178
    # Get all subjects files with .edf format.
179
    subjectFiles = glob.glob(data_path + 'S_*/')
180
    
181
    # Get all subjects numbers sorted without duplicates.
182
    subjectNo = list(dict.fromkeys(sorted([x[len(x)-4:len(x)-1] for x in subjectFiles])))
183
    # print(SubjectNo[subject].zfill(3))
184
    
185
    if training:  session = 1
186
    else:         session = 2
187
    
188
    num_runs = 5
189
    sfreq = 250 #250
190
    mi_duration = 4.5 #4.5
191
192
    data = np.zeros([num_runs*51, 32, int(mi_duration*sfreq)])
193
    classes = np.zeros(num_runs * 51)
194
    valid_trails = 0
195
    
196
    onset = np.zeros([num_runs, 51])
197
    duration = np.zeros([num_runs, 51])
198
    description = np.zeros([num_runs, 51])
199
200
    #Loop to the first 4 runs.
201
    CheckFiles = glob.glob(data_path + 'S_' + subjectNo[subject].zfill(3) + '/S' + str(session) + '/*.edf')
202
    if not CheckFiles:
203
        return 
204
    
205
    for runNo in range(num_runs): 
206
        valid_trails_in_run = 0
207
        #Get .edf and .json file for following subject and run.
208
        EDFfile = glob.glob(data_path + 'S_' + subjectNo[subject].zfill(3) + '/S' + str(session) + '/S_'+subjectNo[subject].zfill(3)+'_'+str(session)+str(runNo+1)+'*.edf')
209
        JSONfile = glob.glob(data_path + 'S_'+subjectNo[subject].zfill(3) + '/S'+ str(session) +'/S_'+subjectNo[subject].zfill(3)+'_'+str(session)+str(runNo+1)+'*.json')
210
    
211
        #Check if EDFfile list is empty
212
        if not EDFfile:
213
          continue
214
    
215
        # We use mne.read_raw_edf to read in the .edf EEG files
216
        raw = read_raw_edf(str(EDFfile[0]), preload=True, verbose=False)
217
        
218
        # Opening JSON file of the current RUN.
219
        f = open(JSONfile[0],) 
220
    
221
        # returns JSON object as a dictionary 
222
        JSON = json.load(f) 
223
    
224
        #Number of Keystrokes Markers
225
        keyStrokes = np.min([len(JSON['Markers']), 51]) #len(JSON['Markers']), to avoid extra markers by accident
226
        # MarkerStart = JSON['Markers'][0]['startDatetime']
227
           
228
        #Get Start time of marker
229
        date_string = EDFfile[0][-21:-4]
230
        datetime_format = "%d.%m.%y_%H.%M.%S"
231
        startRecordTime = datetime.strptime(date_string, datetime_format).astimezone()
232
    
233
        currentTrialNo = 0 # 1 = fingers, 2 = Wrist, 3 = Elbow, 4 = rest
234
        if(runNo == 4): 
235
            currentTrialNo = 4
236
    
237
        ch_names = raw.info['ch_names'][4:36]
238
             
239
        # filter the data 
240
        raw.filter(4., 50., fir_design='firwin')  
241
        
242
        raw = raw.copy().pick_channels(ch_names = ch_names)
243
        raw = raw.copy().resample(sfreq = sfreq)
244
        fs = raw.info['sfreq']
245
246
        for trail in range(keyStrokes):
247
            
248
            # class for current trial
249
            if(runNo == 4 ):               # In Run 5 all trials are 'reset'
250
                currentTrialNo = 4
251
            elif (currentTrialNo == 3):    # Set the class of current trial to 1 'Fingers'
252
                currentTrialNo = 1   
253
            else:                          # In Runs 1-4, 1st trial is 1 'Fingers', 2nd trial is 2 'Wrist', and 3rd trial is 'Elbow', and repeat ('Fingers', 'Wrist', 'Elbow', ..)
254
                currentTrialNo = currentTrialNo + 1
255
                
256
            trailDuration = 8
257
            
258
            trailTime = parse(JSON['Markers'][trail]['startDatetime'])
259
            trailStart = trailTime - startRecordTime
260
            trailStart = trailStart.seconds 
261
            start = trailStart + (6 - mi_duration)
262
            stop = trailStart + 6
263
264
            if (trail < keyStrokes-1):
265
                trailDuration = parse(JSON['Markers'][trail+1]['startDatetime']) - parse(JSON['Markers'][trail]['startDatetime'])
266
                trailDuration =  trailDuration.seconds + (trailDuration.microseconds/1000000)
267
                if (trailDuration < 7.5) or (trailDuration > 8.5):
268
                    print('In Session: {} - Run: {}, Trail no: {} is skipped due to short/long duration of: {:.2f}'.format(session, (runNo+1), (trail+1), trailDuration))
269
                    if (trailDuration > 14 and trailDuration < 18):
270
                        if (currentTrialNo == 3):   currentTrialNo = 1   
271
                        else:                       currentTrialNo = currentTrialNo + 1
272
                    continue
273
                
274
            elif (trail == keyStrokes-1):
275
                trailDuration = raw[0, int(trailStart*int(fs)):int((trailStart+8)*int(fs))][0].shape[1]/fs
276
                if (trailDuration < 7.8) :
277
                    print('In Session: {} - Run: {}, Trail no: {} is skipped due to short/long duration of: {:.2f}'.format(session, (runNo+1), (trail+1), trailDuration))
278
                    continue
279
280
            MITrail = raw[:32, int(start*int(fs)):int(stop*int(fs))][0]
281
            if (MITrail.shape[1] != data.shape[2]):
282
                print('Error in Session: {} - Run: {}, Trail no: {} due to the lost of data'.format(session, (runNo+1), (trail+1)))
283
                return
284
            
285
            # select some specific classes
286
            if ((('Fingers' in classes_labels) and (currentTrialNo==1)) or 
287
            (('Wrist' in classes_labels) and (currentTrialNo==2)) or 
288
            (('Elbow' in classes_labels) and (currentTrialNo==3)) or 
289
            (('Rest' in classes_labels) and (currentTrialNo==4))):
290
                data[valid_trails] = MITrail
291
                classes[valid_trails] =  currentTrialNo
292
                
293
                # For Annotations
294
                onset[runNo, valid_trails_in_run]  = start
295
                duration[runNo, valid_trails_in_run] = trailDuration - (6 - mi_duration)
296
                description[runNo, valid_trails_in_run] = currentTrialNo
297
                valid_trails += 1
298
                valid_trails_in_run += 1
299
                         
300
    data = data[0:valid_trails, :, :]
301
    classes = classes[0:valid_trails]
302
    classes = (classes-1).astype(int)
303
304
    return data, classes, onset, duration, description
305
306
307
#%%
308
def standardize_data(X_train, X_test, channels): 
309
    # X_train & X_test :[Trials, MI-tasks, Channels, Time points]
310
    for j in range(channels):
311
          scaler = StandardScaler()
312
          scaler.fit(X_train[:, 0, j, :])
313
          X_train[:, 0, j, :] = scaler.transform(X_train[:, 0, j, :])
314
          X_test[:, 0, j, :] = scaler.transform(X_test[:, 0, j, :])
315
316
    return X_train, X_test
317
318
319
#%%
320
def get_data(path, subject, dataset = 'BCI2a', classes_labels = 'all', LOSO = False, isStandard = True, isShuffle = True):
321
    
322
    # Load and split the dataset into training and testing 
323
    if LOSO:
324
        """ Loading and Dividing of the dataset based on the 
325
        'Leave One Subject Out' (LOSO) evaluation approach. """ 
326
        X_train, y_train, X_test, y_test = load_data_LOSO(path, subject, dataset)
327
    else:
328
        """ Loading and Dividing of the data set based on the subject-specific 
329
        (subject-dependent) approach.
330
        In this approach, we used the same training and testing data as the original
331
        competition, i.e., for BCI Competition IV-2a, 288 x 9 trials in session 1 
332
        for training, and 288 x 9 trials in session 2 for testing.  
333
        """
334
        if (dataset == 'BCI2a'):
335
            path = path + 's{:}/'.format(subject+1)
336
            X_train, y_train = load_BCI2a_data(path, subject+1, True)
337
            X_test, y_test = load_BCI2a_data(path, subject+1, False)
338
        elif (dataset == 'CS2R'):
339
            X_train, y_train, _, _, _ = load_CS2R_data_v2(path, subject, True, classes_labels)
340
            X_test, y_test, _, _, _ = load_CS2R_data_v2(path, subject, False, classes_labels)
341
        # elif (dataset == 'HGD'):
342
        #     X_train, y_train = load_HGD_data(path, subject+1, True)
343
        #     X_test, y_test = load_HGD_data(path, subject+1, False)
344
        else:
345
            raise Exception("'{}' dataset is not supported yet!".format(dataset))
346
347
    # shuffle the data 
348
    if isShuffle:
349
        X_train, y_train = shuffle(X_train, y_train,random_state=42)
350
        X_test, y_test = shuffle(X_test, y_test,random_state=42)
351
352
    # Prepare training data     
353
    N_tr, N_ch, T = X_train.shape 
354
    X_train = X_train.reshape(N_tr, 1, N_ch, T)
355
    y_train_onehot = to_categorical(y_train)
356
    # Prepare testing data 
357
    N_tr, N_ch, T = X_test.shape 
358
    X_test = X_test.reshape(N_tr, 1, N_ch, T)
359
    y_test_onehot = to_categorical(y_test)    
360
    
361
    # Standardize the data
362
    if isStandard:
363
        X_train, X_test = standardize_data(X_train, X_test, N_ch)
364
365
    return X_train, y_train, y_train_onehot, X_test, y_test, y_test_onehot
366