--- a +++ b/notebooks/visualization.ipynb @@ -0,0 +1,505 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + ".DS_Store\n", + "EHealthQAwithBertAndFFNNEmbeddings.csv\n", + "HealthTapFFNNEmbeddings.csv\n", + "askDocsFFNNEmbeddings.csv\n", + "webMDFFNNEmbeddings.csv\n" + ] + } + ], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "import os\n", + "import faiss\n", + "from sklearn.preprocessing import StandardScaler, MinMaxScaler, QuantileTransformer\n", + "from plotly import offline\n", + "from sklearn.decomposition import TruncatedSVD\n", + "from MulticoreTSNE import MulticoreTSNE as TSNE\n", + "import umap\n", + "\n", + "import plotly.plotly as py\n", + "import plotly.graph_objs as go\n", + "from plotly.offline import init_notebook_mode, iplot\n", + "\n", + "files = os.listdir(\"../data\")\n", + "\n", + "def fix_array(x):\n", + " x = np.fromstring(\n", + " x.replace('\\n','')\n", + " .replace('[','')\n", + " .replace(']','')\n", + " .replace(' ',' '), sep=' ')\n", + " return x.reshape((1, 768))\n", + "\n", + "qa = pd.read_csv(\"../data/\" + files[0])\n", + "for file in files[1:]:\n", + " print(file)\n", + " qa = pd.concat([qa, pd.read_csv(\"../data/\" + file)], axis = 0)\n", + " \n", + "\n", + "qa.drop([\"answer_bert\", \"question_bert\", \"Unnamed: 0\"], axis = 1, inplace = True)\n", + "\n", + "qa[\"Q_FFNN_embeds\"] = qa[\"Q_FFNN_embeds\"].apply(fix_array)\n", + "qa[\"A_FFNN_embeds\"] = qa[\"A_FFNN_embeds\"].apply(fix_array)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>" + ], + "text/vnd.plotly.v1+html": [ + "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>" + ], + "text/vnd.plotly.v1+html": [ + "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>" + ], + "text/vnd.plotly.v1+html": [ + "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>" + ], + "text/vnd.plotly.v1+html": [ + "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>" + ], + "text/vnd.plotly.v1+html": [ + "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>" + ], + "text/vnd.plotly.v1+html": [ + "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>" + ], + "text/vnd.plotly.v1+html": [ + "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>" + ], + "text/vnd.plotly.v1+html": [ + "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>" + ], + "text/vnd.plotly.v1+html": [ + "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>" + ], + "text/vnd.plotly.v1+html": [ + "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>" + ], + "text/vnd.plotly.v1+html": [ + "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>" + ], + "text/vnd.plotly.v1+html": [ + "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>" + ], + "text/vnd.plotly.v1+html": [ + "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>" + ], + "text/vnd.plotly.v1+html": [ + "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>" + ], + "text/vnd.plotly.v1+html": [ + "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>" + ], + "text/vnd.plotly.v1+html": [ + "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>" + ], + "text/vnd.plotly.v1+html": [ + "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>" + ], + "text/vnd.plotly.v1+html": [ + "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>" + ], + "text/vnd.plotly.v1+html": [ + "<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "for n_items in range(100, 5000, 500):\n", + " for perplexity in range(1, 40, 5):\n", + " n_iters = 5000\n", + " qa = qa.sample(frac = 1)\n", + " qa.reset_index(inplace = True, drop = True)\n", + " question_bert = np.concatenate(qa[\"Q_FFNN_embeds\"].values, axis=0)\n", + " answer_bert = np.concatenate(qa[\"A_FFNN_embeds\"].values, axis=0)\n", + " question_bert = question_bert.astype('float32')\n", + " answer_bert = answer_bert.astype('float32')\n", + "\n", + " answer_index = faiss.IndexFlatIP(answer_bert.shape[-1])\n", + " answer_index.add(answer_bert)\n", + "\n", + " question_index = faiss.IndexFlatIP(question_bert.shape[-1])\n", + " question_index.add(question_bert)\n", + "\n", + " k = len(question_bert)\n", + " D1, I1 = answer_index.search(question_bert[0:1].astype('float32'), k)\n", + " D2, I2 = question_index.search(question_bert[0:1].astype('float32'), k)\n", + " QT = QuantileTransformer()\n", + " D2 = (QT.fit_transform(-D2.T)**4) #* 20\n", + " D1 = (QT.fit_transform(-D1.T)**4) #* 20\n", + " closest_ind_q = list(I2[0, :n_items]) +list(I2[0, -n_items:]) \n", + " closest_ind_a = list(I1[0, :n_items]) +list(I1[0, -n_items:])\n", + " dist_answers = answer_bert[closest_ind_a, :]\n", + " dist_questions = question_bert[closest_ind_q, :]\n", + " D1_answers = D1[closest_ind_a, :]\n", + " D2_questions = D2[closest_ind_q, :]\n", + " reducer = TSNE(n_components = 3, perplexity=perplexity, n_iter = n_iters)\n", + " reduced_dimensions = reducer.fit_transform(np.concatenate([dist_questions, dist_answers, answer_bert[0:1]], axis = 0))\n", + " question_bert_3d_close = reduced_dimensions[:n_items]\n", + " question_bert_3d_far = reduced_dimensions[n_items:n_items*2]\n", + " answer_bert_3d_close = reduced_dimensions[n_items*2:n_items*3]\n", + " answer_bert_3d_far = reduced_dimensions[n_items*3:-1]\n", + " question_bert_dist_close = D2_questions[:n_items]\n", + " question_bert_dist_far = D2_questions[n_items:n_items*2]\n", + " answer_bert_dist_close = D1_answers[n_items*2:n_items*3]\n", + " answer_bert_dist_far = D1_answers[n_items*3:-1]\n", + "\n", + " init_notebook_mode(connected=True)\n", + "\n", + " orig_q = go.Scatter3d(\n", + " name = \"Original Question\",\n", + " x=question_bert_3d_close[0:1,0],\n", + " y=question_bert_3d_close[0:1,1],\n", + " z=question_bert_3d_close[0:1,2],\n", + " mode='markers',\n", + " text = qa[\"question\"].loc[closest_ind_q[:1]],\n", + " marker=dict(\n", + " size=12,\n", + " line=dict(\n", + " color='rgba(255, 0, 0, 0.14)',\n", + " width=0.1\n", + " ),\n", + " opacity=1.0\n", + " )\n", + " )\n", + " orig_a = go.Scatter3d(\n", + " name = \"Original Answer\",\n", + " x=reduced_dimensions[-1:,0],\n", + " y=reduced_dimensions[-1:,1],\n", + " z=reduced_dimensions[-1:,2],\n", + " mode='markers',\n", + " text = qa[\"answer\"][0:1],\n", + " marker=dict(\n", + " size=12,\n", + " line=dict(\n", + " color='rgba(0, 255, 0, 0.14)',\n", + " width=0.1\n", + " ),\n", + " opacity=1.0\n", + " )\n", + " )\n", + " recommended_a = go.Scatter3d(\n", + " name = \"Recommended Answers\",\n", + " x=answer_bert_3d_close[0:5,0],\n", + " y=answer_bert_3d_close[0:5,1],\n", + " z=answer_bert_3d_close[0:5,2],\n", + " mode='markers',\n", + " text = qa[\"answer\"].loc[closest_ind_a[:5]],\n", + " marker=dict(\n", + " size=12,\n", + " line=dict(\n", + " color='rgba(0, 255, 0, 0.14)',\n", + " width=0.1\n", + " ),\n", + " opacity=1.0\n", + " )\n", + " )\n", + "\n", + " close_q = go.Scatter3d(\n", + " name = \"Similar Questions\",\n", + " x=question_bert_3d_close[:,0],\n", + " y=question_bert_3d_close[:,1],\n", + " z=question_bert_3d_close[:,2],\n", + " mode='markers',\n", + " text = qa[\"question\"].loc[closest_ind_q],\n", + " marker=dict(\n", + " size=question_bert_dist_close*16,\n", + " line=dict(\n", + " color='rgba(217, 217, 217, 0.14)',\n", + " width=0.1\n", + " ),\n", + " opacity=0.8\n", + " )\n", + " )\n", + "\n", + " close_a = go.Scatter3d(\n", + " name = \"Similar Answers\",\n", + " x=answer_bert_3d_close[5:,0],\n", + " y=answer_bert_3d_close[5:,1],\n", + " z=answer_bert_3d_close[5:,2],\n", + " mode='markers',\n", + " text = qa[\"answer\"].loc[closest_ind_a],\n", + " marker=dict(\n", + " size=answer_bert_dist_close*16,\n", + " line=dict(\n", + " color='rgba(244, 100, 40, 0.14)',\n", + " width=0.1\n", + " ),\n", + " opacity=0.8\n", + " )\n", + " )\n", + "\n", + " far_q = go.Scatter3d(\n", + " name = \"Dissimilar Questions\",\n", + " x=question_bert_3d_far[:,0],\n", + " y=question_bert_3d_far[:,1],\n", + " z=question_bert_3d_far[:,2],\n", + " mode='markers',\n", + " text = qa[\"question\"].loc[closest_ind_q],\n", + " marker=dict(\n", + " size=question_bert_dist_far,\n", + " line=dict(\n", + " color='rgba(40, 100, 217, 0.14)',\n", + " width=0.1\n", + " ),\n", + " opacity=0.8\n", + " )\n", + " )\n", + "\n", + " far_a = go.Scatter3d(\n", + " name = \"Dissimilar Answers\",\n", + " x=answer_bert_3d_far[:,0],\n", + " y=answer_bert_3d_far[:,1],\n", + " z=answer_bert_3d_far[:,2],\n", + " mode='markers',\n", + " text = qa[\"answer\"].loc[closest_ind_a],\n", + " marker=dict(\n", + " size=answer_bert_dist_far,\n", + " line=dict(\n", + " color='rgba(255, 40, 40, 0.14)',\n", + " width=0.1\n", + " ),\n", + " opacity=0.8\n", + " )\n", + " )\n", + "\n", + " data = [orig_q, orig_a, close_q, close_a, \n", + " #far_q, far_a, \n", + " recommended_a\n", + " ]\n", + " layout = go.Layout(\n", + " margin=dict(\n", + " l=0,\n", + " r=0,\n", + " b=0,\n", + " t=0\n", + " )\n", + " )\n", + " fig = go.Figure(data=data, layout=layout)\n", + " #iplot(fig, filename='simple-3d-scatter')\n", + "\n", + " offline.plot(fig, filename=\"./experiments/n_items_\" + str(n_items) + \"_perplexity_\" + str(perplexity) + '.html', auto_open=False) \n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}