|
a |
|
b/main.py |
|
|
1 |
""" Code for the Main function of MetaPred. """ |
|
|
2 |
import os, csv |
|
|
3 |
import numpy as np |
|
|
4 |
import random |
|
|
5 |
import pickle as pkl |
|
|
6 |
import tensorflow as tf |
|
|
7 |
import copy |
|
|
8 |
from tensorflow.python.platform import flags |
|
|
9 |
|
|
|
10 |
from data_loader import DataLoader |
|
|
11 |
import model, finetune |
|
|
12 |
|
|
|
13 |
|
|
|
14 |
FLAGS = flags.FLAGS |
|
|
15 |
flags.DEFINE_string('source', 'AD', 'source task') |
|
|
16 |
flags.DEFINE_string('target', 'MCI', 'simulated task') |
|
|
17 |
flags.DEFINE_string('true_target', 'PD', 'true task') |
|
|
18 |
|
|
|
19 |
## Dataset/method options |
|
|
20 |
flags.DEFINE_integer('n_classes', 2, 'number of classes used in classification (e.g. binary classification)') |
|
|
21 |
|
|
|
22 |
## Training options |
|
|
23 |
flags.DEFINE_string('method', 'rnn', 'deep learning methods for modeling') |
|
|
24 |
flags.DEFINE_integer('pretrain_iterations', 20000, 'number of pre-training iterations') |
|
|
25 |
flags.DEFINE_integer('metatrain_iterations', 10000, 'number of metatraining iterations') # 15k for omniglot, 50k for sinusoid |
|
|
26 |
flags.DEFINE_integer('meta_batch_size', 8, 'number of tasks sampled per meta-update') |
|
|
27 |
flags.DEFINE_integer('update_batch_size', 16, 'number of samples used for inner gradient update (K for K-shot learning)') |
|
|
28 |
flags.DEFINE_float('meta_lr', 0.0001, 'the base learning rate of the generator') |
|
|
29 |
flags.DEFINE_float('update_lr', 1e-3, 'step size alpha for inner gradient update') |
|
|
30 |
flags.DEFINE_integer('num_updates', 4, 'number of inner gradient updates during training') |
|
|
31 |
flags.DEFINE_integer('n_total_batches', 100000, 'total batches generated by random sampling') |
|
|
32 |
|
|
|
33 |
|
|
|
34 |
## Model options |
|
|
35 |
flags.DEFINE_string('norm', 'None', 'batch_norm, layer_norm, or None') |
|
|
36 |
flags.DEFINE_bool('stop_grad', False, 'if True, do not use second derivatives in meta-optimization (for speed)') |
|
|
37 |
flags.DEFINE_bool('isReg', True, 'if True, compute regularization of weights and bias') |
|
|
38 |
flags.DEFINE_float('dropout', 0.5, 'drop out when modeling, with probability keep_prob') |
|
|
39 |
|
|
|
40 |
## Logging, saving, and testing options |
|
|
41 |
flags.DEFINE_integer('run_time', 1, 're-run for stable analysis') |
|
|
42 |
flags.DEFINE_bool('train', True, 'True to train, False to test directly') |
|
|
43 |
flags.DEFINE_bool('test', True, 'True to test, no matter the model is trained') |
|
|
44 |
flags.DEFINE_bool('finetune', False, 'True to finetunning furthermore, after meta-learning') |
|
|
45 |
flags.DEFINE_bool('log', True, 'if false, do not log summaries, for debugging code') |
|
|
46 |
flags.DEFINE_string('logdir', 'model/', 'directory for summaries and checkpoints') |
|
|
47 |
flags.DEFINE_bool('resume', False, 'resume training if there is a model available') |
|
|
48 |
flags.DEFINE_integer('test_iter', -1, 'iteration to load model (-1 for latest model)') |
|
|
49 |
flags.DEFINE_integer('train_update_batch_size', -1, 'number of examples used for gradient update during training (use if you want to test with a different number)') |
|
|
50 |
flags.DEFINE_float('train_update_lr', -1, 'value of inner gradient step step during training. (use if you want to test with a different value)') # 0.1 for omniglot |
|
|
51 |
|
|
|
52 |
|
|
|
53 |
def train(data_loader, ifold, exp_string): |
|
|
54 |
# construct MetaPred model |
|
|
55 |
print ("constructing MetaPred model ...") |
|
|
56 |
m1 = model.MetaPred(data_loader, FLAGS.meta_lr, FLAGS.update_lr) |
|
|
57 |
# fitting the meta-learning model |
|
|
58 |
print ("model training...") |
|
|
59 |
sess = m1.fit(data_loader.episode, data_loader.episode_val[ifold], ifold, exp_string) |
|
|
60 |
return m1, sess |
|
|
61 |
|
|
|
62 |
def test(data_loader, ifold, m, sess, exp_string): |
|
|
63 |
# meta-testing the model |
|
|
64 |
print ("model test...") |
|
|
65 |
data_tuple_val = (data_loader.data_s, data_loader.data_tt_val[ifold], data_loader.label_s, data_loader.label_tt_val[ifold]) |
|
|
66 |
test_accs, test_aucs, test_ap, test_f1s = m.evaluate(data_loader.episode_val[ifold], data_tuple_val, sess=sess, prefix="metatest_") |
|
|
67 |
print('Test results: ' + "ifold: " + str(ifold) + ": tAcc: " + str(test_accs) + \ |
|
|
68 |
", tAuc: " + str(test_aucs) + ", tAP: " + str(test_ap) + ", tF1: " + str(test_f1s)) |
|
|
69 |
return test_accs, test_aucs, test_ap, test_f1s |
|
|
70 |
|
|
|
71 |
|
|
|
72 |
def fine_tune(data_loader, ifold, meta_m, weights_for_finetune, exp_string): |
|
|
73 |
# construct MetaPred model |
|
|
74 |
is_finetune = True |
|
|
75 |
print ("finetunning MetaPred model ...") |
|
|
76 |
if FLAGS.method == "cnn": |
|
|
77 |
m2 = finetune.CNN(data_loader, weights_for_finetune, freeze_opt=freeze_opt, is_finetune=is_finetune) |
|
|
78 |
if FLAGS.method == "rnn": |
|
|
79 |
m2 = finetune.RNN(data_loader, weights_for_finetune, freeze_opt=freeze_opt, is_finetune=is_finetune) |
|
|
80 |
print ("model finetunning...") |
|
|
81 |
|
|
|
82 |
# model finetunning |
|
|
83 |
sess, _, _ = m2.fit(data_loader.tt_sample[ifold], data_loader.tt_label[ifold], |
|
|
84 |
data_loader.tt_sample_val[ifold], data_loader.tt_label_val[ifold]) |
|
|
85 |
return m2, sess |
|
|
86 |
|
|
|
87 |
|
|
|
88 |
def save_results(metatest, exp_string): |
|
|
89 |
out_filename = "results/res_" + exp_string |
|
|
90 |
with open(out_filename, 'w') as f: |
|
|
91 |
writer = csv.writer(f, delimiter=',') |
|
|
92 |
for key in metatest: |
|
|
93 |
writer.writerow([np.mean(np.array(metatest[key]))]) |
|
|
94 |
writer.writerow([np.std(np.array(metatest[key]))]) |
|
|
95 |
print ("results saved") |
|
|
96 |
|
|
|
97 |
def save_weights(meta_m, source, target, true_target, data_loader, ifold): |
|
|
98 |
with open("weights/meta-" + FLAGS.method + ".weights" + ".source_" + "-".join(source) + ".starget_" + "".join(target) + ".ttarget_" + "".join(true_target) + ".pkl", 'wb') as f: |
|
|
99 |
pkl.dump((meta_m.weights_for_finetune), f, protocol=2) |
|
|
100 |
f.close() |
|
|
101 |
with open("weights/meta-" + FLAGS.method + ".tt_train" + ".source_" + "-".join(source) + ".starget_" + "".join(target) + ".ttarget_" + "".join(true_target) + ".pkl", 'wb') as f: |
|
|
102 |
pkl.dump((data_loader.tt_sample[ifold], data_loader.tt_label[ifold]), f, protocol=2) |
|
|
103 |
f.close() |
|
|
104 |
with open("weights/meta-" + FLAGS.method + ".tt_val" + ".source_" + "-".join(source) + ".starget_" + "".join(target) + ".ttarget_" + "".join(true_target) + ".pkl", 'wb') as f: |
|
|
105 |
pkl.dump((data_loader.tt_sample_val[ifold], data_loader.tt_label_val[ifold]), f, protocol=2) |
|
|
106 |
f.close() |
|
|
107 |
print("model weights saved") |
|
|
108 |
|
|
|
109 |
|
|
|
110 |
def main(): |
|
|
111 |
print (FLAGS.method) |
|
|
112 |
# set source and simulated target for training |
|
|
113 |
print ('task setting: ') |
|
|
114 |
source = [FLAGS.source] |
|
|
115 |
target = [FLAGS.target] |
|
|
116 |
true_target = [FLAGS.true_target] |
|
|
117 |
|
|
|
118 |
print ("The applied source tasks are: ", " ".join(source)) |
|
|
119 |
print ("The simulated target task is: ", " ".join(target)) |
|
|
120 |
print ("The true target task is: ", " ".join(true_target)) |
|
|
121 |
n_tasks = len(source) + len(target) |
|
|
122 |
|
|
|
123 |
|
|
|
124 |
# load ehrs data |
|
|
125 |
data_loader = DataLoader(source, target, true_target, n_tasks, |
|
|
126 |
FLAGS.update_batch_size, FLAGS.meta_batch_size) |
|
|
127 |
|
|
|
128 |
exp_string = 'stsk_'+str('&'.join(source))+'ttsk_'+str('&'.join(target))+'.mbs_'+str(FLAGS.meta_batch_size) + \ |
|
|
129 |
'.ubs_' + str(FLAGS.update_batch_size) + '.numstep' + str(FLAGS.num_updates) + '.updatelr' + str(FLAGS.update_lr) |
|
|
130 |
|
|
|
131 |
metatest = {'aucroc': [], 'avepre': [], 'f1score': []} # n_fold result |
|
|
132 |
n_fold = data_loader.n_fold |
|
|
133 |
for ifold in range(n_fold): |
|
|
134 |
print ("----------The %d-th fold-----------" %(ifold+1)) |
|
|
135 |
meta_model = None |
|
|
136 |
|
|
|
137 |
if FLAGS.train: |
|
|
138 |
meta_model, sess = train(data_loader, ifold, exp_string) |
|
|
139 |
save_weights(meta_model, source, target, true_target, data_loader, ifold) |
|
|
140 |
|
|
|
141 |
if FLAGS.finetune: |
|
|
142 |
with open("weights/meta-" + FLAGS.method + ".weights" + ".source_" + "-".join(source) + ".starget_" + "".join(target) + ".ttarget_" + "".join(true_target) + ".pkl", 'rb') as f: |
|
|
143 |
weights_for_finetune = pkl.load(f) |
|
|
144 |
f.close() |
|
|
145 |
model, sess = fine_tune(data_loader, ifold, meta_model, weights_for_finetune, exp_string) |
|
|
146 |
|
|
|
147 |
if FLAGS.test: |
|
|
148 |
_, test_aucs, test_ap, test_f1s = test(data_loader, ifold, meta_model, sess, exp_string) |
|
|
149 |
metatest['aucroc'].append(test_aucs) |
|
|
150 |
metatest['avepre'].append(test_ap) |
|
|
151 |
metatest['f1score'].append(test_f1s) |
|
|
152 |
|
|
|
153 |
# show results |
|
|
154 |
print ('--------------- model setting ---------------') |
|
|
155 |
print('source: ', " ".join(source), 'simulated target: ', " ".join(target), 'true target: ', " ".join(true_target)) |
|
|
156 |
print('method:', 'meta-' + FLAGS.method, 'meta-bz:', FLAGS.meta_batch_size, 'update-bz:', FLAGS.update_batch_size, \ |
|
|
157 |
'num update:', FLAGS.num_updates, 'meta-lr:', FLAGS.meta_lr, 'update-lr:', FLAGS.update_lr) |
|
|
158 |
|
|
|
159 |
print ('--------------- 5fold results ---------------') |
|
|
160 |
print ('aucroc mean: ', np.mean(np.array(metatest['aucroc']))) |
|
|
161 |
print ('aucroc std: ', np.std(np.array(metatest['aucroc']))) |
|
|
162 |
print ('f1score mean: ', np.mean(np.array(metatest['f1score']))) |
|
|
163 |
print ('f1score std: ', np.std(np.array(metatest['f1score']))) |
|
|
164 |
save_results(metatest, exp_string) |
|
|
165 |
|
|
|
166 |
if __name__ == "__main__": |
|
|
167 |
main() |