a b/model.py
1
""" Code for the MetaPred algorithm and network architecture. """
2
import numpy as np
3
import sklearn
4
import tensorflow as tf
5
import os, time, shutil, collections
6
7
import tensorflow.contrib.layers as layers
8
from tensorflow.contrib.rnn import RNNCell
9
from tensorflow.python.platform import flags
10
11
FLAGS = flags.FLAGS
12
13
PADDING_ID = 1016
14
WORDS_NUM = 1017
15
MASK_ARRAY = [[1.]] * PADDING_ID + [[0.]] + [[1.]] * (WORDS_NUM - PADDING_ID - 1)
16
17
18
SUMMARY_INTERVAL = 100
19
SAVE_INTERVAL = 1000
20
PRINT_INTERVAL = 100
21
TEST_PRINT_INTERVAL = PRINT_INTERVAL*5
22
23
class BaseModel(object):
24
    """
25
    Base Model for basic networks with sequential data, i.e., RNN, CNN.
26
    """
27
    def __init__(self):
28
        self.regularizers = []
29
30
    def convert_to_array(self, data):
31
        '''convert other type to numpy array'''
32
        if type(data) is not np.ndarray:
33
            # data = np.array(data)
34
            data = data.toarray()  # convert sparse matrices
35
        return data
36
37
     # Helper methods.
38
    def _get_path(self, folder):
39
        path = '../../models/'
40
        return os.path.join(path, folder, self.dir_name)
41
42
    def _get_session(self, sess=None):
43
        '''Restore parameters if no session given.'''
44
        if sess is None:
45
            sess = tf.Session(graph=self.graph)
46
            filename = tf.train.latest_checkpoint(self._get_path('checkpoints'))
47
            self.op_saver.restore(sess, filename)
48
        return sess
49
50
    def _get_prediction(self, logits):
51
        '''Return the predicted classes.'''
52
        with tf.name_scope('prediction'):
53
            prediction = tf.argmax(logits, axis=1)
54
        return prediction
55
56
    def loss_func(self, pred, label):
57
        '''cross entropy'''
58
        # Note - with tf version <=0.12, this loss has incorrect 2nd derivatives
59
        label = tf.one_hot(label, FLAGS.n_classes)
60
        return tf.nn.softmax_cross_entropy_with_logits_v2(logits=pred, labels=label) / FLAGS.update_batch_size
61
62
63
class MetaPred(BaseModel):
64
    def __init__(self, data_loader, meta_lr=1e-3, update_lr=1e-2, test_num_updates=-1):
65
        """
66
        Args:
67
            dim_input: dimension of input data (for mlps)
68
            n_tasks: task number including both source and target
69
            meta_lr: the base learning rate of the generator
70
            update_lr: step size alpha for inner gradient update
71
        """
72
        super().__init__()
73
74
        self.data_loader = data_loader
75
        self.dim_input = data_loader.dim_input
76
        self.n_tasks = data_loader.n_tasks
77
        self.meta_lr = meta_lr
78
        self.update_lr = update_lr
79
        self.test_num_updates = test_num_updates
80
        self.auc_stable = []
81
        self.f1s_stable = []
82
83
        self.weights_for_finetune = dict() # to store the value of learned params
84
85
        print('method:', "meta-"+FLAGS.method, 'data shape:', self.dim_input, 'meta-bz:', FLAGS.meta_batch_size, 'update-bz:', FLAGS.update_batch_size, \
86
             'num update:', FLAGS.num_updates, 'meta-lr:', meta_lr, 'update-lr:', update_lr)
87
88
        if FLAGS.method == "cnn":
89
            # sequential network (cnn) configuration
90
            self.cnn_config(data_loader)
91
        elif FLAGS.method == "rnn":
92
            # sequential network (cnn) configuration
93
            self.rnn_config(data_loader)
94
95
        # Build the computational graph.
96
        self.build_graph()
97
98
    ####################################### Networks #######################################
99
    def weight_variable(self, shape, name='weights'):
100
        if FLAGS.pretrain:
101
            initial = self.pretrain_weights[name]
102
            var = tf.Variable(initial_value=initial, name=name)
103
        else:
104
            initial = tf.truncated_normal_initializer(0, 0.1)
105
            var = tf.get_variable(name, shape, tf.float32, initializer=initial)
106
107
        if FLAGS.isReg:
108
            self.regularizers.append(tf.nn.l2_loss(var))
109
        tf.summary.histogram(var.op.name, var)
110
        return var
111
112
    def bias_variable(self, shape, initial=None, name='bias'):
113
        if FLAGS.pretrain:
114
            initial = self.pretrain_weights[name]
115
            var = tf.Variable(initial_value=initial, name=name)
116
        else:
117
            initial = tf.constant_initializer(0.1)
118
            var = tf.get_variable(name, shape, tf.float32, initializer=initial)
119
120
        if FLAGS.isReg:
121
            self.regularizers.append(tf.nn.l2_loss(var))
