Switch to side-by-side view

--- a
+++ b/notebooks/visualization.ipynb
@@ -0,0 +1,505 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      ".DS_Store\n",
+      "EHealthQAwithBertAndFFNNEmbeddings.csv\n",
+      "HealthTapFFNNEmbeddings.csv\n",
+      "askDocsFFNNEmbeddings.csv\n",
+      "webMDFFNNEmbeddings.csv\n"
+     ]
+    }
+   ],
+   "source": [
+    "import pandas as pd\n",
+    "import numpy as np\n",
+    "import os\n",
+    "import faiss\n",
+    "from sklearn.preprocessing import StandardScaler, MinMaxScaler, QuantileTransformer\n",
+    "from plotly import offline\n",
+    "from sklearn.decomposition import TruncatedSVD\n",
+    "from MulticoreTSNE import MulticoreTSNE as TSNE\n",
+    "import umap\n",
+    "\n",
+    "import plotly.plotly as py\n",
+    "import plotly.graph_objs as go\n",
+    "from plotly.offline import init_notebook_mode, iplot\n",
+    "\n",
+    "files = os.listdir(\"../data\")\n",
+    "\n",
+    "def fix_array(x):\n",
+    "    x = np.fromstring(\n",
+    "    x.replace('\\n','')\n",
+    "    .replace('[','')\n",
+    "    .replace(']','')\n",
+    "    .replace('  ',' '), sep=' ')\n",
+    "    return x.reshape((1, 768))\n",
+    "\n",
+    "qa = pd.read_csv(\"../data/\" + files[0])\n",
+    "for file in files[1:]:\n",
+    "    print(file)\n",
+    "    qa = pd.concat([qa, pd.read_csv(\"../data/\" + file)], axis = 0)\n",
+    "    \n",
+    "\n",
+    "qa.drop([\"answer_bert\", \"question_bert\", \"Unnamed: 0\"], axis = 1, inplace = True)\n",
+    "\n",
+    "qa[\"Q_FFNN_embeds\"] = qa[\"Q_FFNN_embeds\"].apply(fix_array)\n",
+    "qa[\"A_FFNN_embeds\"] = qa[\"A_FFNN_embeds\"].apply(fix_array)\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>"
+      ],
+      "text/vnd.plotly.v1+html": [
+       "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>"
+      ],
+      "text/vnd.plotly.v1+html": [
+       "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>"
+      ],
+      "text/vnd.plotly.v1+html": [
+       "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>"
+      ],
+      "text/vnd.plotly.v1+html": [
+       "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>"
+      ],
+      "text/vnd.plotly.v1+html": [
+       "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>"
+      ],
+      "text/vnd.plotly.v1+html": [
+       "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>"
+      ],
+      "text/vnd.plotly.v1+html": [
+       "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>"
+      ],
+      "text/vnd.plotly.v1+html": [
+       "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>"
+      ],
+      "text/vnd.plotly.v1+html": [
+       "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>"
+      ],
+      "text/vnd.plotly.v1+html": [
+       "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>"
+      ],
+      "text/vnd.plotly.v1+html": [
+       "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>"
+      ],
+      "text/vnd.plotly.v1+html": [
+       "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>"
+      ],
+      "text/vnd.plotly.v1+html": [
+       "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>"
+      ],
+      "text/vnd.plotly.v1+html": [
+       "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>"
+      ],
+      "text/vnd.plotly.v1+html": [
+       "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>"
+      ],
+      "text/vnd.plotly.v1+html": [
+       "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>"
+      ],
+      "text/vnd.plotly.v1+html": [
+       "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>"
+      ],
+      "text/vnd.plotly.v1+html": [
+       "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>"
+      ],
+      "text/vnd.plotly.v1+html": [
+       "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "for n_items in range(100, 5000, 500):\n",
+    "    for perplexity in range(1, 40, 5):\n",
+    "        n_iters = 5000\n",
+    "        qa = qa.sample(frac = 1)\n",
+    "        qa.reset_index(inplace = True, drop = True)\n",
+    "        question_bert = np.concatenate(qa[\"Q_FFNN_embeds\"].values, axis=0)\n",
+    "        answer_bert = np.concatenate(qa[\"A_FFNN_embeds\"].values, axis=0)\n",
+    "        question_bert = question_bert.astype('float32')\n",
+    "        answer_bert = answer_bert.astype('float32')\n",
+    "\n",
+    "        answer_index = faiss.IndexFlatIP(answer_bert.shape[-1])\n",
+    "        answer_index.add(answer_bert)\n",
+    "\n",
+    "        question_index = faiss.IndexFlatIP(question_bert.shape[-1])\n",
+    "        question_index.add(question_bert)\n",
+    "\n",
+    "        k = len(question_bert)\n",
+    "        D1, I1 = answer_index.search(question_bert[0:1].astype('float32'), k)\n",
+    "        D2, I2 = question_index.search(question_bert[0:1].astype('float32'), k)\n",
+    "        QT = QuantileTransformer()\n",
+    "        D2 = (QT.fit_transform(-D2.T)**4) #* 20\n",
+    "        D1 = (QT.fit_transform(-D1.T)**4) #* 20\n",
+    "        closest_ind_q = list(I2[0, :n_items]) +list(I2[0, -n_items:])                                          \n",
+    "        closest_ind_a = list(I1[0, :n_items]) +list(I1[0, -n_items:])\n",
+    "        dist_answers = answer_bert[closest_ind_a, :]\n",
+    "        dist_questions = question_bert[closest_ind_q, :]\n",
+    "        D1_answers = D1[closest_ind_a, :]\n",
+    "        D2_questions = D2[closest_ind_q, :]\n",
+    "        reducer = TSNE(n_components = 3, perplexity=perplexity, n_iter = n_iters)\n",
+    "        reduced_dimensions = reducer.fit_transform(np.concatenate([dist_questions, dist_answers, answer_bert[0:1]], axis = 0))\n",
+    "        question_bert_3d_close = reduced_dimensions[:n_items]\n",
+    "        question_bert_3d_far = reduced_dimensions[n_items:n_items*2]\n",
+    "        answer_bert_3d_close = reduced_dimensions[n_items*2:n_items*3]\n",
+    "        answer_bert_3d_far = reduced_dimensions[n_items*3:-1]\n",
+    "        question_bert_dist_close = D2_questions[:n_items]\n",
+    "        question_bert_dist_far = D2_questions[n_items:n_items*2]\n",
+    "        answer_bert_dist_close = D1_answers[n_items*2:n_items*3]\n",
+    "        answer_bert_dist_far = D1_answers[n_items*3:-1]\n",
+    "\n",
+    "        init_notebook_mode(connected=True)\n",
+    "\n",
+    "        orig_q = go.Scatter3d(\n",
+    "            name = \"Original Question\",\n",
+    "            x=question_bert_3d_close[0:1,0],\n",
+    "            y=question_bert_3d_close[0:1,1],\n",
+    "            z=question_bert_3d_close[0:1,2],\n",
+    "            mode='markers',\n",
+    "            text = qa[\"question\"].loc[closest_ind_q[:1]],\n",
+    "            marker=dict(\n",
+    "                size=12,\n",
+    "                line=dict(\n",
+    "                    color='rgba(255, 0, 0, 0.14)',\n",
+    "                    width=0.1\n",
+    "                ),\n",
+    "                opacity=1.0\n",
+    "            )\n",
+    "        )\n",
+    "        orig_a = go.Scatter3d(\n",
+    "            name = \"Original Answer\",\n",
+    "            x=reduced_dimensions[-1:,0],\n",
+    "            y=reduced_dimensions[-1:,1],\n",
+    "            z=reduced_dimensions[-1:,2],\n",
+    "            mode='markers',\n",
+    "            text = qa[\"answer\"][0:1],\n",
+    "            marker=dict(\n",
+    "                size=12,\n",
+    "                line=dict(\n",
+    "                    color='rgba(0, 255, 0, 0.14)',\n",
+    "                    width=0.1\n",
+    "                ),\n",
+    "                opacity=1.0\n",
+    "            )\n",
+    "        )\n",
+    "        recommended_a = go.Scatter3d(\n",
+    "            name = \"Recommended Answers\",\n",
+    "            x=answer_bert_3d_close[0:5,0],\n",
+    "            y=answer_bert_3d_close[0:5,1],\n",
+    "            z=answer_bert_3d_close[0:5,2],\n",
+    "            mode='markers',\n",
+    "            text = qa[\"answer\"].loc[closest_ind_a[:5]],\n",
+    "            marker=dict(\n",
+    "                size=12,\n",
+    "                line=dict(\n",
+    "                    color='rgba(0, 255, 0, 0.14)',\n",
+    "                    width=0.1\n",
+    "                ),\n",
+    "                opacity=1.0\n",
+    "            )\n",
+    "        )\n",
+    "\n",
+    "        close_q = go.Scatter3d(\n",
+    "            name = \"Similar Questions\",\n",
+    "            x=question_bert_3d_close[:,0],\n",
+    "            y=question_bert_3d_close[:,1],\n",
+    "            z=question_bert_3d_close[:,2],\n",
+    "            mode='markers',\n",
+    "            text = qa[\"question\"].loc[closest_ind_q],\n",
+    "            marker=dict(\n",
+    "                size=question_bert_dist_close*16,\n",
+    "                line=dict(\n",
+    "                    color='rgba(217, 217, 217, 0.14)',\n",
+    "                    width=0.1\n",
+    "                ),\n",
+    "                opacity=0.8\n",
+    "            )\n",
+    "        )\n",
+    "\n",
+    "        close_a = go.Scatter3d(\n",
+    "            name = \"Similar Answers\",\n",
+    "            x=answer_bert_3d_close[5:,0],\n",
+    "            y=answer_bert_3d_close[5:,1],\n",
+    "            z=answer_bert_3d_close[5:,2],\n",
+    "            mode='markers',\n",
+    "            text = qa[\"answer\"].loc[closest_ind_a],\n",
+    "            marker=dict(\n",
+    "                size=answer_bert_dist_close*16,\n",
+    "                line=dict(\n",
+    "                    color='rgba(244, 100, 40, 0.14)',\n",
+    "                    width=0.1\n",
+    "                ),\n",
+    "                opacity=0.8\n",
+    "            )\n",
+    "        )\n",
+    "\n",
+    "        far_q = go.Scatter3d(\n",
+    "            name = \"Dissimilar Questions\",\n",
+    "            x=question_bert_3d_far[:,0],\n",
+    "            y=question_bert_3d_far[:,1],\n",
+    "            z=question_bert_3d_far[:,2],\n",
+    "            mode='markers',\n",
+    "            text = qa[\"question\"].loc[closest_ind_q],\n",
+    "            marker=dict(\n",
+    "                size=question_bert_dist_far,\n",
+    "                line=dict(\n",
+    "                    color='rgba(40, 100, 217, 0.14)',\n",
+    "                    width=0.1\n",
+    "                ),\n",
+    "                opacity=0.8\n",
+    "            )\n",
+    "        )\n",
+    "\n",
+    "        far_a = go.Scatter3d(\n",
+    "            name = \"Dissimilar Answers\",\n",
+    "            x=answer_bert_3d_far[:,0],\n",
+    "            y=answer_bert_3d_far[:,1],\n",
+    "            z=answer_bert_3d_far[:,2],\n",
+    "            mode='markers',\n",
+    "            text = qa[\"answer\"].loc[closest_ind_a],\n",
+    "            marker=dict(\n",
+    "                size=answer_bert_dist_far,\n",
+    "                line=dict(\n",
+    "                    color='rgba(255, 40, 40, 0.14)',\n",
+    "                    width=0.1\n",
+    "                ),\n",
+    "                opacity=0.8\n",
+    "            )\n",
+    "        )\n",
+    "\n",
+    "        data = [orig_q, orig_a, close_q, close_a, \n",
+    "                #far_q, far_a, \n",
+    "                recommended_a\n",
+    "               ]\n",
+    "        layout = go.Layout(\n",
+    "            margin=dict(\n",
+    "                l=0,\n",
+    "                r=0,\n",
+    "                b=0,\n",
+    "                t=0\n",
+    "            )\n",
+    "        )\n",
+    "        fig = go.Figure(data=data, layout=layout)\n",
+    "        #iplot(fig, filename='simple-3d-scatter')\n",
+    "\n",
+    "        offline.plot(fig, filename=\"./experiments/n_items_\" + str(n_items) + \"_perplexity_\" + str(perplexity) + '.html', auto_open=False) \n",
+    "        "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.6.4"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}