a b/train_adapt.py
1
#!/usr/bin/env python
2
# coding: utf-8
3
'''Subject-adaptative classification with KU Data,
4
using Deep ConvNet model from [1].
5
6
References
7
----------
8
.. [1] Schirrmeister, R. T., Springenberg, J. T., Fiederer, L. D. J.,
9
   Glasstetter, M., Eggensperger, K., Tangermann, M., Hutter, F. & Ball, T. (2017).
10
   Deep learning with convolutional neural networks for EEG decoding and
11
   visualization.
12
   Human Brain Mapping , Aug. 2017. Online: http://dx.doi.org/10.1002/hbm.23730
13
'''
14
import argparse
15
import json
16
import logging
17
import sys
18
from os.path import join as pjoin
19
20
import h5py
21
import torch
22
import torch.nn.functional as F
23
from braindecode.models.deep4 import Deep4Net
24
from braindecode.torch_ext.optimizers import AdamW
25
from braindecode.torch_ext.util import set_random_seeds
26
from torch import nn
27
28
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
29
                    level=logging.INFO, stream=sys.stdout)
30
31
parser = argparse.ArgumentParser(
32
    description='Subject-adaptative classification with KU Data')
33
parser.add_argument('datapath', type=str, help='Path to the h5 data file')
34
parser.add_argument('modelpath', type=str,
35
                    help='Path to the base model folder')
36
parser.add_argument('outpath', type=str, help='Path to the result folder')
37
parser.add_argument('-scheme', type=int, help='Adaptation scheme', default=4)
38
parser.add_argument(
39
    '-trfrate', type=int, help='The percentage of data for adaptation', default=100)
40
parser.add_argument('-lr', type=float, help='Learning rate', default=0.0005)
41
parser.add_argument('-gpu', type=int, help='The gpu device to use', default=0)
42
43
args = parser.parse_args()
44
datapath = args.datapath
45
outpath = args.outpath
46
modelpath = args.modelpath
47
scheme = args.scheme
48
rate = args.trfrate
49
lr = args.lr
50
dfile = h5py.File(datapath, 'r')
51
torch.cuda.set_device(args.gpu)
52
set_random_seeds(seed=20200205, cuda=True)
53
BATCH_SIZE = 16
54
TRAIN_EPOCH = 200
55
56
# Randomly shuffled subject.
57
subjs = [35, 47, 46, 37, 13, 27, 12, 32, 53, 54, 4, 40, 19, 41, 18, 42, 34, 7,
58
         49, 9, 5, 48, 29, 15, 21, 17, 31, 45, 1, 38, 51, 8, 11, 16, 28, 44, 24,
59
         52, 3, 26, 39, 50, 6, 23, 2, 14, 25, 20, 10, 33, 22, 43, 36, 30]
60
61
62
# Get data from single subject.
63
def get_data(subj):
64
    dpath = '/s' + str(subj)
65
    X = dfile[pjoin(dpath, 'X')]
66
    Y = dfile[pjoin(dpath, 'Y')]
67
    return X[:], Y[:]
68
69
70
X, Y = get_data(subjs[0])
71
n_classes = 2
72
in_chans = X.shape[1]
73
# final_conv_length = auto ensures we only get a single output in the time dimension
74
model = Deep4Net(in_chans=in_chans, n_classes=n_classes,
75
                 input_time_length=X.shape[2],
76
                 final_conv_length='auto').cuda()
77
78
# Deprecated.
79
80
81
def reset_conv_pool_block(network, block_nr):
82
    suffix = "_{:d}".format(block_nr)
83
    conv = getattr(network, 'conv' + suffix)
84
    kernel_size = conv.kernel_size
85
    n_filters_before = conv.in_channels
86
    n_filters = conv.out_channels
87
    setattr(network, 'conv' + suffix,
88
            nn.Conv2d(
89
                n_filters_before,
90
                n_filters,
91
                kernel_size,
92
                stride=(1, 1),
93
                bias=False,
94
            ))
95
    setattr(network, 'bnorm' + suffix,
96
            nn.BatchNorm2d(
97
                n_filters,
98
                momentum=0.1,
99
                affine=True,
100
                eps=1e-5,
101
            ))