122
        tf.summary.histogram(var.op.name, var)
123
        return var
124
125
    ############################### Fully Conneted Network #################################
126
    # construct weights
127
    def build_fc_weights(self, dim_in, weights):
128
        for i, dim in enumerate(self.dim_hidden):
129
            dim_out = dim
130
            weights["fc_W"+str(i)] = self.weight_variable([int(dim_in), dim_out], name="fc_W"+str(i))
131
            weights["fc_b"+str(i)] = self.bias_variable([dim_out], name="fc_b"+str(i))
132
            dim_in = dim_out
133
        return weights
134
135
    def fc(self, x, W, b, relu=True):
136
        """Fully connected layer with Mout features."""
137
        x = tf.matmul(x, W) + b
138
        return tf.nn.relu(x) if relu else x
139
140
    ############################ Embedding Layer for SeqNet ################################
141
    def build_emb_weights(self, weights):
142
        weights["emb_W"] = tf.Variable(tf.random_normal([self.n_words, self.n_hidden], stddev=self.init_std), name="emb_W")
143
        with tf.variable_scope("emb", reuse=tf.AUTO_REUSE) as scope:
144
            weights["emb_mask_W"] = tf.get_variable("mask_padding", initializer=MASK_ARRAY, dtype="float32", trainable=False)
145
        return weights
146
147
    def embedding(self, x, Wemb, Wemb_mask):
148
        _x = tf.nn.embedding_lookup(Wemb, x) # recs size is (batch_size, timesteps, code_size)
149
        _x_mask = tf.nn.embedding_lookup(Wemb_mask, x)
150
        # print (_x.get_shape())
151
        # print (_x_mask.get_shape())
152
        emb_vecs = tf.multiply(_x, _x_mask)
153
        emb_vecs = tf.reduce_sum(emb_vecs, 2)
154
        # print (emb_vecs.get_shape())
155
        return emb_vecs
156
157
    ############################ Convolutional Neural Network ##############################
158
    def cnn_config(self, data_loader, init_std=0.05):
159
        # Network Parameters
160
        self.init_std = init_std
161
        self.n_hidden = 256 # hidden dimensions of embedding
162
        self.n_hidden_1 = 128
163
        self.n_hidden_2 = 128
164
        self.n_words = data_loader.n_words
165
        self.n_classes = FLAGS.n_classes
166
        self.n_filters = 128
167
        self.num_input =  data_loader.dim_input
168
        self.timesteps = data_loader.timesteps
169
        self.code_size = data_loader.code_size
170
        self.dim_hidden = [self.n_hidden_1, self.n_hidden_2, FLAGS.n_classes] # for AD
171
        self.filter_sizes = [3, 4, 5]
172
        self.learner = self.cnn_sequential
173
174
    def build_conv_weights(self, weights):
175
        for i, filter_size in enumerate(self.filter_sizes):
176
            filter_shape = [filter_size, self.n_hidden, 1, self.n_filters]
177
            weights["conv_W"+str(filter_size)] = tf.Variable(tf.truncated_normal(filter_shape, stddev=0.1), name="conv_W"+str(filter_size))
178
            weights["conv_b"+str(filter_size)] = tf.Variable(tf.constant(0.1, shape=[self.n_filters]), name="conv_b"+str(filter_size))
179
        return weights
180
181
    def conv(self, emb_vecs, weights, is_training=True):
182
        '''Create a convolution + maxpool layer for each filter size'''
183
        pooled_outputs = []
184
        emb_expanded = tf.expand_dims(emb_vecs, -1)
185
        # print(emb_expanded.get_shape())
186
        for i, filter_size in enumerate(self.filter_sizes):
187
            W = weights["conv_W"+str(filter_size)]
188
            b = weights["conv_b"+str(filter_size)]
189
            with tf.variable_scope("conv-maxpool-%s" % filter_size):
190
                # Convolution Layer
191
                conv_ = tf.nn.conv2d(
192
                    emb_expanded,
193
                    W,
194
                    strides=[1, 1, 1, 1],
195
                    padding="VALID",
196
                    name="conv")
197
                # Apply nonlinearity
198
                h = tf.nn.leaky_relu(tf.nn.bias_add(conv_, b), name="relu")
199
                with tf.name_scope("bnorm{}".format(filter_size)) as scope:
200
                    h = layers.batch_norm(h, updates_collections=None,
201
                                             decay=0.99,
202
                                             scale=True, center=True,
203
                                             is_training=is_training, reuse=tf.AUTO_REUSE, scope=scope)
204
                # Maxpooling over the outputs
205
                pooled = tf.nn.max_pool(
206
                h,
207
                ksize=[1, self.timesteps - filter_size + 1, 1, 1],
208
                strides=[1, 1, 1, 1],
209
                padding='VALID',
210
                name="pool")
211
                pooled_outputs.append(pooled)
212
213
        # Combine all the pooled features
214
        num_filters_total = self.n_filters * len(self.filter_sizes)
