Diff of /class_DeepIMV_AISTATS.py [000000] .. [0f2bcf]

Switch to unified view

a b/class_DeepIMV_AISTATS.py
1
import tensorflow as tf
2
import numpy as np
3
4
from tensorflow.contrib.layers import fully_connected as FC_Net
5
6
7
_EPSILON = 1e-8
8
9
def div(x_, y_):
10
    return tf.div(x_, y_ + _EPSILON)
11
12
def log(x_):
13
    return tf.log(x_ + _EPSILON)
14
15
def xavier_initialization(size):
16
    dim_ = size[0]
17
    xavier_stddev = 1. / tf.sqrt(dim_ / 2.)
18
    return tf.random_normal(shape=size, stddev=xavier_stddev)
19
20
21
### DEFINE PREDICTOR
22
def predictor(x_, o_dim_, o_type_, num_layers_=1, h_dim_=100, activation_fn=tf.nn.relu, keep_prob_=1.0, w_reg_=None):
23
    '''
24
        INPUT
25
            x_            : (2D-tensor) input
26
            o_dim_        : (int) output dimension
27
            o_type_       : (string) output type one of {'continuous', 'categorical', 'binary'}
28
            num_layers_   : (int) # of hidden layers
29
            activation_fn_: tf activation functions
30
        
31
        OUTPUT
32
            o_type_ tensor
33
    '''
34
    if o_type_ == 'continuous':
35
        out_fn = None
36
    elif o_type_ == 'categorical':
37
        out_fn = tf.nn.softmax #for classification task
38
    elif o_type_ == 'binary':
39
        out_fn = tf.nn.sigmoid
40
    else:
41
        raise ValueError('Wrong output type. The value {}!!'.format(o_type_))
42
43
    if num_layers_ == 1:
44
        out =  FC_Net(inputs=x_, num_outputs=o_dim_, activation_fn=out_fn, weights_regularizer=w_reg_, scope='out')
45
    else: #num_layers > 1
46
        for tmp_layer in range(num_layers_-1):
47
            if tmp_layer == 0:
48
                net = x_
49
            net = FC_Net(inputs=net, num_outputs=h_dim_, activation_fn=activation_fn, weights_regularizer=w_reg_, scope='layer_'+str(tmp_layer))
50
            net = tf.nn.dropout(net, keep_prob=keep_prob_)
51
        out =  FC_Net(inputs=net, num_outputs=o_dim_, activation_fn=out_fn, weights_regularizer=w_reg_, scope='out')  
52
    return out
53
54
55
### DEFINE STOCHASTIC ENCODER
56
def stochastic_encoder(x_, o_dim_, num_layers_=1, h_dim_=100, activation_fn=tf.nn.relu, keep_prob_=1.0, w_reg_=None):
57
    '''
58
        INPUT
59
            x_            : (2D-tensor) input
60
            o_dim_        : (int) output dimension
61
            num_layers_   : (int) # of hidden layers
62
            activation_fn_: tf activation functions
63
        
64
        OUTPUT
65
            [mu,sigma] tensor
66
    '''
67
    if num_layers_ == 1:
68
        out =  FC_Net(inputs=x_, num_outputs=o_dim_, activation_fn=None, weights_regularizer=w_reg_, scope='out')
69
    else: #num_layers > 1
70
        for tmp_layer in range(num_layers_-1):
71
            if tmp_layer == 0:
72
                net = x_
73
            net = FC_Net(inputs=net, num_outputs=h_dim_, activation_fn=activation_fn, weights_regularizer=w_reg_, scope='layer_'+str(tmp_layer))
74
            net = tf.nn.dropout(net, keep_prob=keep_prob_)
75
        out =  FC_Net(inputs=net, num_outputs=o_dim_, activation_fn=None, weights_regularizer=w_reg_, scope='out')  
76
    return out
77
78
79
### DEFINE SUPERVISED LOSS FUNCTION
80
def loss_y(y_true_, y_pred_, y_type_):                
81
    if y_type_ == 'continuous':
