|
a |
|
b/class_DeepIMV_AISTATS.py |
|
|
1 |
import tensorflow as tf |
|
|
2 |
import numpy as np |
|
|
3 |
|
|
|
4 |
from tensorflow.contrib.layers import fully_connected as FC_Net |
|
|
5 |
|
|
|
6 |
|
|
|
7 |
_EPSILON = 1e-8 |
|
|
8 |
|
|
|
9 |
def div(x_, y_): |
|
|
10 |
return tf.div(x_, y_ + _EPSILON) |
|
|
11 |
|
|
|
12 |
def log(x_): |
|
|
13 |
return tf.log(x_ + _EPSILON) |
|
|
14 |
|
|
|
15 |
def xavier_initialization(size): |
|
|
16 |
dim_ = size[0] |
|
|
17 |
xavier_stddev = 1. / tf.sqrt(dim_ / 2.) |
|
|
18 |
return tf.random_normal(shape=size, stddev=xavier_stddev) |
|
|
19 |
|
|
|
20 |
|
|
|
21 |
### DEFINE PREDICTOR |
|
|
22 |
def predictor(x_, o_dim_, o_type_, num_layers_=1, h_dim_=100, activation_fn=tf.nn.relu, keep_prob_=1.0, w_reg_=None): |
|
|
23 |
''' |
|
|
24 |
INPUT |
|
|
25 |
x_ : (2D-tensor) input |
|
|
26 |
o_dim_ : (int) output dimension |
|
|
27 |
o_type_ : (string) output type one of {'continuous', 'categorical', 'binary'} |
|
|
28 |
num_layers_ : (int) # of hidden layers |
|
|
29 |
activation_fn_: tf activation functions |
|
|
30 |
|
|
|
31 |
OUTPUT |
|
|
32 |
o_type_ tensor |
|
|
33 |
''' |
|
|
34 |
if o_type_ == 'continuous': |
|
|
35 |
out_fn = None |
|
|
36 |
elif o_type_ == 'categorical': |
|
|
37 |
out_fn = tf.nn.softmax #for classification task |
|
|
38 |
elif o_type_ == 'binary': |
|
|
39 |
out_fn = tf.nn.sigmoid |
|
|
40 |
else: |
|
|
41 |
raise ValueError('Wrong output type. The value {}!!'.format(o_type_)) |
|
|
42 |
|
|
|
43 |
if num_layers_ == 1: |
|
|
44 |
out = FC_Net(inputs=x_, num_outputs=o_dim_, activation_fn=out_fn, weights_regularizer=w_reg_, scope='out') |
|
|
45 |
else: #num_layers > 1 |
|
|
46 |
for tmp_layer in range(num_layers_-1): |
|
|
47 |
if tmp_layer == 0: |
|
|
48 |
net = x_ |
|
|
49 |
net = FC_Net(inputs=net, num_outputs=h_dim_, activation_fn=activation_fn, weights_regularizer=w_reg_, scope='layer_'+str(tmp_layer)) |
|
|
50 |
net = tf.nn.dropout(net, keep_prob=keep_prob_) |
|
|
51 |
out = FC_Net(inputs=net, num_outputs=o_dim_, activation_fn=out_fn, weights_regularizer=w_reg_, scope='out') |
|
|
52 |
return out |
|
|
53 |
|
|
|
54 |
|
|
|
55 |
### DEFINE STOCHASTIC ENCODER |
|
|
56 |
def stochastic_encoder(x_, o_dim_, num_layers_=1, h_dim_=100, activation_fn=tf.nn.relu, keep_prob_=1.0, w_reg_=None): |
|
|
57 |
''' |
|
|
58 |
INPUT |
|
|
59 |
x_ : (2D-tensor) input |
|
|
60 |
o_dim_ : (int) output dimension |
|
|
61 |
num_layers_ : (int) # of hidden layers |
|
|
62 |
activation_fn_: tf activation functions |
|
|
63 |
|
|
|
64 |
OUTPUT |
|
|
65 |
[mu,sigma] tensor |
|
|
66 |
''' |
|
|
67 |
if num_layers_ == 1: |
|
|
68 |
out = FC_Net(inputs=x_, num_outputs=o_dim_, activation_fn=None, weights_regularizer=w_reg_, scope='out') |
|
|
69 |
else: #num_layers > 1 |
|
|
70 |
for tmp_layer in range(num_layers_-1): |
|
|
71 |
if tmp_layer == 0: |
|
|
72 |
net = x_ |
|
|
73 |
net = FC_Net(inputs=net, num_outputs=h_dim_, activation_fn=activation_fn, weights_regularizer=w_reg_, scope='layer_'+str(tmp_layer)) |
|
|
74 |
net = tf.nn.dropout(net, keep_prob=keep_prob_) |
|
|
75 |
out = FC_Net(inputs=net, num_outputs=o_dim_, activation_fn=None, weights_regularizer=w_reg_, scope='out') |
|
|
76 |
return out |
|
|
77 |
|
|
|
78 |
|
|
|
79 |
### DEFINE SUPERVISED LOSS FUNCTION |
|
|
80 |
def loss_y(y_true_, y_pred_, y_type_): |
|
|
81 |
if y_type_ == 'continuous': |
|
|
82 |
tmp_loss = tf.reduce_sum((y_true_ - y_pred_)**2, axis=-1) |
|
|
83 |
elif y_type_ == 'categorical': |
|
|
84 |
tmp_loss = - tf.reduce_sum(y_true_ * log(y_pred_), axis=-1) |
|
|
85 |
elif y_type_ == 'binary': |
|
|
86 |
tmp_loss = - tf.reduce_sum(y_true_ * log(y_pred_) + (1.-y_true_) * log(1.-y_pred_), axis=-1) |
|
|
87 |
else: |
|
|
88 |
raise ValueError('Wrong output type. The value {}!!'.format(y_type_)) |
|
|
89 |
return tmp_loss |
|
|
90 |
|
|
|
91 |
|
|
|
92 |
### DEFINE NETWORK-RELATED FUNCTIONS |
|
|
93 |
def product_of_experts(mask_, mu_set_, logvar_set_): |
|
|
94 |
tmp = 1. |
|
|
95 |
for m in range(len(mu_set_)): |
|
|
96 |
tmp += tf.reshape(mask_[:, m], [-1,1])*div(1., tf.exp(logvar_set_[m])) |
|
|
97 |
poe_var = div(1., tmp) |
|
|
98 |
poe_logvar = log(poe_var) |
|
|
99 |
|
|
|
100 |
tmp = 0. |
|
|
101 |
for m in range(len(mu_set_)): |
|
|
102 |
tmp += tf.reshape(mask_[:, m], [-1,1])*div(1., tf.exp(logvar_set_[m]))*mu_set_[m] |
|
|
103 |
poe_mu = poe_var * tmp |
|
|
104 |
|
|
|
105 |
return poe_mu, poe_logvar |
|
|
106 |
|
|
|
107 |
|
|
|
108 |
|
|
|
109 |
########################################################################### |
|
|
110 |
#### DEFINE PROPOSED-NETWORK |
|
|
111 |
class DeepIMV_AISTATS: |
|
|
112 |
''' |
|
|
113 |
- Add mixture mode |
|
|
114 |
- Remove common/shared parts -- go back to the previous version |
|
|
115 |
- Leave the consistency loss; but make sure to set gamma = 0 |
|
|
116 |
''' |
|
|
117 |
|
|
|
118 |
def __init__(self, sess, name, input_dims, network_settings): |
|
|
119 |
self.sess = sess |
|
|
120 |
self.name = name |
|
|
121 |
|
|
|
122 |
# INPUT/OUTPUT DIMENSIONS |
|
|
123 |
self.M = len(input_dims['x_dim_set']) |
|
|
124 |
|
|
|
125 |
self.x_dim_set = {} |
|
|
126 |
for m in range(self.M): |
|
|
127 |
self.x_dim_set[m] = input_dims['x_dim_set'][m] |
|
|
128 |
|
|
|
129 |
self.y_dim = input_dims['y_dim'] |
|
|
130 |
self.y_type = input_dims['y_type'] |
|
|
131 |
|
|
|
132 |
self.z_dim = input_dims['z_dim'] # z_dim is equivalent to W and Z |
|
|
133 |
self.steps_per_batch = input_dims['steps_per_batch'] |
|
|
134 |
|
|
|
135 |
# PREDICTOR INFO (VIEW-SPECIFC) |
|
|
136 |
self.h_dim_p1 = network_settings['h_dim_p1'] #predictor hidden nodes |
|
|
137 |
self.num_layers_p1 = network_settings['num_layers_p1'] #predictor layers |
|
|
138 |
|
|
|
139 |
# PREDICTOR INFO (MULTI_VIEW) |
|
|
140 |
self.h_dim_p2 = network_settings['h_dim_p2'] #predictor hidden nodes |
|
|
141 |
self.num_layers_p2 = network_settings['num_layers_p2'] #predictor layers |
|
|
142 |
|
|
|
143 |
# ENCODER INFO |
|
|
144 |
self.h_dim_e = network_settings['h_dim_e'] #encoder hidden nodes |
|
|
145 |
self.num_layers_e = network_settings['num_layers_e'] #encoder layers |
|
|
146 |
|
|
|
147 |
self.fc_activate_fn = network_settings['fc_activate_fn'] |
|
|
148 |
self.reg_scale = network_settings['reg_scale'] #regularization |
|
|
149 |
|
|
|
150 |
self._build_net() |
|
|
151 |
|
|
|
152 |
|
|
|
153 |
def _build_net(self): |
|
|
154 |
ds = tf.contrib.distributions |
|
|
155 |
|
|
|
156 |
# with tf.name_scope(self.name): |
|
|
157 |
with tf.variable_scope(self.name): |
|
|
158 |
self.mb_size = tf.placeholder(tf.int32, [], name='batch_size') |
|
|
159 |
self.lr_rate = tf.placeholder(tf.float32, name='learning_rate') |
|
|
160 |
self.k_prob = tf.placeholder(tf.float32, name='keep_probability') |
|
|
161 |
|
|
|
162 |
### INPUT/OUTPUT |
|
|
163 |
self.x_set = {} |
|
|
164 |
for m in range(self.M): |
|
|
165 |
self.x_set[m] = tf.placeholder(tf.float32, [None, self.x_dim_set[m]], 'input_{}'.format(m)) |
|
|
166 |
|
|
|
167 |
self.mask = tf.placeholder(tf.float32, [None, self.M], name='mask') |
|
|
168 |
self.y = tf.placeholder(tf.float32, [None, self.y_dim], name='output') |
|
|
169 |
|
|
|
170 |
### BALANCING COEFFICIENTS |
|
|
171 |
self.alpha = tf.placeholder(tf.float32, name='coef_alpha') #Consitency Loss |
|
|
172 |
self.beta = tf.placeholder(tf.float32, name='coef_beta') #Information Bottleneck |
|
|
173 |
|
|
|
174 |
if self.reg_scale == 0: |
|
|
175 |
w_reg = None |
|
|
176 |
else: |
|
|
177 |
w_reg = tf.contrib.layers.l1_regularizer(scale=self.reg_scale) |
|
|
178 |
|
|
|
179 |
### PRIOR |
|
|
180 |
prior_z = ds.Normal(0.0, 1.0) #PoE Prior - q(z) |
|
|
181 |
prior_z_set = {} |
|
|
182 |
for m in range(self.M): |
|
|
183 |
prior_z_set[m] = ds.Normal(0.0, 1.0) #View-Specific Prior - q(z_{m}) |
|
|
184 |
|
|
|
185 |
### STOCHASTIC ENCODER |
|
|
186 |
self.h_set = {} |
|
|
187 |
|
|
|
188 |
self.mu_z_set = {} |
|
|
189 |
self.logvar_z_set = {} |
|
|
190 |
|
|
|
191 |
for m in range(self.M): |
|
|
192 |
with tf.variable_scope('encoder{}'.format(m+1)): |
|
|
193 |
self.h_set[m] = stochastic_encoder( |
|
|
194 |
x_=self.x_set[m], o_dim_=2*self.z_dim, |
|
|
195 |
num_layers_=self.num_layers_e, h_dim_=self.h_dim_e, |
|
|
196 |
activation_fn=self.fc_activate_fn, keep_prob_=self.k_prob, w_reg_=w_reg |
|
|
197 |
) |
|
|
198 |
self.mu_z_set[m] = self.h_set[m][:, :self.z_dim] |
|
|
199 |
self.logvar_z_set[m] = self.h_set[m][:, self.z_dim:] |
|
|
200 |
|
|
|
201 |
self.mu_z, self.logvar_z = product_of_experts(self.mask, self.mu_z_set, self.logvar_z_set) |
|
|
202 |
|
|
|
203 |
qz = ds.Normal(self.mu_z, tf.sqrt(tf.exp(self.logvar_z))) |
|
|
204 |
self.z = qz.sample() |
|
|
205 |
self.zs = qz.sample(10) |
|
|
206 |
|
|
|
207 |
qz_set = {} |
|
|
208 |
self.z_set = {} |
|
|
209 |
for m in range(self.M): |
|
|
210 |
qz_set[m] = ds.Normal(self.mu_z_set[m], tf.sqrt(tf.exp(self.logvar_z_set[m]))) |
|
|
211 |
self.z_set[m] = qz_set[m].sample() |
|
|
212 |
|
|
|
213 |
|
|
|
214 |
|
|
|
215 |
### PREDICTOR (JOINT) |
|
|
216 |
with tf.variable_scope('predictor'): |
|
|
217 |
self.y_hat = predictor( |
|
|
218 |
x_=self.z, o_dim_=self.y_dim, o_type_=self.y_type, |
|
|
219 |
num_layers_=self.num_layers_p2, h_dim_=self.h_dim_p2, |
|
|
220 |
activation_fn=self.fc_activate_fn, keep_prob_=self.k_prob, w_reg_=w_reg |
|
|
221 |
) |
|
|
222 |
|
|
|
223 |
# this will generate multiple samples of y (based on multiple samples drawn from the variational encoder. |
|
|
224 |
with tf.variable_scope('predictor', reuse=True): |
|
|
225 |
self.y_hats = predictor( |
|
|
226 |
x_=self.zs, o_dim_=self.y_dim, o_type_=self.y_type, |
|
|
227 |
num_layers_=self.num_layers_p2, h_dim_=self.h_dim_p2, |
|
|
228 |
activation_fn=self.fc_activate_fn, keep_prob_=self.k_prob, w_reg_=w_reg |
|
|
229 |
) |
|
|
230 |
|
|
|
231 |
### PREDICTOR |
|
|
232 |
self.y_hat_set = {} |
|
|
233 |
for m in range(self.M): |
|
|
234 |
with tf.variable_scope('predictor_set{}'.format(m)): |
|
|
235 |
self.y_hat_set[m] = predictor( |
|
|
236 |
x_=self.z_set[m], o_dim_=self.y_dim, o_type_=self.y_type, |
|
|
237 |
num_layers_=self.num_layers_p1, h_dim_=self.h_dim_p1, |
|
|
238 |
activation_fn=self.fc_activate_fn, keep_prob_=self.k_prob, w_reg_=w_reg |
|
|
239 |
) |
|
|
240 |
|
|
|
241 |
|
|
|
242 |
### OPTIMIZER |
|
|
243 |
global_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) |
|
|
244 |
enc_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self.name + '/encoder') |
|
|
245 |
pred_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self.name + '/predictor') |
|
|
246 |
|
|
|
247 |
|
|
|
248 |
### CONSITENCY LOSS |
|
|
249 |
self.LOSS_CONSISTENCY = 0. |
|
|
250 |
for m in range(self.M): |
|
|
251 |
self.LOSS_CONSISTENCY += 1./self.M * div( |
|
|
252 |
tf.reduce_sum(self.mask[:, m] * tf.reduce_sum(ds.kl_divergence(qz, qz_set[m]), axis=-1)), |
|
|
253 |
tf.reduce_sum(self.mask[:, m]) |
|
|
254 |
) |
|
|
255 |
|
|
|
256 |
|
|
|
257 |
self.LOSS_KL = tf.reduce_mean( |
|
|
258 |
tf.reduce_sum(ds.kl_divergence(qz, prior_z), axis=-1) |
|
|
259 |
) |
|
|
260 |
self.LOSS_P = tf.reduce_mean(loss_y(self.y, self.y_hat, self.y_type)) |
|
|
261 |
|
|
|
262 |
self.LOSS_IB_JOINT = self.LOSS_P + self.beta*self.LOSS_KL |
|
|
263 |
|
|
|
264 |
self.LOSS_Ps_all = [] |
|
|
265 |
self.LOSS_KLs_all = [] |
|
|
266 |
for m in range(self.M): |
|
|
267 |
tmp_p = loss_y(self.y, self.y_hat_set[m], self.y_type) |
|
|
268 |
tmp_kl = tf.reduce_sum(ds.kl_divergence(qz_set[m], prior_z_set[m]), axis=-1) |
|
|
269 |
|
|
|
270 |
self.LOSS_Ps_all += [div(tf.reduce_sum(self.mask[:,m]*tmp_p), tf.reduce_sum(self.mask[:,m]))] |
|
|
271 |
self.LOSS_KLs_all += [div(tf.reduce_sum(self.mask[:,m]*tmp_kl), tf.reduce_sum(self.mask[:,m]))] |
|
|
272 |
|
|
|
273 |
|
|
|
274 |
self.LOSS_Ps_all = tf.stack(self.LOSS_Ps_all, axis=0) |
|
|
275 |
self.LOSS_KLs_all = tf.stack(self.LOSS_KLs_all, axis=0) |
|
|
276 |
|
|
|
277 |
self.LOSS_Ps = tf.reduce_sum(self.LOSS_Ps_all) |
|
|
278 |
self.LOSS_KLs = tf.reduce_sum(self.LOSS_KLs_all) |
|
|
279 |
|
|
|
280 |
self.LOSS_IB_MARGINAL = self.LOSS_Ps + self.beta*self.LOSS_KLs |
|
|
281 |
|
|
|
282 |
|
|
|
283 |
self.LOSS_TOTAL = self.LOSS_IB_JOINT\ |
|
|
284 |
+ self.alpha*(self.LOSS_IB_MARGINAL)\ |
|
|
285 |
+ tf.losses.get_regularization_loss() |
|
|
286 |
|
|
|
287 |
|
|
|
288 |
self.global_step = tf.contrib.framework.get_or_create_global_step() |
|
|
289 |
self.lr_rate_decayed = tf.train.exponential_decay(self.lr_rate, self.global_step, |
|
|
290 |
decay_steps=2*self.steps_per_batch, |
|
|
291 |
decay_rate=0.97, staircase=True) |
|
|
292 |
|
|
|
293 |
opt = tf.train.AdamOptimizer(self.lr_rate_decayed, 0.5) |
|
|
294 |
|
|
|
295 |
|
|
|
296 |
ma = tf.train.ExponentialMovingAverage(0.999, zero_debias=True) |
|
|
297 |
ma_update = ma.apply(tf.model_variables()) |
|
|
298 |
|
|
|
299 |
|
|
|
300 |
self.solver = tf.contrib.training.create_train_op(self.LOSS_TOTAL, opt, |
|
|
301 |
self.global_step, |
|
|
302 |
update_ops=[ma_update]) |
|
|
303 |
|
|
|
304 |
|
|
|
305 |
def train(self, x_set_, y_, m_, alpha_, beta_, lr_train, k_prob=1.0): |
|
|
306 |
feed_dict_ = self.make_feed_dict(x_set_) |
|
|
307 |
feed_dict_.update({self.y: y_, self.mask: m_, |
|
|
308 |
self.alpha: alpha_, self.beta: beta_, |
|
|
309 |
self.mb_size: np.shape(x_set_[0])[0], |
|
|
310 |
self.lr_rate: lr_train, self.k_prob: k_prob}) |
|
|
311 |
return self.sess.run([self.solver, self.LOSS_TOTAL, self.LOSS_P, self.LOSS_KL, self.LOSS_Ps, |
|
|
312 |
self.LOSS_KLs, self.LOSS_CONSISTENCY], |
|
|
313 |
feed_dict=feed_dict_) |
|
|
314 |
|
|
|
315 |
def get_loss(self, x_set_, y_, m_, alpha_, beta_): |
|
|
316 |
feed_dict_ = self.make_feed_dict(x_set_) |
|
|
317 |
feed_dict_.update({self.y: y_, self.mask: m_, |
|
|
318 |
self.alpha: alpha_, self.beta: beta_, |
|
|
319 |
self.mb_size: np.shape(x_set_[0])[0], self.k_prob: 1.0}) |
|
|
320 |
return self.sess.run([self.LOSS_TOTAL, self.LOSS_P, self.LOSS_KL, self.LOSS_Ps, |
|
|
321 |
self.LOSS_KLs, self.LOSS_CONSISTENCY, self.LOSS_Ps_all, self.LOSS_KLs_all], |
|
|
322 |
feed_dict=feed_dict_) |
|
|
323 |
|
|
|
324 |
|
|
|
325 |
def predict_y(self, x_set_, m_): |
|
|
326 |
feed_dict_ = self.make_feed_dict(x_set_) |
|
|
327 |
feed_dict_.update({self.mask: m_, self.mb_size: np.shape(x_set_[0])[0], self.k_prob: 1.0}) |
|
|
328 |
return self.sess.run(self.y_hat, feed_dict=feed_dict_) |
|
|
329 |
|
|
|
330 |
def predict_ys(self, x_set_, m_): |
|
|
331 |
feed_dict_ = self.make_feed_dict(x_set_) |
|
|
332 |
feed_dict_.update({self.mask: m_, self.mb_size: np.shape(x_set_[0])[0], self.k_prob: 1.0}) |
|
|
333 |
return self.sess.run([self.y_hat, self.y_hats], feed_dict=feed_dict_) |
|
|
334 |
|
|
|
335 |
def predict_yhat_set(self, x_set_, m_): |
|
|
336 |
feed_dict_ = self.make_feed_dict(x_set_) |
|
|
337 |
feed_dict_.update({self.mask: m_, self.mb_size: np.shape(x_set_[0])[0], self.k_prob: 1.0}) |
|
|
338 |
return self.sess.run(self.y_hat_set, feed_dict=feed_dict_) |
|
|
339 |
|
|
|
340 |
def predict_mu_z_and_mu_z_set(self, x_set_, m_): #this outputs mu and mu_set |
|
|
341 |
feed_dict_ = self.make_feed_dict(x_set_) |
|
|
342 |
feed_dict_.update({self.mask: m_, self.mb_size: np.shape(x_set_[0])[0], self.k_prob: 1.0}) |
|
|
343 |
return self.sess.run([self.mu_z, self.mu_z_set], feed_dict=feed_dict_) |
|
|
344 |
|
|
|
345 |
def predict_logvar_z_and_logvar_z_set(self, x_set_, m_): #this outputs sigma and sigma_set |
|
|
346 |
feed_dict_ = self.make_feed_dict(x_set_) |
|
|
347 |
feed_dict_.update({self.mask: m_, self.mb_size: np.shape(x_set_[0])[0], self.k_prob: 1.0}) |
|
|
348 |
return self.sess.run([self.logvar_z, self.logvar_z_set], feed_dict=feed_dict_) |
|
|
349 |
|
|
|
350 |
def predict_z_n_z_set(self, x_set_, m_): #this outputs z and z_set |
|
|
351 |
feed_dict_ = self.make_feed_dict(x_set_) |
|
|
352 |
feed_dict_.update({self.mask: m_, self.mb_size: np.shape(x_set_[0])[0], self.k_prob: 1.0}) |
|
|
353 |
return self.sess.run([self.z, self.z_set], feed_dict=feed_dict_) |
|
|
354 |
|
|
|
355 |
def make_feed_dict(self, x_set_): |
|
|
356 |
feed_dict_ = {} |
|
|
357 |
for m in range(len(self.x_set)): |
|
|
358 |
feed_dict_[self.x_set[m]] = x_set_[m] |
|
|
359 |
return feed_dict_ |
|
|
360 |
|
|
|
361 |
|