215
        h_pool = tf.concat(pooled_outputs, 3)
216
        h_pool_flat = tf.reshape(h_pool, [-1, num_filters_total])
217
        return h_pool_flat
218
219
    def cnn_sequential(self, x, weights, dropout, reuse=False, is_training=True, type="source"):
220
        xemb = self.embedding(x, weights["emb_W"], weights["emb_mask_W"])
221
222
        # convolutional network
223
        hout = self.conv(xemb, weights, is_training)
224
225
        h_ = layers.dropout(hout, keep_prob=dropout)
226
227
        for i, dim in enumerate(self.dim_hidden[:-1]):
228
            h_ = self.fc(h_, weights["fc_W"+str(i)], weights["fc_b"+str(i)])
229
            h_ = tf.nn.dropout(h_, dropout)
230
231
        # Logits linear layer, i.e. softmax without normalization.
232
        N, Min = h_.get_shape()
233
        i = len(self.dim_hidden)-1
234
        logits = self.fc(h_, weights["fc_W"+str(i)], weights["fc_b"+str(i)], relu=False)
235
        return logits
236
237
    ############################ Recurrent Neural Network ##############################
238
    def rnn_config(self, data_loader, init_std=0.05):
239
        # Network Parameters
240
        self.init_std = init_std
241
        self.n_hidden = 256 # hidden dimensions of embedding
242
        self.n_hidden_1 = 128
243
        self.n_hidden_2 = 128
244
        self.n_words = data_loader.n_words
245
        self.num_input = data_loader.dim_input
246
        self.n_classes = FLAGS.n_classes
247
        self.timesteps = data_loader.timesteps
248
        self.code_size = data_loader.code_size
249
        self.dim_hidden = [self.n_hidden_1, self.n_hidden_2, FLAGS.n_classes]
250
        self.learner = self.rnn_sequential
251
252
    def build_lstm_weights(self, weights):
253
        # # Keep W_xh and W_hh separate here as well to reuse initialization methods
254
        # with tf.variable_scope(scope or type(self).__name__):
255
        weights["lstm_W_xh"] = tf.get_variable('lstm_W_xh', [self.n_hidden, 4 * self.n_hidden],
256
                               initializer=self.orthogonal_initializer())
257
        weights["lstm_W_hh"] = tf.get_variable('lstm_W_hh', [self.n_hidden, 4 * self.n_hidden],
258
                               initializer=self.lstm_identity_initializer(0.95),)
259
        weights["lstm_b"] = tf.get_variable('lstm_b', [4 * self.n_hidden])
260
        return weights
261
262
    def lstm_identity_initializer(self, scale):
263
        def _initializer(shape, dtype=tf.float32, partition_info=None):
264
            """Ugly cause LSTM params calculated in one matrix multiply"""
265
            size = shape[0]
266
            t = np.zeros(shape)
267
            t[:, size:size * 2] = np.identity(size) * scale
268
            t[:, :size] = self.orthogonal([size, size])
269
            t[:, size * 2:size * 3] = self.orthogonal([size, size])
270
            t[:, size * 3:] = self.orthogonal([size, size])
271
            return tf.constant(t, dtype=dtype)
272
        return _initializer
273
274
    def orthogonal_initializer(self):
275
        def _initializer(shape, dtype=tf.float32, partition_info=None):
276
            return tf.constant(self.orthogonal(shape), dtype)
277
        return _initializer
278
279
    def orthogonal(self, shape):
280
        flat_shape = (shape[0], np.prod(shape[1:]))
281
        a = np.random.normal(0.0, 1.0, flat_shape)
282
        u, _, v = np.linalg.svd(a, full_matrices=False)
283
        q = u if u.shape == flat_shape else v
284
        return q.reshape(shape)
285
286
    def rnn_sequential(self, x, weights, dropout, reuse=False, is_training=True, type='source'):
287
        # embedding
288
        xemb = self.embedding(x, weights["emb_W"], weights["emb_mask_W"])
289
290
        # recurrent neural networks
291
        xemb = tf.unstack(xemb, self.timesteps, 1)
292
        lstm_cell = LSTMCell(self.n_hidden, weights["lstm_W_xh"], weights["lstm_W_hh"], weights["lstm_b"])
293
        #c, h
294
        if type == "source":
295
            W_state_c = tf.random_normal([(self.n_tasks-1)*FLAGS.update_batch_size, self.n_hidden], stddev=0.1)
296
            W_state_h = tf.random_normal([(self.n_tasks-1)*FLAGS.update_batch_size, self.n_hidden], stddev=0.1)
297
        elif type == "target":
298
            W_state_c = tf.random_normal([FLAGS.update_batch_size, self.n_hidden], stddev=0.1)
299
            W_state_h = tf.random_normal([FLAGS.update_batch_size, self.n_hidden], stddev=0.1)
300
        # outputs, state = tf.nn.dynamic_rnn(lstm_cell, xemb, initial_state=(W_state_c, W_state_h), dtype=tf.float32)