82
        tmp_loss = tf.reduce_sum((y_true_ - y_pred_)**2, axis=-1)
83
    elif y_type_ == 'categorical':
84
        tmp_loss = - tf.reduce_sum(y_true_ * log(y_pred_), axis=-1)
85
    elif y_type_ == 'binary':
86
        tmp_loss = - tf.reduce_sum(y_true_ * log(y_pred_) + (1.-y_true_) * log(1.-y_pred_), axis=-1)
87
    else:
88
        raise ValueError('Wrong output type. The value {}!!'.format(y_type_))                    
89
    return tmp_loss
90
91
92
### DEFINE NETWORK-RELATED FUNCTIONS
93
def product_of_experts(mask_, mu_set_, logvar_set_):
94
    tmp = 1.
95
    for m in range(len(mu_set_)):
96
        tmp += tf.reshape(mask_[:, m], [-1,1])*div(1., tf.exp(logvar_set_[m]))
97
    poe_var = div(1., tmp)
98
    poe_logvar = log(poe_var)
99
    
100
    tmp = 0.
101
    for m in range(len(mu_set_)):
102
        tmp += tf.reshape(mask_[:, m], [-1,1])*div(1., tf.exp(logvar_set_[m]))*mu_set_[m]
103
    poe_mu = poe_var * tmp
104
    
105
    return poe_mu, poe_logvar
106
107
    
108
109
###########################################################################
110
#### DEFINE PROPOSED-NETWORK
111
class DeepIMV_AISTATS:
112
    '''
113
        - Add mixture mode
114
        - Remove common/shared parts -- go back to the previous version
115
        - Leave the consistency loss; but make sure to set gamma = 0
116
    '''
117
    
118
    def __init__(self, sess, name, input_dims, network_settings):
119
        self.sess             = sess
120
        self.name             = name
121
       
122
        # INPUT/OUTPUT DIMENSIONS
123
        self.M                = len(input_dims['x_dim_set'])
124
        
125
        self.x_dim_set = {}
126
        for m in range(self.M):
127
            self.x_dim_set[m] = input_dims['x_dim_set'][m]
128
            
129
        self.y_dim            = input_dims['y_dim']
130
        self.y_type           = input_dims['y_type']
131
       
132
        self.z_dim            = input_dims['z_dim']  # z_dim is equivalent to W and Z             
133
        self.steps_per_batch  = input_dims['steps_per_batch']
134
        
135
        # PREDICTOR INFO (VIEW-SPECIFC)
136
        self.h_dim_p1         = network_settings['h_dim_p1']      #predictor hidden nodes
137
        self.num_layers_p1    = network_settings['num_layers_p1'] #predictor layers
138
        
139
        # PREDICTOR INFO (MULTI_VIEW)
140
        self.h_dim_p2         = network_settings['h_dim_p2']      #predictor hidden nodes
141
        self.num_layers_p2    = network_settings['num_layers_p2'] #predictor layers
142
        
143
        # ENCODER INFO
144
        self.h_dim_e          = network_settings['h_dim_e']      #encoder hidden nodes
145
        self.num_layers_e     = network_settings['num_layers_e'] #encoder layers
146
       
147
        self.fc_activate_fn   = network_settings['fc_activate_fn'] 
148
        self.reg_scale        = network_settings['reg_scale']   #regularization
149
                
150
        self._build_net()
151
        
152
       
153
    def _build_net(self):
154
        ds     = tf.contrib.distributions
155
        
156
#         with tf.name_scope(self.name):
157
        with tf.variable_scope(self.name):
158
            self.mb_size        = tf.placeholder(tf.int32, [], name='batch_size')
159
            self.lr_rate        = tf.placeholder(tf.float32, name='learning_rate')           
160
            self.k_prob         = tf.placeholder(tf.float32, name='keep_probability')
161
                       
