Diff of /train_base.py [000000] .. [5d1c0a]

Switch to unified view

a b/train_base.py
1
#!/usr/bin/env python
2
# coding: utf-8
3
'''Subject-independent classification with KU Data,
4
using Deep ConvNet model from [1].
5
6
References
7
----------
8
.. [1] Schirrmeister, R. T., Springenberg, J. T., Fiederer, L. D. J.,
9
   Glasstetter, M., Eggensperger, K., Tangermann, M., Hutter, F. & Ball, T. (2017).
10
   Deep learning with convolutional neural networks for EEG decoding and
11
   visualization.
12
   Human Brain Mapping , Aug. 2017. Online: http://dx.doi.org/10.1002/hbm.23730
13
'''
14
15
import argparse
16
import json
17
import logging
18
import sys
19
from os import makedirs
20
from os.path import join as pjoin
21
from shutil import copy2, move
22
23
import h5py
24
import numpy as np
25
import torch
26
import torch.nn.functional as F
27
from braindecode.datautil.signal_target import SignalAndTarget
28
from braindecode.models.deep4 import Deep4Net
29
from braindecode.torch_ext.optimizers import AdamW
30
from braindecode.torch_ext.util import set_random_seeds
31
from sklearn.model_selection import KFold
32
33
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
34
                    level=logging.INFO, stream=sys.stdout)
35
36
parser = argparse.ArgumentParser(
37
    description='Subject-independent classification with KU Data')
38
parser.add_argument('datapath', type=str, help='Path to the h5 data file')
39
parser.add_argument('outpath', type=str, help='Path to the result folder')
40
parser.add_argument('-fold', type=int,
41
                    help='k-fold index, starts with 0', required=True)
42
parser.add_argument('-gpu', type=int, help='The gpu device to use', default=0)
43
44
args = parser.parse_args()
45
datapath = args.datapath
46
outpath = args.outpath
47
fold = args.fold
48
assert(fold >= 0 and fold < 54)
49
# Randomly shuffled subject.
50
subjs = [35, 47, 46, 37, 13, 27, 12, 32, 53, 54, 4, 40, 19, 41, 18, 42, 34, 7,
51
         49, 9, 5, 48, 29, 15, 21, 17, 31, 45, 1, 38, 51, 8, 11, 16, 28, 44, 24,
52
         52, 3, 26, 39, 50, 6, 23, 2, 14, 25, 20, 10, 33, 22, 43, 36, 30]
53
test_subj = subjs[fold]
54
cv_set = np.array(subjs[fold+1:] + subjs[:fold])
55
kf = KFold(n_splits=6)
56
57
dfile = h5py.File(datapath, 'r')
58
torch.cuda.set_device(args.gpu)
59
set_random_seeds(seed=20200205, cuda=True)
60
BATCH_SIZE = 16
61
TRAIN_EPOCH = 200  # consider 200 for early stopping
62
63
# Get data from single subject.
64
65
66
def get_data(subj):
67
    dpath = '/s' + str(subj)
68
    X = dfile[pjoin(dpath, 'X')]
69
    Y = dfile[pjoin(dpath, 'Y')]
70
    return X, Y
71
72
73
def get_multi_data(subjs):
74
    Xs = []
75
    Ys = []
76
    for s in subjs:
77
        x, y = get_data(s)
78
        Xs.append(x[:])
79
        Ys.append(y[:])
80
    X = np.concatenate(Xs, axis=0)
81
    Y = np.concatenate(Ys, axis=0)
82
    return X, Y
83
84
85
cv_loss = []
86
for cv_index, (train_index, test_index) in enumerate(kf.split(cv_set)):
87
88
    train_subjs = cv_set[train_index]
89
    valid_subjs = cv_set[test_index]
90
    X_train, Y_train = get_multi_data(train_subjs)
91
    X_val, Y_val = get_multi_data(valid_subjs)
92
    X_test, Y_test = get_data(test_subj)
93
    train_set = SignalAndTarget(X_train, y=Y_train)
94
    valid_set = SignalAndTarget(X_val, y=Y_val)
95
    test_set = SignalAndTarget(X_test[200:], y=Y_test[200:])
96
    n_classes = 2
97
    in_chans = train_set.X.shape[1]
98
99
    # final_conv_length = auto ensures we only get a single output in the time dimension
100
    model = Deep4Net(in_chans=in_chans, n_classes=n_classes,
101
                     input_time_length=train_set.X.shape[2],
102
                     final_conv_length='auto').cuda()
103
    # these are good values for the deep model
104
    optimizer = AdamW(model.parameters(), lr=1*0.01, weight_decay=0.5*0.001)
105
    model.compile(loss=F.nll_loss, optimizer=optimizer, iterator_seed=1, )
106
107
    # Fit the base model for transfer learning, use early stopping as a hack to remember the model
108
    exp = model.fit(train_set.X, train_set.y, epochs=TRAIN_EPOCH, batch_size=BATCH_SIZE, scheduler='cosine',
109
                    validation_data=(valid_set.X, valid_set.y), remember_best_column='valid_loss')
110
    rememberer = exp.rememberer
111
    base_model_param = {
112
        'epoch': rememberer.best_epoch,
113
        'model_state_dict': rememberer.model_state_dict,
114
        'optimizer_state_dict': rememberer.optimizer_state_dict,
115
        'loss': rememberer.lowest_val
116
    }
117
    torch.save(base_model_param, pjoin(
118
        outpath, 'model_f{}_cv{}.pt'.format(fold, cv_index)))
119
    model.epochs_df.to_csv(
120
        pjoin(outpath, 'epochs_f{}_cv{}.csv'.format(fold, cv_index)))
121
    cv_loss.append(rememberer.lowest_val)
122
123
    test_loss = model.evaluate(test_set.X, test_set.y)
124
    with open(pjoin(outpath, 'test_base_s{}_f{}_cv{}.json'.format(test_subj, fold, cv_index)), 'w') as f:
125
        json.dump(test_loss, f)
126
127
best_cv = np.argmin(cv_loss)
128
best_dir = pjoin(outpath, "best")
129
makedirs(best_dir, exist_ok=True)
130
with open(pjoin(best_dir, "fold_bestcv.txt"), 'a') as f:
131
    f.write("{}, {}\n".format(fold, best_cv))
132
copy2(pjoin(outpath, 'model_f{}_cv{}.pt'.format(fold, best_cv)),
133
      pjoin(best_dir, 'model_f{}.pt'.format(fold)))
134
copy2(pjoin(outpath, 'epochs_f{}_cv{}.csv'.format(fold, best_cv)),
135
      pjoin(best_dir, 'epochs_f{}.csv'.format(fold)))
136
copy2(pjoin(outpath, 'test_base_s{}_f{}_cv{}.json'.format(test_subj, fold, best_cv)),
137
      pjoin(best_dir, 'test_base_s{}_f{}.json'.format(test_subj, fold)))
138
dfile.close()