301
        outputs, state = tf.nn.static_rnn(lstm_cell, xemb, initial_state=(W_state_c, W_state_h), dtype=tf.float32)
302
        _, hout = state
303
304
        with tf.variable_scope("dropout"):
305
            h_ = layers.dropout(hout, keep_prob=dropout)
306
307
        for i, dim in enumerate(self.dim_hidden[:-1]):
308
            h_ = self.fc(h_, weights["fc_W"+str(i)], weights["fc_b"+str(i)])
309
            h_ = tf.nn.dropout(h_, dropout)
310
311
        x_rep = tf.identity(h_)
312
313
        # Logits linear layer, i.e. softmax without normalization.
314
        N, Min = h_.get_shape()
315
        i = len(self.dim_hidden)-1
316
        logits = self.fc(h_, weights["fc_W"+str(i)], weights["fc_b"+str(i)], relu=False)
317
        return logits, x_rep
318
319
320
    def build_graph(self):
321
        """Build the computational graph of the model."""
322
        self.graph = tf.Graph()
323
        with self.graph.as_default():
324
            # Inputs.
325
            with tf.name_scope('inputs'):
326
                self.input_s = tf.placeholder(tf.int32, (FLAGS.meta_batch_size, (self.n_tasks-1) * FLAGS.update_batch_size, self.timesteps, self.code_size), 'source_x')
327
                self.input_t = tf.placeholder(tf.int32, (FLAGS.meta_batch_size, FLAGS.update_batch_size, self.timesteps, self.code_size), 'target_x')
328
                self.label_s = tf.placeholder(tf.int64, (FLAGS.meta_batch_size, (self.n_tasks-1) * FLAGS.update_batch_size), 'source_y')
329
                self.label_t = tf.placeholder(tf.int64, (FLAGS.meta_batch_size, FLAGS.update_batch_size), 'target_y')
330
331
                self.ph_training = tf.placeholder(tf.bool, name='trainingFlag')
332
                self.ph_dropout = tf.placeholder(tf.float32, (), 'dropout')
333
334
            # Model.
335
            # construct metatrain_ and metaval_
336
            if FLAGS.method == "cnn" or FLAGS.method == "rnn":
337
                self.build_model((self.input_s, self.input_t, self.label_s, self.label_t), prefix='metatrain_', is_training=self.ph_training)
338
339
            # Initialize variables, i.e. weights and biases.
340
            self.op_init = tf.global_variables_initializer()
341
            self.op_weights = self.get_op_variables()
342
343
            # Summaries for TensorBoard and Save for model parameters.
344
            self.op_summary = tf.summary.merge_all()
345
            self.op_saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=10)
346
            print ('graph built!')
347
348
        self.graph.finalize()
349
350
    def get_op_variables(self):
351
        if FLAGS.method == "cnn":
352
            op_weights = dict()
353
            op_var = tf.trainable_variables()
354
            # embedding
355
            op_weights["emb_W"] = [v for v in op_var if "emb_W" in v.name][0]
356
            # cnn
357
            for i, filter_size in enumerate(self.filter_sizes):
358
                op_weights["conv_W"+str(filter_size)] = [v for v in op_var if "conv_W"+str(filter_size) in v.name][0]
359
                op_weights["conv_b"+str(filter_size)] = [v for v in op_var if "conv_b"+str(filter_size) in v.name][0]
360
            # fully connected
361
            for i, dim in enumerate(self.dim_hidden):
362
                op_weights["fc_W"+str(i)] = [v for v in op_var if "fc_W"+str(i) in v.name][0]
363
                op_weights["fc_b"+str(i)] = [v for v in op_var if "fc_b"+str(i) in v.name][0]
364
        elif FLAGS.method == "rnn":
365
            op_weights = dict()
366
            op_var = tf.trainable_variables()
367
            # embedding
368
            op_weights["emb_W"] = [v for v in op_var if "emb_W" in v.name][0]
369
            # lstm
370
            op_weights["lstm_W_xh"] = [v for v in op_var if "lstm_W_xh" in v.name][0]
371
            op_weights["lstm_W_hh"] = [v for v in op_var if "lstm_W_hh" in v.name][0]
372
            op_weights["lstm_b"] = [v for v in op_var if "lstm_b" in v.name][0]
373
            # fully connected
374
            for i, dim in enumerate(self.dim_hidden):
375
                op_weights["fc_W"+str(i)] = [v for v in op_var if "fc_W"+str(i) in v.name ][0]
376
                op_weights["fc_b"+str(i)] = [v for v in op_var if "fc_b"+str(i) in v.name][0]
377
        return op_weights
378
379
    def build_weights(self):
380
        weights = {}
381
        if FLAGS.method == "cnn":
382
            weights = self.build_emb_weights(weights)
383
            weights = self.build_conv_weights(weights)
384
            weights = self.build_fc_weights(self.n_filters * len(self.filter_sizes), weights)