162
            ### INPUT/OUTPUT                   
163
            self.x_set          = {}
164
            for m in range(self.M):
165
                self.x_set[m]   = tf.placeholder(tf.float32, [None, self.x_dim_set[m]], 'input_{}'.format(m))
166
            
167
            self.mask           = tf.placeholder(tf.float32, [None, self.M], name='mask')            
168
            self.y              = tf.placeholder(tf.float32, [None, self.y_dim],  name='output')
169
                                   
170
            ### BALANCING COEFFICIENTS
171
            self.alpha          = tf.placeholder(tf.float32, name='coef_alpha') #Consitency Loss
172
            self.beta           = tf.placeholder(tf.float32, name='coef_beta')  #Information Bottleneck
173
            
174
            if self.reg_scale == 0:
175
                w_reg           = None
176
            else:
177
                w_reg           = tf.contrib.layers.l1_regularizer(scale=self.reg_scale)
178
                        
179
            ### PRIOR
180
            prior_z  = ds.Normal(0.0, 1.0) #PoE Prior - q(z)
181
            prior_z_set = {}
182
            for m in range(self.M):
183
                prior_z_set[m] = ds.Normal(0.0, 1.0) #View-Specific Prior - q(z_{m})
184
                        
185
            ### STOCHASTIC ENCODER
186
            self.h_set      = {}
187
            
188
            self.mu_z_set     = {}
189
            self.logvar_z_set = {}
190
            
191
            for m in range(self.M):
192
                with tf.variable_scope('encoder{}'.format(m+1)):
193
                    self.h_set[m]      = stochastic_encoder(
194
                        x_=self.x_set[m], o_dim_=2*self.z_dim, 
195
                        num_layers_=self.num_layers_e, h_dim_=self.h_dim_e, 
196
                        activation_fn=self.fc_activate_fn, keep_prob_=self.k_prob, w_reg_=w_reg
197
                    )
198
                    self.mu_z_set[m]     = self.h_set[m][:, :self.z_dim] 
199
                    self.logvar_z_set[m] = self.h_set[m][:, self.z_dim:]
200
            
201
            self.mu_z, self.logvar_z = product_of_experts(self.mask, self.mu_z_set, self.logvar_z_set)                
202
            
203
            qz         = ds.Normal(self.mu_z, tf.sqrt(tf.exp(self.logvar_z)))
204
            self.z     = qz.sample()
205
            self.zs    = qz.sample(10)
206
207
            qz_set     = {}
208
            self.z_set = {}
209
            for m in range(self.M):
210
                qz_set[m]      = ds.Normal(self.mu_z_set[m], tf.sqrt(tf.exp(self.logvar_z_set[m])))
211
                self.z_set[m]  = qz_set[m].sample()
212
213
    
214
    
215
            ### PREDICTOR (JOINT)
216
            with tf.variable_scope('predictor'):
217
                self.y_hat = predictor(
218
                    x_=self.z, o_dim_=self.y_dim, o_type_=self.y_type, 
219
                    num_layers_=self.num_layers_p2, h_dim_=self.h_dim_p2, 
220
                    activation_fn=self.fc_activate_fn, keep_prob_=self.k_prob, w_reg_=w_reg
221
                )
222
223
            # this will generate multiple samples of y (based on multiple samples drawn from the variational encoder.
224
            with tf.variable_scope('predictor', reuse=True):               
225
                self.y_hats = predictor(
226
                    x_=self.zs, o_dim_=self.y_dim, o_type_=self.y_type, 
227
                    num_layers_=self.num_layers_p2, h_dim_=self.h_dim_p2, 
228
                    activation_fn=self.fc_activate_fn, keep_prob_=self.k_prob, w_reg_=w_reg
229
                )
230
                                    
231
            ### PREDICTOR 
232
            self.y_hat_set = {}
233
            for m in range(self.M):
234
                with tf.variable_scope('predictor_set{}'.format(m)):
