--- a
+++ b/reproduce_table3_base1.ipynb
@@ -0,0 +1,388 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Tutorial - TCGA\n",
+    "### Generating Results in Table 3 (TCGA Dataset)\n",
+    "### Method: BASE1"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [],
+   "source": [
+    "import warnings\n",
+    "warnings.filterwarnings('ignore')\n",
+    "\n",
+    "import numpy as np\n",
+    "import tensorflow as tf\n",
+    "\n",
+    "import random\n",
+    "import sys, os"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from sklearn.model_selection import train_test_split\n",
+    "\n",
+    "import import_data as impt\n",
+    "from helper import f_get_minibatch_set, evaluate\n",
+    "from class_Baseline_Concat import Baseline_Concat"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "year         = 1\n",
+    "DATASET_PATH = 'TCGA_{}YR'.format(int(year))\n",
+    "DATASET      = 'TCGA'\n",
+    "\n",
+    "X_set_comp, Y_onehot_comp, Mask_comp, X_set_incomp, Y_onehot_incomp, Mask_incomp = impt.import_dataset_TCGA(year)\n",
+    "\n",
+    "MODE       = 'incomplete'\n",
+    "model_name = 'Base1'\n",
+    "\n",
+    "M = len(X_set_comp)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "SEED = 1234\n",
+    "OUTITERATION = 5\n",
+    "\n",
+    "RESULTS_AUROC_RAND = np.zeros([4, OUTITERATION+2])\n",
+    "RESULTS_AUPRC_RAND = np.zeros([4, OUTITERATION+2])"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# IMPORT DATASET"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "out_itr = 1\n",
+    "\n",
+    "tr_X_set, te_X_set, va_X_set = {}, {}, {}\n",
+    "for m in range(M):\n",
+    "    tr_X_set[m],te_X_set[m] = train_test_split(X_set_comp[m], test_size=0.2, random_state=SEED + out_itr)\n",
+    "    tr_X_set[m],va_X_set[m] = train_test_split(tr_X_set[m], test_size=0.2, random_state=SEED + out_itr)\n",
+    "    \n",
+    "tr_Y_onehot,te_Y_onehot, tr_M,te_M = train_test_split(Y_onehot_comp, Mask_comp, test_size=0.2, random_state=SEED + out_itr)\n",
+    "tr_Y_onehot,va_Y_onehot, tr_M,va_M = train_test_split(tr_Y_onehot, tr_M, test_size=0.2, random_state=SEED + out_itr)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "(5850, 4)\n"
+     ]
+    }
+   ],
+   "source": [
+    "if MODE == 'incomplete':\n",
+    "    for m in range(M):\n",
+    "        tr_X_set[m] = np.concatenate([tr_X_set[m], X_set_incomp[m]], axis=0)\n",
+    "\n",
+    "    tr_Y_onehot = np.concatenate([tr_Y_onehot, Y_onehot_incomp], axis=0)\n",
+    "    tr_M        = np.concatenate([tr_M, Mask_incomp], axis=0)\n",
+    "    \n",
+    "    print(tr_M.shape)\n",
+    "elif MODE == 'complete':\n",
+    "    print(tr_M.shape)\n",
+    "else:\n",
+    "    raise ValueError('WRONG MODE!!!')\n",
+    "    \n",
+    "\n",
+    "save_path = '{}/M{}_{}/{}/'.format(DATASET_PATH, M, MODE, model_name)\n",
+    "    \n",
+    "    \n",
+    "if not os.path.exists(save_path + 'itr{}/'.format(out_itr)):\n",
+    "    os.makedirs(save_path + 'itr{}/'.format(out_itr))"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Hyper-parameters"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 14,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "x_dim_set    = [tr_X_set[m].shape[1] for m in range(len(tr_X_set))]\n",
+    "y_dim        = np.shape(tr_Y_onehot)[1]\n",
+    "\n",
+    "y_type       = 'binary'\n",
+    "\n",
+    "z_dim        = 100\n",
+    "\n",
+    "h_dim_p      = 100\n",
+    "num_layers_p = 2\n",
+    "\n",
+    "h_dim_e      = 100\n",
+    "num_layers_e = 3\n",
+    "\n",
+    "\n",
+    "input_dims = {\n",
+    "    'x_dim_set': x_dim_set,\n",
+    "    'y_dim': y_dim,\n",
+    "    'y_type': y_type,\n",
+    "    'z_dim': z_dim\n",
+    "}\n",
+    "\n",
+    "network_settings = {\n",
+    "    'h_dim_p': h_dim_p,\n",
+    "    'num_layers_p': num_layers_p,\n",
+    "    'h_dim_e': h_dim_e,\n",
+    "    'num_layers_e': num_layers_e,\n",
+    "    'fc_activate_fn': tf.nn.relu,\n",
+    "    'reg_scale': 0., \n",
+    "}\n",
+    "\n",
+    "mb_size  = 32\n",
+    "lr_rate  = 1e-4\n",
+    "k_prob   = 0.7\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Training"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 15,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "tf.reset_default_graph()\n",
+    "\n",
+    "gpu_options = tf.GPUOptions()\n",
+    "sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))\n",
+    "\n",
+    "model = Baseline_Concat(sess, \"Base1\", input_dims, network_settings)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 16,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "00500: TRAIN| LT=0.969   VALID| LT=0.804   auroc=0.7378   auprc=0.2977\n",
+      "saved...\n",
+      "01000: TRAIN| LT=0.743   VALID| LT=0.716   auroc=0.7685   auprc=0.3671\n",
+      "saved...\n",
+      "01500: TRAIN| LT=0.675   VALID| LT=0.721   auroc=0.7703   auprc=0.3434\n",
+      "02000: TRAIN| LT=0.637   VALID| LT=0.748   auroc=0.7757   auprc=0.3746\n",
+      "02500: TRAIN| LT=0.610   VALID| LT=0.777   auroc=0.7757   auprc=0.3422\n",
+      "03000: TRAIN| LT=0.580   VALID| LT=0.746   auroc=0.7717   auprc=0.3369\n",
+      "03500: TRAIN| LT=0.558   VALID| LT=0.774   auroc=0.7773   auprc=0.3451\n",
+      "04000: TRAIN| LT=0.534   VALID| LT=0.782   auroc=0.7663   auprc=0.3426\n",
+      "04500: TRAIN| LT=0.518   VALID| LT=0.810   auroc=0.7709   auprc=0.3266\n",
+      "05000: TRAIN| LT=0.491   VALID| LT=0.845   auroc=0.7719   auprc=0.3573\n",
+      "05500: TRAIN| LT=0.481   VALID| LT=0.863   auroc=0.7634   auprc=0.3282\n",
+      "06000: TRAIN| LT=0.447   VALID| LT=0.897   auroc=0.7594   auprc=0.3321\n",
+      "06500: TRAIN| LT=0.425   VALID| LT=0.951   auroc=0.7658   auprc=0.3384\n",
+      "07000: TRAIN| LT=0.397   VALID| LT=0.932   auroc=0.7734   auprc=0.3436\n",
+      "07500: TRAIN| LT=0.381   VALID| LT=1.012   auroc=0.7687   auprc=0.3375\n",
+      "08000: TRAIN| LT=0.370   VALID| LT=1.079   auroc=0.7547   auprc=0.3164\n",
+      "08500: TRAIN| LT=0.338   VALID| LT=1.129   auroc=0.7564   auprc=0.3148\n",
+      "09000: TRAIN| LT=0.338   VALID| LT=1.211   auroc=0.7571   auprc=0.3164\n",
+      "09500: TRAIN| LT=0.299   VALID| LT=1.251   auroc=0.7514   auprc=0.3062\n",
+      "10000: TRAIN| LT=0.283   VALID| LT=1.384   auroc=0.7558   auprc=0.3013\n",
+      "10500: TRAIN| LT=0.264   VALID| LT=1.407   auroc=0.7531   auprc=0.2913\n",
+      "11000: TRAIN| LT=0.250   VALID| LT=1.535   auroc=0.7510   auprc=0.2931\n",
+      "FINISHED...\n"
+     ]
+    }
+   ],
+   "source": [
+    "saver = tf.train.Saver()\n",
+    "sess.run(tf.global_variables_initializer())\n",
+    "\n",
+    "ITERATION = 500000\n",
+    "STEPSIZE  = 500\n",
+    "\n",
+    "min_loss  = 1e+8   \n",
+    "max_acc   = 0.0\n",
+    "max_flag  = 20\n",
+    "\n",
+    "tr_avg_Lt, va_avg_Lt  = 0, 0\n",
+    "\n",
+    "stop_flag = 0\n",
+    "for itr in range(ITERATION):\n",
+    "    x_mb_set, y_mb, m_mb          = f_get_minibatch_set(mb_size, tr_X_set, tr_Y_onehot, tr_M)\n",
+    "    _, Lt                         = model.train(x_mb_set, y_mb, lr_rate, k_prob)\n",
+    "\n",
+    "    tr_avg_Lt  += Lt/STEPSIZE\n",
+    "\n",
+    "\n",
+    "    x_mb_set, y_mb, m_mb          = f_get_minibatch_set(min(np.shape(va_M)[0], mb_size), va_X_set, va_Y_onehot, va_M)        \n",
+    "    Lt                            = model.get_loss(x_mb_set, y_mb)\n",
+    "\n",
+    "    va_avg_Lt  += Lt/STEPSIZE\n",
+    "\n",
+    "    if (itr+1)%STEPSIZE == 0:\n",
+    "        y_pred = model.predict_y(va_X_set)\n",
+    "        auroc, auprc = evaluate(va_Y_onehot, y_pred, y_type)\n",
+    "\n",
+    "        print( \"{:05d}: TRAIN| LT={:.3f}   VALID| LT={:.3f}   auroc={:.4f}   auprc={:.4f}\".format(\n",
+    "            itr+1, tr_avg_Lt, va_avg_Lt, auroc, auprc))\n",
+    "\n",
+    "        if min_loss > va_avg_Lt:\n",
+    "            min_loss  = va_avg_Lt\n",
+    "            stop_flag = 0\n",
+    "            saver.save(sess, save_path  + 'itr{}/best_model'.format(out_itr))\n",
+    "            print('saved...')\n",
+    "        else:\n",
+    "            stop_flag += 1\n",
+    "\n",
+    "        tr_avg_Lt = 0\n",
+    "        va_avg_Lt = 0 \n",
+    "\n",
+    "        if stop_flag >= max_flag:\n",
+    "            break\n",
+    "\n",
+    "print('FINISHED...')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 17,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "WARNING:tensorflow:From /home/vdslab/anaconda3/envs/py37/lib/python3.7/site-packages/tensorflow/python/training/saver.py:1266: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.\n",
+      "Instructions for updating:\n",
+      "Use standard file APIs to check for files with this prefix.\n",
+      "INFO:tensorflow:Restoring parameters from TCGA_1YR/M4_incomplete/Base1/itr1/best_model\n"
+     ]
+    }
+   ],
+   "source": [
+    "saver.restore(sess, save_path  + 'itr{}/best_model'.format(out_itr))"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Evaluation -- (Results in Table 3)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 21,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "TEST - INCOMPLETE - #VIEW 1: auroc=0.6844  auprc=0.3007\n",
+      "TEST - INCOMPLETE - #VIEW 2: auroc=0.7444  auprc=0.3537\n",
+      "TEST - INCOMPLETE - #VIEW 3: auroc=0.7774  auprc=0.3685\n",
+      "TEST - INCOMPLETE - #VIEW 4: auroc=0.7805  auprc=0.3945\n"
+     ]
+    }
+   ],
+   "source": [
+    "mean_X = {}\n",
+    "for m in range(4):\n",
+    "    mean_X[m] = X_set_incomp[m][Mask_incomp[:,m] == 0][0, :]\n",
+    "\n",
+    "for m_available in [1,2,3,4]:\n",
+    "    tmp_M_mis = np.zeros_like(te_M)#np.copy(te_M)\n",
+    "\n",
+    "    for i in range(len(tmp_M_mis)):\n",
+    "        np.random.seed(SEED+out_itr+i)\n",
+    "        idx = np.random.choice(4, m_available, replace=False)\n",
+    "        tmp_M_mis[i, idx] = 1\n",
+    "\n",
+    "    tmp_te_X = {}\n",
+    "    for m in range(M):\n",
+    "        tmp_te_X[m] = np.copy(te_X_set[m])\n",
+    "        tmp_te_X[m][tmp_M_mis[:,m] == 0] = mean_X[m] \n",
+    "\n",
+    "    y_pred = model.predict_y(tmp_te_X)\n",
+    "    auc1, apc1 = evaluate(te_Y_onehot, y_pred, y_type)\n",
+    "\n",
+    "    RESULTS_AUROC_RAND[m_available-1, out_itr] = auc1\n",
+    "    RESULTS_AUPRC_RAND[m_available-1, out_itr] = apc1\n",
+    "\n",
+    "    print(\"TEST - {} - #VIEW {}: auroc={:.4f}  auprc={:.4f}\".format(MODE.upper(), m_available,  auc1, apc1))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.7.9"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}