385
        elif FLAGS.method == "rnn":
386
            weights = self.build_emb_weights(weights)
387
            weights = self.build_lstm_weights(weights)
388
            weights = self.build_fc_weights(self.n_hidden, weights)
389
        return weights
390
391
392
    def build_model(self, input_tensors, prefix='metatrain_', is_training=True):
393
        """
394
        Args:
395
            input_tensors = []:
396
                source_xb:   [batch_size, (n_tasks-1)*update_batch_size, data_shape]
397
                source_yb:   [batch_size, (n_tasks-1)*update_batch_size, ]
398
                target_xb:   [batch_size, update_batch_size, data_shape]
399
                target_yb:   [batch_size, update_batch_size, ] i.e., querysz = 1
400
            # update_batch_size: number of examples used for inner gradient update (K for K tasks)
401
            # meta_batch_size: number of mate-batches sampled per meta-update
402
            prefix:        pretrain_/metatrain_/metaval_/metatest_, for training, we build train val and test network meanwhile.
403
        """
404
        # source: training data for inner gradient, target: test data for meta gradient
405
        source_xb, target_xb, source_yb, target_yb = input_tensors
406
407
        # create or reuse network variable, not including batch_norm variable, therefore we need extra reuse mechnism
408
        # to reuse batch_norm variables.
409
        with tf.variable_scope('model', reuse=tf.AUTO_REUSE) as training_scope:
410
            # Define the weights. weights is a dictionary
411
            self.weights = weights = self.build_weights()
412
413
            num_updates = max(self.test_num_updates, FLAGS.num_updates)
414
            # target_preds_tasks[i] and target_losses_tasks[i] is the output and loss after i+1 gradient updates
415
            source_pred_tasks, source_loss_tasks, source_acc_tasks, source_auc_tasks = [], [], [], [] # source and target has seperate loss
416
                                                                                                      # and accuracies
417
            target_losses_tasks = [[]]*num_updates # result of every updates for test data
418
            target_preds_tasks = [[]]*num_updates # prediction
419
            target_accs_tasks = [[]]*num_updates
420
            target_aucs_tasks = [[]]*num_updates
421
422
            def task_metalearn(input, reuse=True):
423
                """
424
                Perform gradient descent for one task in the meta-batch.
425
                Args:
426
                    source_x:   [(n_tasks-1)*update_batch_size, data_shape]
427
                    source_y:   [(n_tasks-1)*update_batch_size, ]
428
                    target_x:   [update_batch_size, data_shape]
429
                    target_y:   [update_batch_size, ]
430
                    training:   training or not, for batch_norm
431
                """
432
                source_x, target_x, source_y, target_y = input # map_fn only support one parameters, so we need to unpack from tuple
433
                # print (source_x.get_shape())
434
                # print (target_x.get_shape())
435
                # print (source_y.get_shape())
436
                # print (target_y.get_shape())
437
438
                # record the op in t update step, each element is the results of the upate step.
439
                target_preds, target_losses, target_accs, target_aucs, target_represents = [], [], [], [], []
440
441
                # That's, to create variable, you must turn off reuse
442
                source_pred, _ = self.learner(source_x, weights, self.ph_dropout, reuse=False, is_training=is_training, type="source")
443
                # print (source_pred.get_shape())
444
                source_loss = self.loss_func(source_pred, source_y)
445
                source_acc = tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(source_pred), 1), source_y)
446
447
                # compute gradients
448
                grads = tf.gradients(source_loss, list(weights.values()))
449
450
                if FLAGS.stop_grad: # if True, do not use second derivatives in meta-optimization (for speed)
451
                    grads = [tf.stop_gradient(grad) for grad in grads]
452
453
                # grad and variable dict
454
                gvs = dict(zip(weights.keys(), grads))
455
                # theta_pi = theta - alpha * grads
456
                fast_weights = dict(zip(weights.keys(), [weights[key] - tf.multiply(self.update_lr, gvs[key]) for key in weights.keys()]))
457
                # fast_weights = dict(zip(weights.keys(), [weights[key] - self.update_lr*gvs[key] for key in weights.keys()]))
458
459
                # use theta_pi for fast adaption
460
                target_pred, target_represent = self.learner(target_x, fast_weights, self.ph_dropout, reuse=True, is_training=is_training, type="target")
461
                target_loss = self.loss_func(target_pred, target_y)
462
                target_preds.append(target_pred)
463
                target_losses.append(target_loss)
464
                target_represents.append(target_represent)
465
466
                # continue to build T1-TK steps graph
467
                for _ in range(1, num_updates): # i.e., num_updates = 4, update 3 times
468
                    # T_k loss on meta-train
469
                    # we need meta-train loss to fine-tune the task and meta-test loss to update theta
470
                    loss = self.loss_func(self.learner(source_x, fast_weights, self.ph_dropout, reuse=True, is_training=is_training, type="source")[0], source_y)
471
                    # compute gradients
472
                    grads = tf.gradients(loss, list(fast_weights.values()))