235
                    self.y_hat_set[m] = predictor(
236
                        x_=self.z_set[m], o_dim_=self.y_dim, o_type_=self.y_type, 
237
                        num_layers_=self.num_layers_p1, h_dim_=self.h_dim_p1, 
238
                        activation_fn=self.fc_activate_fn, keep_prob_=self.k_prob, w_reg_=w_reg
239
                    )
240
241
            
242
            ### OPTIMIZER
243
            global_vars      = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
244
            enc_vars         = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self.name + '/encoder')
245
            pred_vars        = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self.name + '/predictor')
246
            
247
        
248
            ### CONSITENCY LOSS
249
            self.LOSS_CONSISTENCY = 0.
250
            for m in range(self.M):
251
                self.LOSS_CONSISTENCY += 1./self.M * div(
252
                    tf.reduce_sum(self.mask[:, m] * tf.reduce_sum(ds.kl_divergence(qz, qz_set[m]), axis=-1)),
253
                    tf.reduce_sum(self.mask[:, m])
254
                )
255
                                                                        
256
257
            self.LOSS_KL       = tf.reduce_mean(
258
                tf.reduce_sum(ds.kl_divergence(qz, prior_z), axis=-1)
259
            )            
260
            self.LOSS_P        = tf.reduce_mean(loss_y(self.y, self.y_hat, self.y_type))
261
            
262
            self.LOSS_IB_JOINT = self.LOSS_P + self.beta*self.LOSS_KL
263
                                    
264
            self.LOSS_Ps_all  = []
265
            self.LOSS_KLs_all = []
266
            for m in range(self.M):
267
                tmp_p              = loss_y(self.y, self.y_hat_set[m], self.y_type)
268
                tmp_kl             = tf.reduce_sum(ds.kl_divergence(qz_set[m], prior_z_set[m]), axis=-1)
269
        
270
                self.LOSS_Ps_all  += [div(tf.reduce_sum(self.mask[:,m]*tmp_p), tf.reduce_sum(self.mask[:,m]))]        
271
                self.LOSS_KLs_all += [div(tf.reduce_sum(self.mask[:,m]*tmp_kl), tf.reduce_sum(self.mask[:,m]))]
272
        
273
                
274
            self.LOSS_Ps_all  = tf.stack(self.LOSS_Ps_all, axis=0)
275
            self.LOSS_KLs_all = tf.stack(self.LOSS_KLs_all, axis=0)
276
            
277
            self.LOSS_Ps      = tf.reduce_sum(self.LOSS_Ps_all)
278
            self.LOSS_KLs     = tf.reduce_sum(self.LOSS_KLs_all)
279
                        
280
            self.LOSS_IB_MARGINAL = self.LOSS_Ps + self.beta*self.LOSS_KLs         
281
            
282
283
            self.LOSS_TOTAL       = self.LOSS_IB_JOINT\
284
                                    + self.alpha*(self.LOSS_IB_MARGINAL)\
285
                                    + tf.losses.get_regularization_loss()
286
    
287
            
288
            self.global_step      = tf.contrib.framework.get_or_create_global_step()
289
            self.lr_rate_decayed  = tf.train.exponential_decay(self.lr_rate, self.global_step,
290
                                                       decay_steps=2*self.steps_per_batch,
291
                                                       decay_rate=0.97, staircase=True)
292
293
            opt                = tf.train.AdamOptimizer(self.lr_rate_decayed, 0.5)
294
                
295
                        
296
            ma = tf.train.ExponentialMovingAverage(0.999, zero_debias=True)
297
            ma_update = ma.apply(tf.model_variables())
298
            
299
300
            self.solver = tf.contrib.training.create_train_op(self.LOSS_TOTAL, opt,
301
                                                               self.global_step,
302
                                                               update_ops=[ma_update])
303
                
304
305
    def train(self, x_set_, y_, m_, alpha_, beta_, lr_train, k_prob=1.0):
306
        feed_dict_ = self.make_feed_dict(x_set_)
