a b/main.py
1
""" Code for the Main function of MetaPred. """
2
import os, csv
3
import numpy as np
4
import random
5
import pickle as pkl
6
import tensorflow as tf
7
import copy
8
from tensorflow.python.platform import flags
9
10
from data_loader import DataLoader
11
import model, finetune
12
13
14
FLAGS = flags.FLAGS
15
flags.DEFINE_string('source', 'AD', 'source task')
16
flags.DEFINE_string('target', 'MCI', 'simulated task')
17
flags.DEFINE_string('true_target', 'PD', 'true task')
18
19
## Dataset/method options
20
flags.DEFINE_integer('n_classes', 2, 'number of classes used in classification (e.g. binary classification)')
21
22
## Training options
23
flags.DEFINE_string('method', 'rnn', 'deep learning methods for modeling')
24
flags.DEFINE_integer('pretrain_iterations', 20000, 'number of pre-training iterations')
25
flags.DEFINE_integer('metatrain_iterations', 10000, 'number of metatraining iterations') # 15k for omniglot, 50k for sinusoid
26
flags.DEFINE_integer('meta_batch_size', 8, 'number of tasks sampled per meta-update')
27
flags.DEFINE_integer('update_batch_size', 16, 'number of samples used for inner gradient update (K for K-shot learning)')
28
flags.DEFINE_float('meta_lr', 0.0001, 'the base learning rate of the generator')
29
flags.DEFINE_float('update_lr', 1e-3, 'step size alpha for inner gradient update')
30
flags.DEFINE_integer('num_updates', 4, 'number of inner gradient updates during training')
31
flags.DEFINE_integer('n_total_batches', 100000, 'total batches generated by random sampling')
32
33
34
## Model options
35
flags.DEFINE_string('norm', 'None', 'batch_norm, layer_norm, or None')
36
flags.DEFINE_bool('stop_grad', False, 'if True, do not use second derivatives in meta-optimization (for speed)')
37
flags.DEFINE_bool('isReg', True, 'if True, compute regularization of weights and bias')
38
flags.DEFINE_float('dropout', 0.5, 'drop out when modeling, with probability keep_prob')
39
40
## Logging, saving, and testing options
41
flags.DEFINE_integer('run_time', 1, 're-run for stable analysis')
42
flags.DEFINE_bool('train', True, 'True to train, False to test directly')
43
flags.DEFINE_bool('test', True, 'True to test, no matter the model is trained')
44
flags.DEFINE_bool('finetune', False, 'True to finetunning furthermore, after meta-learning')
45
flags.DEFINE_bool('log', True, 'if false, do not log summaries, for debugging code')
46
flags.DEFINE_string('logdir', 'model/', 'directory for summaries and checkpoints')
47
flags.DEFINE_bool('resume', False, 'resume training if there is a model available')
48
flags.DEFINE_integer('test_iter', -1, 'iteration to load model (-1 for latest model)')
49
flags.DEFINE_integer('train_update_batch_size', -1, 'number of examples used for gradient update during training (use if you want to test with a different number)')
50
flags.DEFINE_float('train_update_lr', -1, 'value of inner gradient step step during training. (use if you want to test with a different value)') # 0.1 for omniglot
51
52
53
def train(data_loader, ifold, exp_string):
54
    # construct MetaPred model
55
    print ("constructing MetaPred model ...")
56
    m1 = model.MetaPred(data_loader, FLAGS.meta_lr, FLAGS.update_lr)
57
    # fitting the meta-learning model
58
    print ("model training...")
59
    sess = m1.fit(data_loader.episode, data_loader.episode_val[ifold], ifold, exp_string)
60
    return m1, sess
61
62
def test(data_loader, ifold, m, sess, exp_string):
63
    # meta-testing the model
64
    print ("model test...")
65
    data_tuple_val = (data_loader.data_s, data_loader.data_tt_val[ifold], data_loader.label_s, data_loader.label_tt_val[ifold])
66
    test_accs, test_aucs, test_ap, test_f1s = m.evaluate(data_loader.episode_val[ifold], data_tuple_val, sess=sess, prefix="metatest_")
67
    print('Test results: ' + "ifold: " + str(ifold) + ": tAcc: " + str(test_accs) + \
68
               ", tAuc: " + str(test_aucs) + ", tAP: "  + str(test_ap) + ", tF1: "  + str(test_f1s))
69
    return test_accs, test_aucs, test_ap, test_f1s
70
71
72
def fine_tune(data_loader, ifold, meta_m, weights_for_finetune, exp_string):
73
    # construct MetaPred model
74
    is_finetune = True
75
    print ("finetunning MetaPred model ...")
76
    if FLAGS.method == "cnn":
77
        m2 = finetune.CNN(data_loader, weights_for_finetune, freeze_opt=freeze_opt, is_finetune=is_finetune)
78
    if FLAGS.method == "rnn":
79
        m2 = finetune.RNN(data_loader, weights_for_finetune, freeze_opt=freeze_opt, is_finetune=is_finetune)
80
    print ("model finetunning...")
81
82
    # model finetunning