473
474
                    # compose grad and variable dict
475
                    gvs = dict(zip(fast_weights.keys(), grads))
476
                    # update theta_pi according to varibles
477
                    fast_weights = dict(zip(fast_weights.keys(), [fast_weights[key] - tf.multiply(self.update_lr, gvs[key])
478
                                          for key in fast_weights.keys()]))
479
480
                    # forward on theta_pi
481
                    target_pred, target_represent = self.learner(target_x, fast_weights, self.ph_dropout, reuse=True, is_training=is_training, type="target")
482
                    # we need accumulate all meta-test losses to update theta
483
                    target_loss = self.loss_func(target_pred, target_y)
484
                    target_preds.append(target_pred)
485
                    target_losses.append(target_loss)
486
                    target_represents.append(target_represent)
487
488
489
                task_output = [target_represents, source_pred, target_preds, source_loss, target_losses]
490
                for j in range(num_updates):
491
                    target_accs.append(tf.contrib.metrics.accuracy(predictions=tf.argmax(tf.nn.softmax(target_preds[j]), 1), labels=target_y))
492
                task_output.extend([source_acc, target_accs])
493
                return task_output
494
495
            if FLAGS.norm is not 'None': # batch norm or layer norm
496
                # to initialize the batch norm vars, might want to combine this, and not run idx 0 twice.
497
                unused = task_metalearn((source_xb[0], target_xb[0], source_yb[0], target_yb[0]), False)
498
499
            out_dtype = [[tf.float32] * num_updates, tf.float32, [tf.float32] * num_updates, tf.float32, [tf.float32] * num_updates,
500
                         tf.float32, [tf.float32] * num_updates]
501
502
            result = tf.map_fn(task_metalearn, elems=(source_xb, target_xb, source_yb, target_yb),
503
                              dtype=out_dtype, parallel_iterations=FLAGS.meta_batch_size, name='map_fn')
504
            target_represents_tasks, source_pred_tasks, target_preds_tasks, source_loss_tasks, target_losses_tasks, \
505
                           source_acc_tasks, target_accs_tasks = result
506
507
        ## Performance & Optimization
508
        # average loss
509
        self.source_loss = source_loss = tf.reduce_sum(source_loss_tasks) / FLAGS.meta_batch_size
510
        # [avgloss_T1, avgloss_T2, ..., avgloss_TK]
511
        self.target_losses = target_losses = [tf.reduce_sum(target_losses_tasks[j]) / FLAGS.meta_batch_size
512
                                              for j in range(num_updates)]
513
        self.source_acc = source_acc = tf.reduce_sum(source_acc_tasks) / FLAGS.meta_batch_size
514
        self.target_accs = target_accs = [tf.reduce_sum(target_accs_tasks[j]) / FLAGS.meta_batch_size
515
                                            for j in range(num_updates)]
516
        self.source_pred = source_pred_tasks
517
        self.target_preds = target_preds_tasks[FLAGS.num_updates-1]
518
        self.target_represent = target_represents_tasks[FLAGS.num_updates-1]
519
520
        if self.ph_training is not False:
521
            # meta-train optim
522
            optimizer = tf.train.AdamOptimizer(self.meta_lr, name='meta_optim')
523
            # meta-train gradients, target_losses[-1] is the accumulated loss across over tasks.
524
            self.gvs = gvs = optimizer.compute_gradients(self.source_loss + self.target_losses[FLAGS.num_updates-1])
525
            # update theta
526
            self.metatrain_op = optimizer.apply_gradients(gvs)
527
528
        ## Summaries
529
        # NOTICE: every time build model, support_loss will be added to the summary, but it's different.
530
        tf.summary.scalar(prefix+'Pre-update loss', source_loss)
531
        tf.summary.scalar(prefix+'Pre-update accuracy', source_acc)
532
        for j in range(num_updates):
533
            tf.summary.scalar(prefix+'Post-update accuracy, step ' + str(j+1), target_losses[j])
534
            tf.summary.scalar(prefix+'Post-update accuracy, step ' + str(j+1), target_losses[j])
535
536
537
    def compute_metrics(self, predictions, labels):
538
        '''compute metrics score'''
539
        fpr, tpr, _ = sklearn.metrics.roc_curve(labels, predictions)
540
        auc = sklearn.metrics.auc(fpr, tpr)
541
        ncorrects = sum(predictions == labels)
542
        accuracy = sklearn.metrics.accuracy_score(labels, predictions)
543
        ap = sklearn.metrics.average_precision_score(labels, predictions, 'micro')
544
        f1score = sklearn.metrics.f1_score(labels, predictions,  'micro')
545
        return auc, ap, f1score
546
547
548
    # def evaluate(self, sample, label, sess=None, prefix="metaval_"):
549
    def evaluate(self, episode, data_tuple_val, sess=None, prefix="metaval_"):
550
        '''validate meta learning model'''
551
        target_acc,target_vals,target_preds = [], [], []
