# Tutorial
### Handwritten Dataset: https://archive.ics.uci.edu/ml/datasets/Multiple+Features

In [None]:
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import tensorflow as tf

import random
import sys, os

In [None]:
from sklearn.model_selection import train_test_split

In [None]:
import import_data as impt
from helper import f_get_minibatch_set, evaluate
from class_DeepIMV_AISTATS import DeepIMV_AISTATS

### Import dataset
##### x_set is a list of arrays where missing views (for each sample) is replaced with np.nan
##### label must be transformed into one-hot variable. (if continuous, make Y_onehto = Y.reshape([-1,1]))

In [None]:
SEED         = 1234

# this is a sample dataset used for our toy example.
X_set, Y_onehot, Mask = impt.import_incomplete_handwritten()

tr_X_set, te_X_set, va_X_set = {}, {}, {}

# 64/16/20 training/validation/testing split
for m in range(len(X_set)):
    tr_X_set[m],te_X_set[m] = train_test_split(X_set[m], test_size=0.2, random_state=SEED)   
    tr_X_set[m],va_X_set[m] = train_test_split(tr_X_set[m], test_size=0.2, random_state=SEED)
    
tr_Y_onehot,te_Y_onehot, tr_M,te_M = train_test_split(Y_onehot, Mask, test_size=0.2, random_state=SEED)
tr_Y_onehot,va_Y_onehot, tr_M,va_M = train_test_split(tr_Y_onehot, tr_M, test_size=0.2, random_state=SEED)

In [None]:
save_path = './storage/'

if not os.path.exists(save_path):
    os.makedirs(save_path)

### Hyper-parameters

In [None]:
mb_size         = 32 
steps_per_batch = int(np.shape(tr_M)[0]/mb_size) 

x_dim_set    = [tr_X_set[m].shape[1] for m in range(len(tr_X_set))]
y_dim        = np.shape(tr_Y_onehot)[1]
y_type       = 'categorical'

z_dim        = 50

h_dim_p      = 100
num_layers_p = 2

h_dim_e      = 100
num_layers_e = 3

input_dims = {
    'x_dim_set': x_dim_set,
    'y_dim': y_dim,
    'y_type': y_type,
    'z_dim': z_dim,
    
    'steps_per_batch': steps_per_batch
}

network_settings = {
    'h_dim_p1': h_dim_p,
    'num_layers_p1': num_layers_p,   #view-specific
    'h_dim_p2': h_dim_p,
    'num_layers_p2': num_layers_p,  #multi-view
    'h_dim_e': h_dim_e,
    'num_layers_e': num_layers_e,
    'fc_activate_fn': tf.nn.relu,
    'reg_scale': 0., #1e-4,
}


alpha    = 1.0
beta     = 0.01 # IB coefficient
lr_rate  = 1e-4
k_prob   = 0.7

In [None]:
tf.reset_default_graph()

gpu_options = tf.GPUOptions()
sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))

model = DeepIMV_AISTATS(sess, "DeepIMV_AISTATS", input_dims, network_settings)

### Training

In [None]:
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())

ITERATION = 500000
STEPSIZE  = 500

min_loss  = 1e+8   
max_acc   = 0.0
max_flag  = 20

tr_avg_Lt, tr_avg_Lp, tr_avg_Lkl, tr_avg_Lps, tr_avg_Lkls, tr_avg_Lc = 0, 0, 0, 0, 0, 0
va_avg_Lt, va_avg_Lp, va_avg_Lkl, va_avg_Lps, va_avg_Lkls, va_avg_Lc = 0, 0, 0, 0, 0, 0
    
stop_flag = 0
for itr in range(ITERATION):
    x_mb_set, y_mb, m_mb          = f_get_minibatch_set(mb_size, tr_X_set, tr_Y_onehot, tr_M)     
   
    _, Lt, Lp, Lkl, Lps, Lkls, Lc = model.train(x_mb_set, y_mb, m_mb, alpha, beta, lr_rate, k_prob)

    tr_avg_Lt   += Lt/STEPSIZE
    tr_avg_Lp   += Lp/STEPSIZE
    tr_avg_Lkl  += Lkl/STEPSIZE
    tr_avg_Lps  += Lps/STEPSIZE
    tr_avg_Lkls += Lkls/STEPSIZE
    tr_avg_Lc   += Lc/STEPSIZE

    
    x_mb_set, y_mb, m_mb          = f_get_minibatch_set(min(np.shape(va_M)[0], mb_size), va_X_set, va_Y_onehot, va_M)       
    Lt, Lp, Lkl, Lps, Lkls, Lc, _, _    = model.get_loss(x_mb_set, y_mb, m_mb, alpha, beta)
    
    va_avg_Lt   += Lt/STEPSIZE
    va_avg_Lp   += Lp/STEPSIZE
    va_avg_Lkl  += Lkl/STEPSIZE
    va_avg_Lps  += Lps/STEPSIZE
    va_avg_Lkls += Lkls/STEPSIZE
    va_avg_Lc   += Lc/STEPSIZE
    
    if (itr+1)%STEPSIZE == 0:
        y_pred, y_preds = model.predict_ys(va_X_set, va_M)
        
#         score = 

        print( "{:05d}: TRAIN| Lt={:.3f} Lp={:.3f} Lkl={:.3f} Lps={:.3f} Lkls={:.3f} Lc={:.3f} | VALID| Lt={:.3f} Lp={:.3f} Lkl={:.3f} Lps={:.3f} Lkls={:.3f} Lc={:.3f} score={}".format(
            itr+1, tr_avg_Lt, tr_avg_Lp, tr_avg_Lkl, tr_avg_Lps, tr_avg_Lkls, tr_avg_Lc,  
            va_avg_Lt, va_avg_Lp, va_avg_Lkl, va_avg_Lps, va_avg_Lkls, va_avg_Lc, evaluate(va_Y_onehot, np.mean(y_preds, axis=0), y_type))
             )
            
        if min_loss > va_avg_Lt:
            min_loss  = va_avg_Lt
            stop_flag = 0
            saver.save(sess, save_path  + 'best_model')
            print('saved...')
        else:
            stop_flag += 1
                           
        tr_avg_Lt, tr_avg_Lp, tr_avg_Lkl, tr_avg_Lps, tr_avg_Lkls, tr_avg_Lc = 0, 0, 0, 0, 0, 0
        va_avg_Lt, va_avg_Lp, va_avg_Lkl, va_avg_Lps, va_avg_Lkls, va_avg_Lc = 0, 0, 0, 0, 0, 0
        
        if stop_flag >= max_flag:
            break
            
print('FINISHED...')

### Testing

In [None]:
saver.restore(sess, save_path  + 'best_model')

_, pred_ys = model.predict_ys(te_X_set, te_M)
pred_y = np.mean(pred_ys, axis=0)

print('Test Score: {}'.format(evaluate(te_Y_onehot, pred_y, y_type)))