307
        feed_dict_.update({self.y: y_, self.mask: m_, 
308
                           self.alpha: alpha_, self.beta: beta_, 
309
                           self.mb_size: np.shape(x_set_[0])[0],
310
                           self.lr_rate: lr_train, self.k_prob: k_prob})        
311
        return self.sess.run([self.solver, self.LOSS_TOTAL, self.LOSS_P, self.LOSS_KL, self.LOSS_Ps, 
312
                              self.LOSS_KLs, self.LOSS_CONSISTENCY],
313
                             feed_dict=feed_dict_)
314
   
315
    def get_loss(self, x_set_, y_, m_, alpha_, beta_):
316
        feed_dict_ = self.make_feed_dict(x_set_)
317
        feed_dict_.update({self.y: y_, self.mask: m_, 
318
                           self.alpha: alpha_, self.beta: beta_,
319
                           self.mb_size: np.shape(x_set_[0])[0], self.k_prob: 1.0})        
320
        return self.sess.run([self.LOSS_TOTAL, self.LOSS_P, self.LOSS_KL, self.LOSS_Ps, 
321
                              self.LOSS_KLs, self.LOSS_CONSISTENCY, self.LOSS_Ps_all, self.LOSS_KLs_all],
322
                             feed_dict=feed_dict_)
323
        
324
325
    def predict_y(self, x_set_, m_):
326
        feed_dict_ = self.make_feed_dict(x_set_)
327
        feed_dict_.update({self.mask: m_, self.mb_size: np.shape(x_set_[0])[0], self.k_prob: 1.0})
328
        return self.sess.run(self.y_hat, feed_dict=feed_dict_)
329
    
330
    def predict_ys(self, x_set_, m_):
331
        feed_dict_ = self.make_feed_dict(x_set_)
332
        feed_dict_.update({self.mask: m_, self.mb_size: np.shape(x_set_[0])[0], self.k_prob: 1.0})
333
        return self.sess.run([self.y_hat, self.y_hats], feed_dict=feed_dict_)
334
335
    def predict_yhat_set(self, x_set_, m_):
336
        feed_dict_ = self.make_feed_dict(x_set_)
337
        feed_dict_.update({self.mask: m_, self.mb_size: np.shape(x_set_[0])[0], self.k_prob: 1.0})
338
        return self.sess.run(self.y_hat_set, feed_dict=feed_dict_)
339
    
340
    def predict_mu_z_and_mu_z_set(self, x_set_, m_): #this outputs mu and mu_set
341
        feed_dict_ = self.make_feed_dict(x_set_)
342
        feed_dict_.update({self.mask: m_, self.mb_size: np.shape(x_set_[0])[0], self.k_prob: 1.0})
343
        return self.sess.run([self.mu_z, self.mu_z_set], feed_dict=feed_dict_)
344
    
345
    def predict_logvar_z_and_logvar_z_set(self, x_set_, m_): #this outputs sigma and sigma_set
346
        feed_dict_ = self.make_feed_dict(x_set_)
347
        feed_dict_.update({self.mask: m_, self.mb_size: np.shape(x_set_[0])[0], self.k_prob: 1.0})
348
        return self.sess.run([self.logvar_z, self.logvar_z_set], feed_dict=feed_dict_)
349
350
    def predict_z_n_z_set(self, x_set_, m_):  #this outputs z and z_set
351
        feed_dict_ = self.make_feed_dict(x_set_)
352
        feed_dict_.update({self.mask: m_, self.mb_size: np.shape(x_set_[0])[0], self.k_prob: 1.0})
353
        return self.sess.run([self.z, self.z_set], feed_dict=feed_dict_)
354
  
355
    def make_feed_dict(self, x_set_):
356
        feed_dict_ = {}
357
        for m in range(len(self.x_set)):
358
            feed_dict_[self.x_set[m]] = x_set_[m]           
359
        return feed_dict_
360
361