552
        size = len(episode)
553
554
        for begin in range(0, size, FLAGS.meta_batch_size):
555
            end = begin + FLAGS.meta_batch_size
556
            end = min([end, size])
557
            if end-begin < FLAGS.meta_batch_size: break
558
559
            batch_idx = range(begin, end)
560
            sample, label = self.get_feed_data(episode, batch_idx, data_tuple_val, is_training=False)
561
562
            X_tensor_s = self.convert_to_array(sample[:, :(self.n_tasks-1) * FLAGS.update_batch_size, :, :])
563
            X_tensor_t = self.convert_to_array(sample[:, (self.n_tasks-1) * FLAGS.update_batch_size:, :, :])
564
            y_tensor_s = self.convert_to_array(label[:, :(self.n_tasks-1) * FLAGS.update_batch_size])
565
            y_tensor_t = self.convert_to_array(label[:, (self.n_tasks-1) * FLAGS.update_batch_size:])
566
567
            feed_dict = {self.input_s: X_tensor_s, self.input_t: X_tensor_t, self.label_s: y_tensor_s, self.label_t: y_tensor_t, self.ph_dropout: 1, self.ph_training: False}
568
            input_tensors = [self.target_preds, self.target_accs[FLAGS.num_updates-1]]
569
            metaval_target_preds, metaval_target_accs = sess.run(input_tensors, feed_dict)
570
            target_acc.append(metaval_target_accs)
571
            target_preds.append(metaval_target_preds)
572
            target_vals.append(y_tensor_t)
573
574
        target_vals = np.array(target_vals).flatten()
575
        target_preds = np.array([np.argmax(preds, axis=2) for preds in target_preds]).flatten()
576
577
        target_acc = np.mean(target_acc)
578
        target_auc, target_ap, target_f1 = self.compute_metrics(target_preds, target_vals)
579
580
        return target_acc, target_auc, target_ap, target_f1
581
582
583
    def get_feed_data(self, episode, batch_idx, data_tuple, is_training, is_show=False):
584
        ''' given batch indices, get data array from the generated index episodes'''
585
        n_samples_per_task = FLAGS.update_batch_size
586
        data_s, data_t, label_s, label_t  = data_tuple
587
        # generate episode
588
        sample, label = [], []
589
        batch_count = 0
590
        for i in range(len(batch_idx)): # the 1st dimension is the batch size
591
            # i.e., sample 16 patients from selected tasks
592
            # len of spl and lbl: 4 * 16
593
            spl, lbl = [], [] # samples and labels in one episode
594
            bi = batch_idx[i]
595
            data_idx = episode[bi] # all tasks are merged: [task1, task2, ..., tastn], where taskn is target
596
            n_source = 0
597
            for i in range(len(self.data_loader.source)):
598
                s_idx = data_idx[i*n_samples_per_task:(i+1)*n_samples_per_task]
599
                spl.extend(data_s[i][s_idx])
600
                lbl.extend(label_s[i][s_idx])
601
                n_source += n_samples_per_task
602
            ### do not keep pos/neg ratio
603
            if is_training:
604
                t_idx = data_idx[n_source:]
605
                spl.extend(data_t[0][t_idx])
606
                lbl.extend(label_t[0][t_idx])
607
            else:
608
                t_idx = data_idx[n_source:]
609
                spl.extend(data_t[t_idx])
610
                lbl.extend(label_t[t_idx])
611
612
            batch_count += 1
613
            # add meta_batch
614
            sample.append(spl)
615
            label.append(lbl)
616
617
        sample = np.array(sample, dtype="float32")
618
        label = np.array(label, dtype="float32")
619
        return sample, label
620
621
622
    def fit(self, episode, episode_val, ifold, exp_string, model_file = None):
623
        sess = tf.Session(graph=self.graph)
624
        if FLAGS.resume or not FLAGS.train:
625
            model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' + exp_string)
626
            if model_file:
627
                ind1 = model_file.index('model')
628
                print("Restoring model weights from " + model_file)
629
                self.op_saver.restore(sess, model_file)
630
        sess.run(self.op_init)
631
632
        if FLAGS.log:
633
            train_writer = tf.summary.FileWriter(FLAGS.logdir + '/' + exp_string, sess.graph)
634
635
        # load data for metatrain
636
        data_tuple = (self.data_loader.data_s, self.data_loader.data_t, self.data_loader.label_s, self.data_loader.label_t)
637
        # load data for metaeval
638
        data_tuple_val = (self.data_loader.data_s, self.data_loader.data_tt_val[ifold], self.data_loader.label_s, self.data_loader.label_tt_val[ifold])
639
640
        prelosses, postlosses, preaccs, postaccs = [], [], [], []
641
642
        # train for meta_iteartion epoches
643
        indices = collections.deque()
644
        for itr in range(FLAGS.metatrain_iterations):
645
            feed_dict = {}
646
            input_tensors = [self.metatrain_op]