83
    sess, _, _ = m2.fit(data_loader.tt_sample[ifold], data_loader.tt_label[ifold],
84
                  data_loader.tt_sample_val[ifold], data_loader.tt_label_val[ifold])
85
    return m2, sess
86
87
88
def save_results(metatest, exp_string):
89
    out_filename = "results/res_" + exp_string
90
    with open(out_filename, 'w') as f:
91
        writer = csv.writer(f, delimiter=',')
92
        for key in metatest:
93
            writer.writerow([np.mean(np.array(metatest[key]))])
94
            writer.writerow([np.std(np.array(metatest[key]))])
95
    print ("results saved")
96
97
def save_weights(meta_m, source, target, true_target, data_loader, ifold):
98
    with open("weights/meta-" + FLAGS.method + ".weights" + ".source_" + "-".join(source) + ".starget_" + "".join(target) + ".ttarget_" + "".join(true_target) + ".pkl", 'wb') as f:
99
        pkl.dump((meta_m.weights_for_finetune), f, protocol=2)
100
        f.close()
101
    with open("weights/meta-" + FLAGS.method + ".tt_train" + ".source_" + "-".join(source) + ".starget_" + "".join(target) + ".ttarget_" + "".join(true_target) + ".pkl", 'wb') as f:
102
        pkl.dump((data_loader.tt_sample[ifold], data_loader.tt_label[ifold]), f, protocol=2)
103
        f.close()
104
    with open("weights/meta-" + FLAGS.method + ".tt_val" + ".source_" + "-".join(source) + ".starget_" + "".join(target) + ".ttarget_" + "".join(true_target) + ".pkl", 'wb') as f:
105
        pkl.dump((data_loader.tt_sample_val[ifold], data_loader.tt_label_val[ifold]), f, protocol=2)
106
        f.close()
107
    print("model weights saved")
108
109
110
def main():
111
    print (FLAGS.method)
112
    # set source and simulated target for training
113
    print ('task setting: ')
114
    source = [FLAGS.source]
115
    target = [FLAGS.target]
116
    true_target = [FLAGS.true_target]
117
118
    print ("The applied source tasks are: ", " ".join(source))
119
    print ("The simulated target task is: ", " ".join(target))
120
    print ("The true target task is: ", " ".join(true_target))
121
    n_tasks = len(source) + len(target)
122
123
124
    # load ehrs data
125
    data_loader = DataLoader(source, target, true_target, n_tasks,
126
                             FLAGS.update_batch_size, FLAGS.meta_batch_size)
127
128
    exp_string = 'stsk_'+str('&'.join(source))+'ttsk_'+str('&'.join(target))+'.mbs_'+str(FLAGS.meta_batch_size) + \
129
                       '.ubs_' + str(FLAGS.update_batch_size) + '.numstep' + str(FLAGS.num_updates) + '.updatelr' + str(FLAGS.update_lr)
130
131
    metatest = {'aucroc': [], 'avepre': [], 'f1score': []} # n_fold result
132
    n_fold = data_loader.n_fold
133
    for ifold in range(n_fold):
134
        print ("----------The %d-th fold-----------" %(ifold+1))
135
        meta_model = None
136
137
        if FLAGS.train:
138
            meta_model, sess = train(data_loader, ifold, exp_string)
139
            save_weights(meta_model, source, target, true_target, data_loader, ifold)
140
141
        if FLAGS.finetune:
142
             with open("weights/meta-" + FLAGS.method + ".weights" + ".source_" + "-".join(source) + ".starget_" + "".join(target) + ".ttarget_" + "".join(true_target) + ".pkl", 'rb') as f:
143
                 weights_for_finetune = pkl.load(f)
144
                 f.close()
145
             model, sess = fine_tune(data_loader, ifold, meta_model, weights_for_finetune, exp_string)
146
147
        if FLAGS.test:
148
            _, test_aucs, test_ap, test_f1s = test(data_loader, ifold, meta_model, sess, exp_string)
149
            metatest['aucroc'].append(test_aucs)
150
            metatest['avepre'].append(test_ap)
151
            metatest['f1score'].append(test_f1s)
152
153
    # show results
154
    print ('--------------- model setting ---------------')
155
    print('source: ', " ".join(source), 'simulated target: ', " ".join(target), 'true target: ', " ".join(true_target))
156
    print('method:', 'meta-' + FLAGS.method, 'meta-bz:', FLAGS.meta_batch_size, 'update-bz:', FLAGS.update_batch_size, \
157
          'num update:', FLAGS.num_updates, 'meta-lr:', FLAGS.meta_lr, 'update-lr:', FLAGS.update_lr)
158
159
    print ('--------------- 5fold results ---------------')
160
    print ('aucroc mean: ', np.mean(np.array(metatest['aucroc'])))
161
    print ('aucroc std: ', np.std(np.array(metatest['aucroc'])))
162
    print ('f1score mean: ', np.mean(np.array(metatest['f1score'])))
163
    print ('f1score std: ', np.std(np.array(metatest['f1score'])))
164
    save_results(metatest, exp_string)
165
166
if __name__ == "__main__":
167
    main()