a b/train_within.py
1
#!/usr/bin/env python
2
# coding: utf-8
3
'''Subject-specific classification with KU Data,
4
using Deep ConvNet model from [1].
5
6
References
7
----------
8
.. [1] Schirrmeister, R. T., Springenberg, J. T., Fiederer, L. D. J.,
9
   Glasstetter, M., Eggensperger, K., Tangermann, M., Hutter, F. & Ball, T. (2017).
10
   Deep learning with convolutional neural networks for EEG decoding and
11
   visualization.
12
   Human Brain Mapping , Aug. 2017. Online: http://dx.doi.org/10.1002/hbm.23730
13
'''
14
15
import argparse
16
import json
17
import logging
18
import sys
19
from os.path import join as pjoin
20
21
import h5py
22
import torch
23
import torch.nn.functional as F
24
from braindecode.models.deep4 import Deep4Net
25
from braindecode.torch_ext.optimizers import AdamW
26
from braindecode.torch_ext.util import set_random_seeds
27
28
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
29
                    level=logging.INFO, stream=sys.stdout)
30
parser = argparse.ArgumentParser(
31
    description='Subject-specific classification with KU Data')
32
parser.add_argument('datapath', type=str, help='Path to the h5 data file')
33
parser.add_argument('outpath', type=str, help='Path to the result folder')
34
parser.add_argument('-gpu', type=int,
35
                    help='The gpu device index to use', default=0)
36
parser.add_argument('-start', type=int,
37
                    help='Start of the subject index', default=1)
38
parser.add_argument(
39
    '-end', type=int, help='End of the subject index (not inclusive)', default=55)
40
parser.add_argument('-subj', type=int, nargs='+',
41
                    help='Explicitly set the subject number. This will override the start and end argument')
42
args = parser.parse_args()
43
44
datapath = args.datapath
45
outpath = args.outpath
46
start = args.start
47
end = args.end
48
assert(start < end)
49
subjs = args.subj if args.subj else range(start, end)
50
dfile = h5py.File(datapath, 'r')
51
torch.cuda.set_device(args.gpu)
52
set_random_seeds(seed=20200205, cuda=True)
53
54
55
def get_data(subj):
56
    dpath = '/s' + str(subj)
57
    X = dfile[pjoin(dpath, 'X')]
58
    Y = dfile[pjoin(dpath, 'Y')]
59
    return X[:], Y[:]
60
61
62
for subj in subjs:
63
    # Get data for within-subject classification
64
    X, Y = get_data(subj)
65
    X_train, Y_train = X[:200], Y[:200]
66
    X_val, Y_val = X[200:300], Y[200:300]
67
    X_test, Y_test = X[300:], Y[300:]
68
69
    suffix = 's' + str(subj)
70
    n_classes = 2
71
    in_chans = X.shape[1]
72
73
    # final_conv_length = auto ensures we only get a single output in the time dimension
74
    model = Deep4Net(in_chans=in_chans, n_classes=n_classes,
75
                     input_time_length=X.shape[2],
76
                     final_conv_length='auto').cuda()
77
78
    # these are good values for the deep model
79
    optimizer = AdamW(model.parameters(), lr=1 * 0.01, weight_decay=0.5*0.001)
80
    model.compile(loss=F.nll_loss, optimizer=optimizer, iterator_seed=1, )
81
82
    model.fit(X_train, Y_train, epochs=200, batch_size=16, scheduler='cosine',
83
              validation_data=(X_val, Y_val), remember_best_column='valid_loss')
84
85
    test_loss = model.evaluate(X_test, Y_test)
86
    model.epochs_df.to_csv(pjoin(outpath, 'epochs_' + suffix + '.csv'))
87
    with open(pjoin(outpath, 'test_subj_' + str(subj) + '.json'), 'w') as f:
88
        json.dump(test_loss, f)
89
90
dfile.close()