647
648
            if itr % SUMMARY_INTERVAL == 0 or itr % PRINT_INTERVAL == 0:
649
                input_tensors.extend([self.op_summary, self.source_loss, self.target_losses[FLAGS.num_updates-1]])
650
                input_tensors.extend([self.source_acc, self.target_accs[FLAGS.num_updates-1], self.target_preds])
651
652
            if len(indices) < FLAGS.meta_batch_size:
653
                 indices.extend(np.random.permutation(len(episode)))
654
            batch_idx = [indices.popleft() for i in range(FLAGS.meta_batch_size)]
655
            sample, label = self.get_feed_data(episode, batch_idx, data_tuple, is_training=True)
656
657
            X_tensor_s = self.convert_to_array(sample[:, :(self.n_tasks-1) * FLAGS.update_batch_size, :, :])
658
            X_tensor_t = self.convert_to_array(sample[:, (self.n_tasks-1) * FLAGS.update_batch_size:, :, :])
659
            y_tensor_s = self.convert_to_array(label[:, :(self.n_tasks-1) * FLAGS.update_batch_size])
660
            y_tensor_t = self.convert_to_array(label[:, (self.n_tasks-1) * FLAGS.update_batch_size:])
661
            feed_dict = {self.input_s: X_tensor_s, self.input_t: X_tensor_t, self.label_s: y_tensor_s, self.label_t: y_tensor_t, self.ph_dropout: FLAGS.dropout, self.ph_training: True}
662
663
            result = sess.run(input_tensors, feed_dict)
664
            if itr % SUMMARY_INTERVAL == 0:
665
                prelosses.append(result[-5])
666
                preaccs.append(result[-3])
667
                if FLAGS.log:
668
                    train_writer.add_summary(result[1], itr)
669
                postlosses.append(result[-4])
670
                postaccs.append(result[-2])
671
                postauc, postap, postf1 = self.compute_metrics(np.argmax(result[-1], axis=2).flatten(), y_tensor_t.flatten())
672
673
            if (itr!=0) and itr % PRINT_INTERVAL == 0:
674
                print_str = 'Iteration ' + str(itr)
675
                print_str += ': sacc: ' + str(np.mean(preaccs)) + ', tacc: ' + str(np.mean(postaccs))
676
                print_str += " tauc: " + str(postauc) + " tap: " + str(postap) + " tf1: " + str(postf1)
677
                print(print_str)
678
                preaccs, postaccs = [], []
679
                prelosses, postlosses = [], []
680
681
            if (itr!=0) and itr % SAVE_INTERVAL == 0:
682
                self.op_saver.save(sess, FLAGS.logdir + '/' + exp_string + '/model' + str(itr))
683
684
            if (itr!=0) and itr % TEST_PRINT_INTERVAL == 0:
685
                target_accs, target_aucs, target_ap, target_f1s = self.evaluate(episode_val, data_tuple_val, sess=sess, prefix="metaval_")
686
                self.auc_stable.append(target_aucs)
687
                self.f1s_stable.append(target_f1s)
688
                print('Validation results: ' + "tAcc: " + str(target_accs) + ", tAuc: " + str(target_aucs) + ", tAP: "  + str(target_ap) + ", tF1: "  + str(target_f1s))
689
                print ("---------------")
690
        self.op_saver.save(sess, FLAGS.logdir + '/' + exp_string +  '/model' + str(itr))
691
        print ("---------------")
692
693
        # store weights value for fine-tune
694
        feed_dict = {}
695
        for k in self.op_weights:
696
             self.weights_for_finetune[k] = sess.run([self.op_weights[k]], feed_dict)[0]
697
        return sess
698
699
700
class LSTMCell(RNNCell):
701
    '''Vanilla LSTM implemented with same initializations as BN-LSTM'''
702
    def __init__(self, num_units, W_xh, W_hh, bias):
703
        self.num_units = num_units
704
        self.W_xh = W_xh
705
        self.W_hh = W_hh
706
        self.bias = bias
707
708
    @property
709
    def state_size(self):
710
        return (self.num_units, self.num_units)
711
712
    @property
713
    def output_size(self):
714
        return self.num_units
715
716
    def __call__(self, x, state, scope=None):
717
        with tf.variable_scope(scope or type(self).__name__, reuse=tf.AUTO_REUSE):
718
            c, h = state
719
720
            # hidden = tf.matmul(x, W_xh) + tf.matmul(h, W_hh) + bias
721
            # improve speed by concat.
722
            concat = tf.concat([x, h], 1)
723
            W_both = tf.concat([self.W_xh, self.W_hh], 0)
724
            hidden = tf.matmul(concat, W_both) + self.bias
725
726
            i, j, f, o = tf.split(hidden, 4, axis=1)
727
728
            new_c = c * tf.sigmoid(f) + tf.sigmoid(i) * tf.tanh(j)
729
            new_h = tf.tanh(new_c) * tf.sigmoid(o)
730
731
            return new_h, (new_c, new_h)