a b/preprocess_HGD.py
1
""" 
2
Copyright (C) 2022 King Saud University, Saudi Arabia 
3
SPDX-License-Identifier: Apache-2.0 
4
5
Licensed under the Apache License, Version 2.0 (the "License"); you may not use
6
this file except in compliance with the License. You may obtain a copy of the 
7
License at
8
9
http://www.apache.org/licenses/LICENSE-2.0  
10
11
Unless required by applicable law or agreed to in writing, software distributed
12
under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 
13
CONDITIONS OF ANY KIND, either express or implied. See the License for the
14
specific language governing permissions and limitations under the License. 
15
16
Author:  Hamdi Altaheri 
17
"""
18
19
#%%
20
# We need the following to load and preprocess the High Gamma Dataset
21
import numpy as np
22
import logging
23
from collections import OrderedDict
24
from braindecode.datasets.bbci import BBCIDataset
25
from braindecode.datautil.trial_segment import \
26
    create_signal_target_from_raw_mne
27
from braindecode.mne_ext.signalproc import mne_apply, resample_cnt
28
from braindecode.datautil.signalproc import exponential_running_standardize
29
from braindecode.datautil.signalproc import highpass_cnt
30
31
#%%
32
def load_HGD_data(data_path, subject, training, low_cut_hz =0, debug = False):
33
    """ Loading training/testing data for the High Gamma Dataset (HGD)
34
    for a specific subject.
35
    
36
    Please note that  HGD is for "executed movements" NOT "motor imagery"  
37
    
38
    This code is taken from https://github.com/robintibor/high-gamma-dataset 
39
    You can download the HGD using the following link: 
40
        https://gin.g-node.org/robintibor/high-gamma-dataset/src/master/data
41
    The Braindecode library is required to load and processs the HGD dataset.
42
   
43
        Parameters
44
        ----------
45
        data_path: string
46
            dataset path
47
        subject: int
48
            number of subject in [1, .. ,14]
49
        training: bool
50
            if True, load training data
51
            if False, load testing data
52
        debug: bool
53
            if True, 
54
            if False, 
55
    """
56
57
    log = logging.getLogger(__name__)
58
    log.setLevel('DEBUG')
59
60
    if training:  filename = (data_path + 'train/{}.mat'.format(subject))
61
    else:         filename = (data_path + 'test/{}.mat'.format(subject))
62
63
    load_sensor_names = None
64
    if debug:
65
        load_sensor_names = ['C3', 'C4', 'C2']
66
    # we loaded all sensors to always get same cleaning results independent of sensor selection
67
    # There is an inbuilt heuristic that tries to use only EEG channels and that definitely
68
    # works for datasets in our paper
69
    loader = BBCIDataset(filename, load_sensor_names=load_sensor_names)
70
    
71
    log.info("Loading data...")
72
    cnt = loader.load()
73
74
    # Cleaning: First find all trials that have absolute microvolt values
75
    # larger than +- 800 inside them and remember them for removal later
76
    log.info("Cutting trials...")
77
78
    marker_def = OrderedDict([('Right Hand', [1]), ('Left Hand', [2],),
79
                              ('Rest', [3]), ('Feet', [4])])
80
    clean_ival = [0, 4000]
81
82
    set_for_cleaning = create_signal_target_from_raw_mne(cnt, marker_def,
83
                                                  clean_ival)
84
85
    clean_trial_mask = np.max(np.abs(set_for_cleaning.X), axis=(1, 2)) < 800
86
87
    log.info("Clean trials: {:3d}  of {:3d} ({:5.1f}%)".format(
88
        np.sum(clean_trial_mask),
89
        len(set_for_cleaning.X),
90
        np.mean(clean_trial_mask) * 100))
91
92
    # now pick only sensors with C in their name
93
    # as they cover motor cortex
94
    C_sensors = ['FC5', 'FC1', 'FC2', 'FC6', 'C3', 'C4', 'CP5',
95
                 'CP1', 'CP2', 'CP6', 'FC3', 'FCz', 'FC4', 'C5', 'C1', 'C2',
96
                 'C6',
97
                 'CP3', 'CPz', 'CP4', 'FFC5h', 'FFC3h', 'FFC4h', 'FFC6h',
98
                 'FCC5h',
99
                 'FCC3h', 'FCC4h', 'FCC6h', 'CCP5h', 'CCP3h', 'CCP4h', 'CCP6h',
100
                 'CPP5h',
101
                 'CPP3h', 'CPP4h', 'CPP6h', 'FFC1h', 'FFC2h', 'FCC1h', 'FCC2h',
102
                 'CCP1h',
103
                 'CCP2h', 'CPP1h', 'CPP2h']
104
    if debug:
105
        C_sensors = load_sensor_names
106
    cnt = cnt.pick_channels(C_sensors)
107
108
    # Further preprocessings as descibed in paper
109
    log.info("Resampling...")
110
    cnt = resample_cnt(cnt, 250.0)
111
    log.info("Highpassing...")
112
    cnt = mne_apply(
113
        lambda a: highpass_cnt(
114
            a, low_cut_hz, cnt.info['sfreq'], filt_order=3, axis=1),
115
        cnt)
116
    log.info("Standardizing...")
117
    cnt = mne_apply(
118
        lambda a: exponential_running_standardize(a.T, factor_new=1e-3,
119
                                                  init_block_size=1000,
120
                                                  eps=1e-4).T,
121
        cnt)
122
123
    # Trial interval, start at -500 already, since improved decoding for networks
124
    ival = [-500, 4000]
125
126
    dataset = create_signal_target_from_raw_mne(cnt, marker_def, ival)
127
    dataset.X = dataset.X[clean_trial_mask]
128
    dataset.y = dataset.y[clean_trial_mask]
129
    return dataset.X, dataset.y