--- a
+++ b/tutorial.ipynb
@@ -0,0 +1,284 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Tutorial\n",
+    "### Handwritten Dataset: https://archive.ics.uci.edu/ml/datasets/Multiple+Features"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "collapsed": true,
+    "jupyter": {
+     "outputs_hidden": true
+    }
+   },
+   "outputs": [],
+   "source": [
+    "import warnings\n",
+    "warnings.filterwarnings('ignore')\n",
+    "\n",
+    "import numpy as np\n",
+    "import tensorflow as tf\n",
+    "\n",
+    "import random\n",
+    "import sys, os"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from sklearn.model_selection import train_test_split"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import import_data as impt\n",
+    "from helper import f_get_minibatch_set, evaluate\n",
+    "from class_DeepIMV_AISTATS import DeepIMV_AISTATS"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Import dataset\n",
+    "##### x_set is a list of arrays where missing views (for each sample) is replaced with np.nan\n",
+    "##### label must be transformed into one-hot variable. (if continuous, make Y_onehto = Y.reshape([-1,1]))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "SEED         = 1234\n",
+    "\n",
+    "# this is a sample dataset used for our toy example.\n",
+    "X_set, Y_onehot, Mask = impt.import_incomplete_handwritten()\n",
+    "\n",
+    "tr_X_set, te_X_set, va_X_set = {}, {}, {}\n",
+    "\n",
+    "# 64/16/20 training/validation/testing split\n",
+    "for m in range(len(X_set)):\n",
+    "    tr_X_set[m],te_X_set[m] = train_test_split(X_set[m], test_size=0.2, random_state=SEED)   \n",
+    "    tr_X_set[m],va_X_set[m] = train_test_split(tr_X_set[m], test_size=0.2, random_state=SEED)\n",
+    "    \n",
+    "tr_Y_onehot,te_Y_onehot, tr_M,te_M = train_test_split(Y_onehot, Mask, test_size=0.2, random_state=SEED)\n",
+    "tr_Y_onehot,va_Y_onehot, tr_M,va_M = train_test_split(tr_Y_onehot, tr_M, test_size=0.2, random_state=SEED)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "save_path = './storage/'\n",
+    "\n",
+    "if not os.path.exists(save_path):\n",
+    "    os.makedirs(save_path)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Hyper-parameters"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "mb_size         = 32 \n",
+    "steps_per_batch = int(np.shape(tr_M)[0]/mb_size) \n",
+    "\n",
+    "x_dim_set    = [tr_X_set[m].shape[1] for m in range(len(tr_X_set))]\n",
+    "y_dim        = np.shape(tr_Y_onehot)[1]\n",
+    "y_type       = 'categorical'\n",
+    "\n",
+    "z_dim        = 50\n",
+    "\n",
+    "h_dim_p      = 100\n",
+    "num_layers_p = 2\n",
+    "\n",
+    "h_dim_e      = 100\n",
+    "num_layers_e = 3\n",
+    "\n",
+    "input_dims = {\n",
+    "    'x_dim_set': x_dim_set,\n",
+    "    'y_dim': y_dim,\n",
+    "    'y_type': y_type,\n",
+    "    'z_dim': z_dim,\n",
+    "    \n",
+    "    'steps_per_batch': steps_per_batch\n",
+    "}\n",
+    "\n",
+    "network_settings = {\n",
+    "    'h_dim_p1': h_dim_p,\n",
+    "    'num_layers_p1': num_layers_p,   #view-specific\n",
+    "    'h_dim_p2': h_dim_p,\n",
+    "    'num_layers_p2': num_layers_p,  #multi-view\n",
+    "    'h_dim_e': h_dim_e,\n",
+    "    'num_layers_e': num_layers_e,\n",
+    "    'fc_activate_fn': tf.nn.relu,\n",
+    "    'reg_scale': 0., #1e-4,\n",
+    "}\n",
+    "\n",
+    "\n",
+    "alpha    = 1.0\n",
+    "beta     = 0.01 # IB coefficient\n",
+    "lr_rate  = 1e-4\n",
+    "k_prob   = 0.7"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "tf.reset_default_graph()\n",
+    "\n",
+    "gpu_options = tf.GPUOptions()\n",
+    "sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))\n",
+    "\n",
+    "model = DeepIMV_AISTATS(sess, \"DeepIMV_AISTATS\", input_dims, network_settings)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Training"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "saver = tf.train.Saver()\n",
+    "sess.run(tf.global_variables_initializer())\n",
+    "\n",
+    "ITERATION = 500000\n",
+    "STEPSIZE  = 500\n",
+    "\n",
+    "min_loss  = 1e+8   \n",
+    "max_acc   = 0.0\n",
+    "max_flag  = 20\n",
+    "\n",
+    "tr_avg_Lt, tr_avg_Lp, tr_avg_Lkl, tr_avg_Lps, tr_avg_Lkls, tr_avg_Lc = 0, 0, 0, 0, 0, 0\n",
+    "va_avg_Lt, va_avg_Lp, va_avg_Lkl, va_avg_Lps, va_avg_Lkls, va_avg_Lc = 0, 0, 0, 0, 0, 0\n",
+    "    \n",
+    "stop_flag = 0\n",
+    "for itr in range(ITERATION):\n",
+    "    x_mb_set, y_mb, m_mb          = f_get_minibatch_set(mb_size, tr_X_set, tr_Y_onehot, tr_M)     \n",
+    "   \n",
+    "    _, Lt, Lp, Lkl, Lps, Lkls, Lc = model.train(x_mb_set, y_mb, m_mb, alpha, beta, lr_rate, k_prob)\n",
+    "\n",
+    "    tr_avg_Lt   += Lt/STEPSIZE\n",
+    "    tr_avg_Lp   += Lp/STEPSIZE\n",
+    "    tr_avg_Lkl  += Lkl/STEPSIZE\n",
+    "    tr_avg_Lps  += Lps/STEPSIZE\n",
+    "    tr_avg_Lkls += Lkls/STEPSIZE\n",
+    "    tr_avg_Lc   += Lc/STEPSIZE\n",
+    "\n",
+    "    \n",
+    "    x_mb_set, y_mb, m_mb          = f_get_minibatch_set(min(np.shape(va_M)[0], mb_size), va_X_set, va_Y_onehot, va_M)       \n",
+    "    Lt, Lp, Lkl, Lps, Lkls, Lc, _, _    = model.get_loss(x_mb_set, y_mb, m_mb, alpha, beta)\n",
+    "    \n",
+    "    va_avg_Lt   += Lt/STEPSIZE\n",
+    "    va_avg_Lp   += Lp/STEPSIZE\n",
+    "    va_avg_Lkl  += Lkl/STEPSIZE\n",
+    "    va_avg_Lps  += Lps/STEPSIZE\n",
+    "    va_avg_Lkls += Lkls/STEPSIZE\n",
+    "    va_avg_Lc   += Lc/STEPSIZE\n",
+    "    \n",
+    "    if (itr+1)%STEPSIZE == 0:\n",
+    "        y_pred, y_preds = model.predict_ys(va_X_set, va_M)\n",
+    "        \n",
+    "#         score = \n",
+    "\n",
+    "        print( \"{:05d}: TRAIN| Lt={:.3f} Lp={:.3f} Lkl={:.3f} Lps={:.3f} Lkls={:.3f} Lc={:.3f} | VALID| Lt={:.3f} Lp={:.3f} Lkl={:.3f} Lps={:.3f} Lkls={:.3f} Lc={:.3f} score={}\".format(\n",
+    "            itr+1, tr_avg_Lt, tr_avg_Lp, tr_avg_Lkl, tr_avg_Lps, tr_avg_Lkls, tr_avg_Lc,  \n",
+    "            va_avg_Lt, va_avg_Lp, va_avg_Lkl, va_avg_Lps, va_avg_Lkls, va_avg_Lc, evaluate(va_Y_onehot, np.mean(y_preds, axis=0), y_type))\n",
+    "             )\n",
+    "            \n",
+    "        if min_loss > va_avg_Lt:\n",
+    "            min_loss  = va_avg_Lt\n",
+    "            stop_flag = 0\n",
+    "            saver.save(sess, save_path  + 'best_model')\n",
+    "            print('saved...')\n",
+    "        else:\n",
+    "            stop_flag += 1\n",
+    "                           \n",
+    "        tr_avg_Lt, tr_avg_Lp, tr_avg_Lkl, tr_avg_Lps, tr_avg_Lkls, tr_avg_Lc = 0, 0, 0, 0, 0, 0\n",
+    "        va_avg_Lt, va_avg_Lp, va_avg_Lkl, va_avg_Lps, va_avg_Lkls, va_avg_Lc = 0, 0, 0, 0, 0, 0\n",
+    "        \n",
+    "        if stop_flag >= max_flag:\n",
+    "            break\n",
+    "            \n",
+    "print('FINISHED...')"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Testing"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "saver.restore(sess, save_path  + 'best_model')\n",
+    "\n",
+    "_, pred_ys = model.predict_ys(te_X_set, te_M)\n",
+    "pred_y = np.mean(pred_ys, axis=0)\n",
+    "\n",
+    "print('Test Score: {}'.format(evaluate(te_Y_onehot, pred_y, y_type)))"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.6.10"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}