Diff of /train_adapt.py [000000] .. [5d1c0a]

Switch to side-by-side view

--- a
+++ b/train_adapt.py
@@ -0,0 +1,187 @@
+#!/usr/bin/env python
+# coding: utf-8
+'''Subject-adaptative classification with KU Data,
+using Deep ConvNet model from [1].
+
+References
+----------
+.. [1] Schirrmeister, R. T., Springenberg, J. T., Fiederer, L. D. J.,
+   Glasstetter, M., Eggensperger, K., Tangermann, M., Hutter, F. & Ball, T. (2017).
+   Deep learning with convolutional neural networks for EEG decoding and
+   visualization.
+   Human Brain Mapping , Aug. 2017. Online: http://dx.doi.org/10.1002/hbm.23730
+'''
+import argparse
+import json
+import logging
+import sys
+from os.path import join as pjoin
+
+import h5py
+import torch
+import torch.nn.functional as F
+from braindecode.models.deep4 import Deep4Net
+from braindecode.torch_ext.optimizers import AdamW
+from braindecode.torch_ext.util import set_random_seeds
+from torch import nn
+
+logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
+                    level=logging.INFO, stream=sys.stdout)
+
+parser = argparse.ArgumentParser(
+    description='Subject-adaptative classification with KU Data')
+parser.add_argument('datapath', type=str, help='Path to the h5 data file')
+parser.add_argument('modelpath', type=str,
+                    help='Path to the base model folder')
+parser.add_argument('outpath', type=str, help='Path to the result folder')
+parser.add_argument('-scheme', type=int, help='Adaptation scheme', default=4)
+parser.add_argument(
+    '-trfrate', type=int, help='The percentage of data for adaptation', default=100)
+parser.add_argument('-lr', type=float, help='Learning rate', default=0.0005)
+parser.add_argument('-gpu', type=int, help='The gpu device to use', default=0)
+
+args = parser.parse_args()
+datapath = args.datapath
+outpath = args.outpath
+modelpath = args.modelpath
+scheme = args.scheme
+rate = args.trfrate
+lr = args.lr
+dfile = h5py.File(datapath, 'r')
+torch.cuda.set_device(args.gpu)
+set_random_seeds(seed=20200205, cuda=True)
+BATCH_SIZE = 16
+TRAIN_EPOCH = 200
+
+# Randomly shuffled subject.
+subjs = [35, 47, 46, 37, 13, 27, 12, 32, 53, 54, 4, 40, 19, 41, 18, 42, 34, 7,
+         49, 9, 5, 48, 29, 15, 21, 17, 31, 45, 1, 38, 51, 8, 11, 16, 28, 44, 24,
+         52, 3, 26, 39, 50, 6, 23, 2, 14, 25, 20, 10, 33, 22, 43, 36, 30]
+
+
+# Get data from single subject.
+def get_data(subj):
+    dpath = '/s' + str(subj)
+    X = dfile[pjoin(dpath, 'X')]
+    Y = dfile[pjoin(dpath, 'Y')]
+    return X[:], Y[:]
+
+
+X, Y = get_data(subjs[0])
+n_classes = 2
+in_chans = X.shape[1]
+# final_conv_length = auto ensures we only get a single output in the time dimension
+model = Deep4Net(in_chans=in_chans, n_classes=n_classes,
+                 input_time_length=X.shape[2],
+                 final_conv_length='auto').cuda()
+
+# Deprecated.
+
+
+def reset_conv_pool_block(network, block_nr):
+    suffix = "_{:d}".format(block_nr)
+    conv = getattr(network, 'conv' + suffix)
+    kernel_size = conv.kernel_size
+    n_filters_before = conv.in_channels
+    n_filters = conv.out_channels
+    setattr(network, 'conv' + suffix,
+            nn.Conv2d(
+                n_filters_before,
+                n_filters,
+                kernel_size,
+                stride=(1, 1),
+                bias=False,
+            ))
+    setattr(network, 'bnorm' + suffix,
+            nn.BatchNorm2d(
+                n_filters,
+                momentum=0.1,
+                affine=True,
+                eps=1e-5,
+            ))
+    # Initialize the layers.
+    conv = getattr(network, 'conv' + suffix)
+    bnorm = getattr(network, 'bnorm' + suffix)
+    nn.init.xavier_uniform_(conv.weight, gain=1)
+    nn.init.constant_(bnorm.weight, 1)
+    nn.init.constant_(bnorm.bias, 0)
+
+
+def reset_model(checkpoint):
+    # Load the state dict of the model.
+    model.network.load_state_dict(checkpoint['model_state_dict'])
+
+    # # Resets the last conv block
+    # reset_conv_pool_block(model.network, block_nr=4)
+    # reset_conv_pool_block(model.network, block_nr=3)
+    # reset_conv_pool_block(model.network, block_nr=2)
+    # # Resets the fully-connected layer.
+    # # Parameters of newly constructed modules have requires_grad=True by default.
+    # n_final_conv_length = model.network.conv_classifier.kernel_size[0]
+    # n_prev_filter = model.network.conv_classifier.in_channels
+    # n_classes = model.network.conv_classifier.out_channels
+    # model.network.conv_classifier = nn.Conv2d(
+    #     n_prev_filter, n_classes, (n_final_conv_length, 1), bias=True)
+    # nn.init.xavier_uniform_(model.network.conv_classifier.weight, gain=1)
+    # nn.init.constant_(model.network.conv_classifier.bias, 0)
+
+    if scheme != 5:
+        # Freeze all layers.
+        for param in model.network.parameters():
+            param.requires_grad = False
+
+        if scheme in {1, 2, 3, 4}:
+            # Unfreeze the FC layer.
+            for param in model.network.conv_classifier.parameters():
+                param.requires_grad = True
+
+        if scheme in {2, 3, 4}:
+            # Unfreeze the conv4 layer.
+            for param in model.network.conv_4.parameters():
+                param.requires_grad = True
+            for param in model.network.bnorm_4.parameters():
+                param.requires_grad = True
+
+        if scheme in {3, 4}:
+            # Unfreeze the conv3 layer.
+            for param in model.network.conv_3.parameters():
+                param.requires_grad = True
+            for param in model.network.bnorm_3.parameters():
+                param.requires_grad = True
+
+        if scheme == 4:
+            # Unfreeze the conv2 layer.
+            for param in model.network.conv_2.parameters():
+                param.requires_grad = True
+            for param in model.network.bnorm_2.parameters():
+                param.requires_grad = True
+
+    # Only optimize parameters that requires gradient.
+    optimizer = AdamW(filter(lambda p: p.requires_grad, model.network.parameters()),
+                      lr=lr, weight_decay=0.5*0.001)
+    model.compile(loss=F.nll_loss, optimizer=optimizer,
+                  iterator_seed=20200205, )
+
+cutoff = int(rate * 200 / 100)
+# Use only session 1 data for training
+assert(cutoff <= 200)
+
+for fold, subj in enumerate(subjs):
+    suffix = '_s' + str(subj) + '_f' + str(fold)
+    checkpoint = torch.load(pjoin(modelpath, 'model_f' + str(fold) + '.pt'),
+                            map_location='cuda:' + str(args.gpu))
+    reset_model(checkpoint)
+
+    X, Y = get_data(subj)
+    X_train, Y_train = X[:cutoff], Y[:cutoff]
+    X_val, Y_val = X[200:300], Y[200:300]
+    X_test, Y_test = X[300:], Y[300:]
+    model.fit(X_train, Y_train, epochs=TRAIN_EPOCH,
+              batch_size=BATCH_SIZE, scheduler='cosine',
+              validation_data=(X_val, Y_val), remember_best_column='valid_loss')
+    model.epochs_df.to_csv(pjoin(outpath, 'epochs' + suffix + '.csv'))
+    test_loss = model.evaluate(X_test, Y_test)
+    with open(pjoin(outpath, 'test' + suffix + '.json'), 'w') as f:
+        json.dump(test_loss, f)
+
+dfile.close()