--- a +++ b/train_adapt.py @@ -0,0 +1,187 @@ +#!/usr/bin/env python +# coding: utf-8 +'''Subject-adaptative classification with KU Data, +using Deep ConvNet model from [1]. + +References +---------- +.. [1] Schirrmeister, R. T., Springenberg, J. T., Fiederer, L. D. J., + Glasstetter, M., Eggensperger, K., Tangermann, M., Hutter, F. & Ball, T. (2017). + Deep learning with convolutional neural networks for EEG decoding and + visualization. + Human Brain Mapping , Aug. 2017. Online: http://dx.doi.org/10.1002/hbm.23730 +''' +import argparse +import json +import logging +import sys +from os.path import join as pjoin + +import h5py +import torch +import torch.nn.functional as F +from braindecode.models.deep4 import Deep4Net +from braindecode.torch_ext.optimizers import AdamW +from braindecode.torch_ext.util import set_random_seeds +from torch import nn + +logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s', + level=logging.INFO, stream=sys.stdout) + +parser = argparse.ArgumentParser( + description='Subject-adaptative classification with KU Data') +parser.add_argument('datapath', type=str, help='Path to the h5 data file') +parser.add_argument('modelpath', type=str, + help='Path to the base model folder') +parser.add_argument('outpath', type=str, help='Path to the result folder') +parser.add_argument('-scheme', type=int, help='Adaptation scheme', default=4) +parser.add_argument( + '-trfrate', type=int, help='The percentage of data for adaptation', default=100) +parser.add_argument('-lr', type=float, help='Learning rate', default=0.0005) +parser.add_argument('-gpu', type=int, help='The gpu device to use', default=0) + +args = parser.parse_args() +datapath = args.datapath +outpath = args.outpath +modelpath = args.modelpath +scheme = args.scheme +rate = args.trfrate +lr = args.lr +dfile = h5py.File(datapath, 'r') +torch.cuda.set_device(args.gpu) +set_random_seeds(seed=20200205, cuda=True) +BATCH_SIZE = 16 +TRAIN_EPOCH = 200 + +# Randomly shuffled subject. +subjs = [35, 47, 46, 37, 13, 27, 12, 32, 53, 54, 4, 40, 19, 41, 18, 42, 34, 7, + 49, 9, 5, 48, 29, 15, 21, 17, 31, 45, 1, 38, 51, 8, 11, 16, 28, 44, 24, + 52, 3, 26, 39, 50, 6, 23, 2, 14, 25, 20, 10, 33, 22, 43, 36, 30] + + +# Get data from single subject. +def get_data(subj): + dpath = '/s' + str(subj) + X = dfile[pjoin(dpath, 'X')] + Y = dfile[pjoin(dpath, 'Y')] + return X[:], Y[:] + + +X, Y = get_data(subjs[0]) +n_classes = 2 +in_chans = X.shape[1] +# final_conv_length = auto ensures we only get a single output in the time dimension +model = Deep4Net(in_chans=in_chans, n_classes=n_classes, + input_time_length=X.shape[2], + final_conv_length='auto').cuda() + +# Deprecated. + + +def reset_conv_pool_block(network, block_nr): + suffix = "_{:d}".format(block_nr) + conv = getattr(network, 'conv' + suffix) + kernel_size = conv.kernel_size + n_filters_before = conv.in_channels + n_filters = conv.out_channels + setattr(network, 'conv' + suffix, + nn.Conv2d( + n_filters_before, + n_filters, + kernel_size, + stride=(1, 1), + bias=False, + )) + setattr(network, 'bnorm' + suffix, + nn.BatchNorm2d( + n_filters, + momentum=0.1, + affine=True, + eps=1e-5, + )) + # Initialize the layers. + conv = getattr(network, 'conv' + suffix) + bnorm = getattr(network, 'bnorm' + suffix) + nn.init.xavier_uniform_(conv.weight, gain=1) + nn.init.constant_(bnorm.weight, 1) + nn.init.constant_(bnorm.bias, 0) + + +def reset_model(checkpoint): + # Load the state dict of the model. + model.network.load_state_dict(checkpoint['model_state_dict']) + + # # Resets the last conv block + # reset_conv_pool_block(model.network, block_nr=4) + # reset_conv_pool_block(model.network, block_nr=3) + # reset_conv_pool_block(model.network, block_nr=2) + # # Resets the fully-connected layer. + # # Parameters of newly constructed modules have requires_grad=True by default. + # n_final_conv_length = model.network.conv_classifier.kernel_size[0] + # n_prev_filter = model.network.conv_classifier.in_channels + # n_classes = model.network.conv_classifier.out_channels + # model.network.conv_classifier = nn.Conv2d( + # n_prev_filter, n_classes, (n_final_conv_length, 1), bias=True) + # nn.init.xavier_uniform_(model.network.conv_classifier.weight, gain=1) + # nn.init.constant_(model.network.conv_classifier.bias, 0) + + if scheme != 5: + # Freeze all layers. + for param in model.network.parameters(): + param.requires_grad = False + + if scheme in {1, 2, 3, 4}: + # Unfreeze the FC layer. + for param in model.network.conv_classifier.parameters(): + param.requires_grad = True + + if scheme in {2, 3, 4}: + # Unfreeze the conv4 layer. + for param in model.network.conv_4.parameters(): + param.requires_grad = True + for param in model.network.bnorm_4.parameters(): + param.requires_grad = True + + if scheme in {3, 4}: + # Unfreeze the conv3 layer. + for param in model.network.conv_3.parameters(): + param.requires_grad = True + for param in model.network.bnorm_3.parameters(): + param.requires_grad = True + + if scheme == 4: + # Unfreeze the conv2 layer. + for param in model.network.conv_2.parameters(): + param.requires_grad = True + for param in model.network.bnorm_2.parameters(): + param.requires_grad = True + + # Only optimize parameters that requires gradient. + optimizer = AdamW(filter(lambda p: p.requires_grad, model.network.parameters()), + lr=lr, weight_decay=0.5*0.001) + model.compile(loss=F.nll_loss, optimizer=optimizer, + iterator_seed=20200205, ) + +cutoff = int(rate * 200 / 100) +# Use only session 1 data for training +assert(cutoff <= 200) + +for fold, subj in enumerate(subjs): + suffix = '_s' + str(subj) + '_f' + str(fold) + checkpoint = torch.load(pjoin(modelpath, 'model_f' + str(fold) + '.pt'), + map_location='cuda:' + str(args.gpu)) + reset_model(checkpoint) + + X, Y = get_data(subj) + X_train, Y_train = X[:cutoff], Y[:cutoff] + X_val, Y_val = X[200:300], Y[200:300] + X_test, Y_test = X[300:], Y[300:] + model.fit(X_train, Y_train, epochs=TRAIN_EPOCH, + batch_size=BATCH_SIZE, scheduler='cosine', + validation_data=(X_val, Y_val), remember_best_column='valid_loss') + model.epochs_df.to_csv(pjoin(outpath, 'epochs' + suffix + '.csv')) + test_loss = model.evaluate(X_test, Y_test) + with open(pjoin(outpath, 'test' + suffix + '.json'), 'w') as f: + json.dump(test_loss, f) + +dfile.close()