Diff of /train_within.py [000000] .. [5d1c0a]

Switch to side-by-side view

--- a
+++ b/train_within.py
@@ -0,0 +1,90 @@
+#!/usr/bin/env python
+# coding: utf-8
+'''Subject-specific classification with KU Data,
+using Deep ConvNet model from [1].
+
+References
+----------
+.. [1] Schirrmeister, R. T., Springenberg, J. T., Fiederer, L. D. J.,
+   Glasstetter, M., Eggensperger, K., Tangermann, M., Hutter, F. & Ball, T. (2017).
+   Deep learning with convolutional neural networks for EEG decoding and
+   visualization.
+   Human Brain Mapping , Aug. 2017. Online: http://dx.doi.org/10.1002/hbm.23730
+'''
+
+import argparse
+import json
+import logging
+import sys
+from os.path import join as pjoin
+
+import h5py
+import torch
+import torch.nn.functional as F
+from braindecode.models.deep4 import Deep4Net
+from braindecode.torch_ext.optimizers import AdamW
+from braindecode.torch_ext.util import set_random_seeds
+
+logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
+                    level=logging.INFO, stream=sys.stdout)
+parser = argparse.ArgumentParser(
+    description='Subject-specific classification with KU Data')
+parser.add_argument('datapath', type=str, help='Path to the h5 data file')
+parser.add_argument('outpath', type=str, help='Path to the result folder')
+parser.add_argument('-gpu', type=int,
+                    help='The gpu device index to use', default=0)
+parser.add_argument('-start', type=int,
+                    help='Start of the subject index', default=1)
+parser.add_argument(
+    '-end', type=int, help='End of the subject index (not inclusive)', default=55)
+parser.add_argument('-subj', type=int, nargs='+',
+                    help='Explicitly set the subject number. This will override the start and end argument')
+args = parser.parse_args()
+
+datapath = args.datapath
+outpath = args.outpath
+start = args.start
+end = args.end
+assert(start < end)
+subjs = args.subj if args.subj else range(start, end)
+dfile = h5py.File(datapath, 'r')
+torch.cuda.set_device(args.gpu)
+set_random_seeds(seed=20200205, cuda=True)
+
+
+def get_data(subj):
+    dpath = '/s' + str(subj)
+    X = dfile[pjoin(dpath, 'X')]
+    Y = dfile[pjoin(dpath, 'Y')]
+    return X[:], Y[:]
+
+
+for subj in subjs:
+    # Get data for within-subject classification
+    X, Y = get_data(subj)
+    X_train, Y_train = X[:200], Y[:200]
+    X_val, Y_val = X[200:300], Y[200:300]
+    X_test, Y_test = X[300:], Y[300:]
+
+    suffix = 's' + str(subj)
+    n_classes = 2
+    in_chans = X.shape[1]
+
+    # final_conv_length = auto ensures we only get a single output in the time dimension
+    model = Deep4Net(in_chans=in_chans, n_classes=n_classes,
+                     input_time_length=X.shape[2],
+                     final_conv_length='auto').cuda()
+
+    # these are good values for the deep model
+    optimizer = AdamW(model.parameters(), lr=1 * 0.01, weight_decay=0.5*0.001)
+    model.compile(loss=F.nll_loss, optimizer=optimizer, iterator_seed=1, )
+
+    model.fit(X_train, Y_train, epochs=200, batch_size=16, scheduler='cosine',
+              validation_data=(X_val, Y_val), remember_best_column='valid_loss')
+
+    test_loss = model.evaluate(X_test, Y_test)
+    model.epochs_df.to_csv(pjoin(outpath, 'epochs_' + suffix + '.csv'))
+    with open(pjoin(outpath, 'test_subj_' + str(subj) + '.json'), 'w') as f:
+        json.dump(test_loss, f)
+
+dfile.close()