102
    # Initialize the layers.
103
    conv = getattr(network, 'conv' + suffix)
104
    bnorm = getattr(network, 'bnorm' + suffix)
105
    nn.init.xavier_uniform_(conv.weight, gain=1)
106
    nn.init.constant_(bnorm.weight, 1)
107
    nn.init.constant_(bnorm.bias, 0)
108
109
110
def reset_model(checkpoint):
111
    # Load the state dict of the model.
112
    model.network.load_state_dict(checkpoint['model_state_dict'])
113
114
    # # Resets the last conv block
115
    # reset_conv_pool_block(model.network, block_nr=4)
116
    # reset_conv_pool_block(model.network, block_nr=3)
117
    # reset_conv_pool_block(model.network, block_nr=2)
118
    # # Resets the fully-connected layer.
119
    # # Parameters of newly constructed modules have requires_grad=True by default.
120
    # n_final_conv_length = model.network.conv_classifier.kernel_size[0]
121
    # n_prev_filter = model.network.conv_classifier.in_channels
122
    # n_classes = model.network.conv_classifier.out_channels
123
    # model.network.conv_classifier = nn.Conv2d(
124
    #     n_prev_filter, n_classes, (n_final_conv_length, 1), bias=True)
125
    # nn.init.xavier_uniform_(model.network.conv_classifier.weight, gain=1)
126
    # nn.init.constant_(model.network.conv_classifier.bias, 0)
127
128
    if scheme != 5:
129
        # Freeze all layers.
130
        for param in model.network.parameters():
131
            param.requires_grad = False
132
133
        if scheme in {1, 2, 3, 4}:
134
            # Unfreeze the FC layer.
135
            for param in model.network.conv_classifier.parameters():
136
                param.requires_grad = True
137
138
        if scheme in {2, 3, 4}:
139
            # Unfreeze the conv4 layer.
140
            for param in model.network.conv_4.parameters():
141
                param.requires_grad = True
142
            for param in model.network.bnorm_4.parameters():
143
                param.requires_grad = True
144
145
        if scheme in {3, 4}:
146
            # Unfreeze the conv3 layer.
147
            for param in model.network.conv_3.parameters():
148
                param.requires_grad = True
149
            for param in model.network.bnorm_3.parameters():
150
                param.requires_grad = True
151
152
        if scheme == 4:
153
            # Unfreeze the conv2 layer.
154
            for param in model.network.conv_2.parameters():
155
                param.requires_grad = True
156
            for param in model.network.bnorm_2.parameters():
157
                param.requires_grad = True
158
159
    # Only optimize parameters that requires gradient.
160
    optimizer = AdamW(filter(lambda p: p.requires_grad, model.network.parameters()),
161
                      lr=lr, weight_decay=0.5*0.001)
162
    model.compile(loss=F.nll_loss, optimizer=optimizer,
163
                  iterator_seed=20200205, )
164
165
cutoff = int(rate * 200 / 100)
166
# Use only session 1 data for training
167
assert(cutoff <= 200)
168
169
for fold, subj in enumerate(subjs):
170
    suffix = '_s' + str(subj) + '_f' + str(fold)
171
    checkpoint = torch.load(pjoin(modelpath, 'model_f' + str(fold) + '.pt'),
172
                            map_location='cuda:' + str(args.gpu))
173
    reset_model(checkpoint)
174
175
    X, Y = get_data(subj)
176
    X_train, Y_train = X[:cutoff], Y[:cutoff]
177
    X_val, Y_val = X[200:300], Y[200:300]
178
    X_test, Y_test = X[300:], Y[300:]
179
    model.fit(X_train, Y_train, epochs=TRAIN_EPOCH,
180
              batch_size=BATCH_SIZE, scheduler='cosine',
181
              validation_data=(X_val, Y_val), remember_best_column='valid_loss')
182
    model.epochs_df.to_csv(pjoin(outpath, 'epochs' + suffix + '.csv'))
183
    test_loss = model.evaluate(X_test, Y_test)
184
    with open(pjoin(outpath, 'test' + suffix + '.json'), 'w') as f:
185
        json.dump(test_loss, f)
186
187
dfile.close()