[7b0fc8]: / MED277_bot.ipynb

Download this file

884 lines (883 with data), 30.1 kB

{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-06-11T22:22:28.421703Z",
     "start_time": "2018-06-11T22:22:26.327228Z"
    }
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from sklearn.externals import joblib\n",
    "import re\n",
    "from nltk.stem.snowball import SnowballStemmer\n",
    "from collections import defaultdict\n",
    "import operator\n",
    "import numpy as np\n",
    "import sklearn.feature_extraction.text as text\n",
    "from sklearn import decomposition\n",
    "from nltk.stem import PorterStemmer, WordNetLemmatizer\n",
    "from sklearn.decomposition import PCA\n",
    "from numpy.linalg import norm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-06-11T22:22:30.024471Z",
     "start_time": "2018-06-11T22:22:30.000549Z"
    }
   },
   "outputs": [],
   "source": [
    "'''This function loads discharge summary data from MIMIC-III dataset NOTEEVENTS.csv.gz file.\n",
    "update the base_path to be the location of this zipped '''\n",
    "def load_data():\n",
    "    ## Intitializing data paths\n",
    "    base_path = r'D:\\ORGANIZATION\\UCSD_Life\\Work\\4. Quarter-3\\Subjects\\MED 277\\Project\\DATA\\\\'\n",
    "    data_file = base_path+\"NOTEEVENTS.csv.gz\"\n",
    "    \n",
    "    ## Loading data frames from CSV file\n",
    "    df = pd.read_csv(data_file, compression='gzip')\n",
    "    \n",
    "    ## Uncomment this to slice the size of dataset\n",
    "    #df = df[:10000]\n",
    "    \n",
    "    ## Uncomment this to save processed data to the memory\n",
    "    #joblib.dump(df,base_path+'data10.pkl')\n",
    "    ## loading data frames from PKL memory\n",
    "    #df1 =  joblib.load(base_path+'data10.pkl')\n",
    "    \n",
    "    ## Filtering dataframe for \"Discharge summaries\" and \"TEXT\"\n",
    "    df = df.loc[df['CATEGORY'] == 'Discharge summary'] #Extracting only discharge summaries\n",
    "    df_text = df['TEXT']\n",
    "    return df_text"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## EXTRACT ALL THE TOPICS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-06-11T22:22:32.608559Z",
     "start_time": "2018-06-11T22:22:32.603573Z"
    }
   },
   "outputs": [],
   "source": [
    "'''Method that processes the entire document string'''\n",
    "def process_text(txt):\n",
    "    txt1 = re.sub('[\\n]',\" \",txt)\n",
    "    txt1 = re.sub('[^A-Za-z \\.]+', '', txt1)\n",
    "    \n",
    "    return txt1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-06-11T22:22:33.596917Z",
     "start_time": "2018-06-11T22:22:33.586970Z"
    }
   },
   "outputs": [],
   "source": [
    "'''Method that processes the document string not considering separate lines'''\n",
    "def process(txt):\n",
    "    txt1 = re.sub('[\\n]',\" \",txt)\n",
    "    txt1 = re.sub('[^A-Za-z ]+', '', txt1)\n",
    "    \n",
    "    _wrds = txt1.split()\n",
    "    stemmer = SnowballStemmer(\"english\") ## May use porter stemmer\n",
    "    wrds = [stemmer.stem(wrd) for wrd in _wrds]\n",
    "    return wrds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-06-11T22:22:34.412734Z",
     "start_time": "2018-06-11T22:22:34.402790Z"
    }
   },
   "outputs": [],
   "source": [
    "'''Method that processes raw string and gets a processes list containing lines'''\n",
    "def get_processed_sentences(snt_txt):\n",
    "    snt_list = []\n",
    "    for line in snt_txt.split('.'):\n",
    "        line = line.strip()\n",
    "        if len(line.split()) >= 5:\n",
    "            snt_list.append(line)\n",
    "    return snt_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-06-11T22:22:35.078078Z",
     "start_time": "2018-06-11T22:22:35.051176Z"
    }
   },
   "outputs": [],
   "source": [
    "'''This method extracts topic from sentence'''\n",
    "def extract_topic(str_arg, num_topics = 1, num_top_words = 3):\n",
    "    vectorizer = text.CountVectorizer(input='content', analyzer='word', lowercase=True, stop_words='english')\n",
    "    try:\n",
    "        dtm = vectorizer.fit_transform(str_arg.split())\n",
    "        vocab = np.array(vectorizer.get_feature_names())\n",
    "    \n",
    "        #clf = decomposition.NMF(n_components=num_topics, random_state=1) ## topic extraction\n",
    "        clf = decomposition.LatentDirichletAllocation(n_components=num_topics, learning_method='online')\n",
    "        clf.fit_transform(dtm)\n",
    "\n",
    "        topic_words = []\n",
    "        for topic in clf.components_:\n",
    "            word_idx = np.argsort(topic)[::-1][0:num_top_words] ##[::-1] reverses the list\n",
    "            topic_words.append([vocab[i] for i in word_idx])\n",
    "        return topic_words\n",
    "    except:\n",
    "        return None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-06-11T22:22:35.804137Z",
     "start_time": "2018-06-11T22:22:35.794166Z"
    }
   },
   "outputs": [],
   "source": [
    "'''This method extracts topics of each sentence and returns a list'''\n",
    "def extract_topics_all(doc_string):\n",
    "    #One entry per sentence in list\n",
    "    doc_str = process_text(doc_string)\n",
    "    doc_str = get_processed_sentences(doc_str)\n",
    "    \n",
    "    res = []\n",
    "    for i in range (0, len(doc_str)):\n",
    "        snd_str = doc_str[i].lower()\n",
    "        #print(\"Sending ----------------------------\",snd_str,\"==========\",len(snd_str))\n",
    "        tmp_topic = extract_topic(snd_str, num_topics = 2, num_top_words = 1)\n",
    "        for top in tmp_topic:\n",
    "            for wrd in top:\n",
    "                res.append(wrd)\n",
    "    return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-06-11T22:22:36.470386Z",
     "start_time": "2018-06-11T22:22:36.462381Z"
    }
   },
   "outputs": [],
   "source": [
    "'''This function takes a dataframe and returns all the topics in the entire corpus'''\n",
    "def extract_corpus_topics(arg_df):\n",
    "    all_topics = set()\n",
    "    cnt = 1\n",
    "    for txt in arg_df:\n",
    "        all_topics = all_topics.union(extract_topics_all(txt))\n",
    "        print(\"Processed \",cnt,\" records\")\n",
    "        cnt += 1\n",
    "    all_topics = list(all_topics)\n",
    "    return all_topics"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## GET A VECTORIZED REPRESENTATION OF ALL THE TOPICS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-06-11T22:22:38.161868Z",
     "start_time": "2018-06-11T22:22:38.140924Z"
    }
   },
   "outputs": [],
   "source": [
    "'''data_set = words list per document.\n",
    "    vocabulary = list of all the words present\n",
    "    _vocab = dict of word counts for words in vocabulary'''\n",
    "def get_vocab_wrd_map(df_text):\n",
    "    data_set = []\n",
    "    vocabulary = []\n",
    "    _vocab = defaultdict(int)\n",
    "    for i in range(0,df_text.size):\n",
    "        txt = process(df_text[i])\n",
    "        data_set.append(txt)\n",
    "\n",
    "        for wrd in txt:\n",
    "            _vocab[wrd] += 1\n",
    "\n",
    "        vocabulary = vocabulary + txt\n",
    "        vocabulary = list(set(vocabulary))\n",
    "\n",
    "        if(i%100 == 0):\n",
    "            print(\"%5d records processed\"%(i))\n",
    "    return data_set, vocabulary, _vocab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-06-11T22:22:39.105476Z",
     "start_time": "2018-06-11T22:22:39.099498Z"
    }
   },
   "outputs": [],
   "source": [
    "'''vocab = return sorted list of most common words in vocabulary'''\n",
    "def get_common_vocab(num_arg, vocab):\n",
    "    vocab = sorted(vocab.items(), key=operator.itemgetter(1), reverse=True)\n",
    "    vocab = vocab[:num_arg]\n",
    "    return vocab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-06-11T22:22:39.851482Z",
     "start_time": "2018-06-11T22:22:39.839514Z"
    }
   },
   "outputs": [],
   "source": [
    "'''Convert vocabulary and most common words to map for faster access'''\n",
    "def get_vocab_map(vocabulary, vocab):\n",
    "    vocab_map = {}\n",
    "    for i in range(0,len(vocab)):\n",
    "        vocab_map[vocab[i][0]] = i \n",
    "    \n",
    "    vocabulary_map = {}\n",
    "    for i in range(0,len(vocabulary)):\n",
    "        vocabulary_map[vocabulary[i]] = i\n",
    "    \n",
    "    return vocabulary_map, vocab_map"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-06-11T22:22:40.626409Z",
     "start_time": "2018-06-11T22:22:40.609455Z"
    }
   },
   "outputs": [],
   "source": [
    "'''This function returns n-gram context embedding for each word'''\n",
    "def get_embedding(word, data_set, vocab_map, wdw_size):\n",
    "    embedding = [0]*len(vocab_map)\n",
    "    for docs in data_set:\n",
    "        for i in range(wdw_size, len(docs)-wdw_size):\n",
    "            if docs[i] == word:\n",
    "                for j in range(i-wdw_size, i-1):\n",
    "                    if docs[j] in vocab_map:\n",
    "                        embedding[vocab_map[docs[j]]] += 1\n",
    "                for j in range(i+1, i+wdw_size):\n",
    "                    if docs[j] in vocab_map:\n",
    "                        embedding[vocab_map[docs[j]]] += 1\n",
    "    total_words = sum(embedding)\n",
    "    if total_words != 0:\n",
    "        embedding[:] = [e/total_words for e in embedding]\n",
    "    return embedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-06-11T22:22:41.413308Z",
     "start_time": "2018-06-11T22:22:41.405327Z"
    }
   },
   "outputs": [],
   "source": [
    "'''This is a helper function that returns n-gram embedding for all the topics in the corpus'''\n",
    "def get_embedding_all(all_topics, data_set, vocab_map, wdw_size):\n",
    "    embeddings = []\n",
    "    for i in range(0, len(all_topics)):\n",
    "        embeddings.append(get_embedding(all_topics[i], data_set, vocab_map, wdw_size))\n",
    "    return embeddings"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get similarity function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-06-11T22:22:42.876508Z",
     "start_time": "2018-06-11T22:22:42.868529Z"
    }
   },
   "outputs": [],
   "source": [
    "def cos_matrix_multiplication(matrix, vector):\n",
    "    \"\"\"\n",
    "    Calculating pairwise cosine distance using matrix vector multiplication.\n",
    "    \"\"\"\n",
    "    dotted = matrix.dot(vector)\n",
    "    matrix_norms = np.linalg.norm(matrix, axis=1)\n",
    "    vector_norm = np.linalg.norm(vector)\n",
    "    matrix_vector_norms = np.multiply(matrix_norms, vector_norm)\n",
    "    neighbors = np.divide(dotted, matrix_vector_norms)\n",
    "    return neighbors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-06-11T22:22:43.710277Z",
     "start_time": "2018-06-11T22:22:43.695318Z"
    }
   },
   "outputs": [],
   "source": [
    "'''This function generates most similar topic to a given embedding'''\n",
    "def get_most_similar_topics(embd, embeddings, all_topics, num_wrd=10):\n",
    "    sim_top = []\n",
    "    cos_sim = cos_matrix_multiplication(np.array(embeddings), embd)\n",
    "    #closest_match = cos_sim.argsort()[-num_wrd:][::-1] ## This sorts all matches in order\n",
    "    \n",
    "    ## This just takes 80% and above similar matches\n",
    "    idx = list(np.where(cos_sim > 0.9)[0])\n",
    "    val = list(cos_sim[np.where(cos_sim > 0.9)])\n",
    "    closest_match, list2 = (list(t) for t in zip(*sorted(zip(idx, val), reverse=True)))\n",
    "    closest_match = np.array(closest_match)\n",
    "    \n",
    "    for i in range(0, closest_match.shape[0]):\n",
    "        sim_top.append(all_topics[closest_match[i]])\n",
    "    return sim_top"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Topic Modelling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-06-11T22:22:44.978887Z",
     "start_time": "2018-06-11T22:22:44.972904Z"
    }
   },
   "outputs": [],
   "source": [
    "'''This function extracts matches for a regular expression in the text'''\n",
    "def get_regex_match(regex, str_arg):\n",
    "    srch = re.search(regex,str_arg)\n",
    "    if srch is not None:\n",
    "        return srch.group(0).strip()\n",
    "    else:\n",
    "        return \"Not found\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-06-11T22:22:45.895464Z",
     "start_time": "2018-06-11T22:22:45.878481Z"
    }
   },
   "outputs": [],
   "source": [
    "'''This is a helper function that helps extracting answer to extraction type questions'''\n",
    "def extract(key,str_arg):\n",
    "    if key == 'dob':\n",
    "        return get_regex_match('Date of Birth:(.*)] ', str_arg)\n",
    "    elif key == 'a_date':\n",
    "        return get_regex_match('Admission Date:(.*)] ', str_arg)\n",
    "    elif key == 'd_date':\n",
    "        return get_regex_match('Discharge Date:(.*)]\\n', str_arg)\n",
    "    elif key == 'sex':\n",
    "        return get_regex_match('Sex:(.*)\\n', str_arg)\n",
    "    elif key == 'service':\n",
    "        return get_regex_match('Service:(.*)\\n', str_arg)\n",
    "    elif key == 'allergy':\n",
    "        return get_regex_match('Allergies:(.*)\\n(.*)\\n', str_arg)\n",
    "    elif key == 'attdng':\n",
    "        return get_regex_match('Attending:(.*)]\\n', str_arg)\n",
    "    else:\n",
    "        return \"I Don't know\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-06-11T22:22:46.788047Z",
     "start_time": "2018-06-11T22:22:46.771119Z"
    }
   },
   "outputs": [],
   "source": [
    "'''This method extracts topic from sentence'''\n",
    "def extract_topic(str_arg, num_topics = 1, num_top_words = 3):\n",
    "    vectorizer = text.CountVectorizer(input='content', analyzer='word', lowercase=True, stop_words='english')\n",
    "    dtm = vectorizer.fit_transform(str_arg.split())\n",
    "    vocab = np.array(vectorizer.get_feature_names())\n",
    "    \n",
    "    #clf = decomposition.NMF(n_components=num_topics, random_state=1) ## topic extraction\n",
    "    clf = decomposition.LatentDirichletAllocation(n_components=num_topics, learning_method='online')\n",
    "    clf.fit_transform(dtm)\n",
    "    \n",
    "    topic_words = []\n",
    "    for topic in clf.components_:\n",
    "        word_idx = np.argsort(topic)[::-1][0:num_top_words] ##[::-1] reverses the list\n",
    "        topic_words.append([vocab[i] for i in word_idx])\n",
    "    return topic_words"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-06-11T22:22:47.747483Z",
     "start_time": "2018-06-11T22:22:47.740500Z"
    }
   },
   "outputs": [],
   "source": [
    "'''This method extracts topics in a question'''\n",
    "def extract_Q_topic(str_arg):\n",
    "    try:\n",
    "        return extract_topic(str_arg)\n",
    "    except:\n",
    "        return None\n",
    "    ## Future Scope fix later for more comprehensive results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-06-11T22:22:48.590228Z",
     "start_time": "2018-06-11T22:22:48.578259Z"
    }
   },
   "outputs": [],
   "source": [
    "def get_extract_map(key_wrd):\n",
    "    ## A Stemmed mapping for simple extractions\n",
    "    extract_map = {'birth':'dob', 'dob':'dob',\n",
    "              'admiss':'a_date', 'discharg':'d_date',\n",
    "              'sex':'sex', 'gender':'sex', 'servic':'service',\n",
    "              'allergi':'allergy', 'attend':'attdng'}\n",
    "    if key_wrd in extract_map.keys():\n",
    "        return extract_map[key_wrd]\n",
    "    else:\n",
    "        return None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-06-11T22:22:49.521736Z",
     "start_time": "2018-06-11T22:22:49.504781Z"
    }
   },
   "outputs": [],
   "source": [
    "'''Method that generates the answer for text extraction questions'''\n",
    "def get_extracted_answer(topic_str, text):\n",
    "    port = PorterStemmer()\n",
    "    for i in range(0, len(topic_str)):\n",
    "        rel_wrd = topic_str[i]\n",
    "        for wrd in rel_wrd:\n",
    "            key = get_extract_map(port.stem(wrd))\n",
    "            if key is not None:\n",
    "                return extract(key, text)\n",
    "    return None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-06-11T22:22:50.506103Z",
     "start_time": "2018-06-11T22:22:50.494136Z"
    }
   },
   "outputs": [],
   "source": [
    "'''This method extracts topics of each sentence and returns a list'''\n",
    "def get_topic_mapping(doc_string):\n",
    "    #One entry per sentence in list\n",
    "    doc_str = process_text(doc_string)\n",
    "    doc_str = get_processed_sentences(doc_str)\n",
    "    \n",
    "    res = defaultdict(list)\n",
    "    for i in range (0, len(doc_str)):\n",
    "        snd_str = doc_str[i].lower()\n",
    "        #print(\"Sending ----------------------------\",snd_str,\"==========\",len(snd_str))\n",
    "        tmp_topic = extract_topic(snd_str, num_topics = 2, num_top_words = 1)\n",
    "        for top in tmp_topic:\n",
    "            for wrd in top:\n",
    "                res[wrd].append(doc_str[i])\n",
    "    return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-06-11T22:22:51.287014Z",
     "start_time": "2018-06-11T22:22:51.280036Z"
    }
   },
   "outputs": [],
   "source": [
    "def get_direct_answer(topic_str, topic_map):\n",
    "    ## Maybe apply lemmatizer here\n",
    "    for i in range(0, len(topic_str)):\n",
    "        rel_wrd = topic_str[i]\n",
    "        for wrd in rel_wrd:\n",
    "            if wrd in topic_map.keys():\n",
    "                return topic_map[wrd]\n",
    "    return None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-06-11T22:22:52.353164Z",
     "start_time": "2018-06-11T22:22:52.343190Z"
    }
   },
   "outputs": [],
   "source": [
    "def get_answer(topic, topic_map, embedding_short, all_topics, data_set, vocab_map, pca, wdw_size=5):\n",
    "    ## Get most similar topics\n",
    "    tpc_embedding = get_embedding(topic, data_set, vocab_map, wdw_size)\n",
    "    tpc_embedding = pca.transform([tpc_embedding])\n",
    "    sim_topics = get_most_similar_topics(tpc_embedding[0], embedding_short, all_topics, num_wrd = len(all_topics))\n",
    "    for topic in sim_topics:\n",
    "        if topic in topic_map.keys():\n",
    "            return topic_map[topic]\n",
    "    return None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-06-11T22:22:53.157013Z",
     "start_time": "2018-06-11T22:22:53.150059Z"
    }
   },
   "outputs": [],
   "source": [
    "'''This function checks if the user input text is an instruction allowed in chatbot or not'''\n",
    "def is_instruction_option(str_arg):\n",
    "    if str_arg == \"exit\" or str_arg == \"summary\" or str_arg == \"reveal\":\n",
    "        return True\n",
    "    else:\n",
    "        return False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-06-11T22:22:53.953950Z",
     "start_time": "2018-06-11T22:22:53.941983Z"
    }
   },
   "outputs": [],
   "source": [
    "def print_bot():\n",
    "\tprint(r\"          _ _ _\")\n",
    "\tprint(r\"         | o o |\")\n",
    "\tprint(r\"        \\|  =  |/\")\n",
    "\tprint(r\"         -------\")\n",
    "\tprint(r\"         |||||||\")\n",
    "\tprint(r\"         //   \\\\\")\n",
    "\t\n",
    "def print_caption():\n",
    "\tprint(r\"\t||\\\\   ||  ||       ||= =||\")\n",
    "\tprint(r\"\t|| \\\\  ||  ||       ||= =||\")\n",
    "\tprint(r\"\t||  \\\\ ||  ||       ||\")\n",
    "\tprint(r\"\t||   \\\\||  ||_ _ _  ||\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2018-06-11T22:24:04.067Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading data ... \n",
      "\n",
      "Getting Vocabulary ...\n",
      "    0 records processed\n",
      "Creating context ...\n",
      "Learning topics ...\n",
      "Processed  1  records\n",
      "Processed  2  records\n",
      "Processed  3  records\n",
      "Processed  4  records\n",
      "Processed  5  records\n",
      "Processed  6  records\n",
      "Processed  7  records\n",
      "Processed  8  records\n",
      "Processed  9  records\n",
      "Processed  10  records\n",
      "Processed  11  records\n",
      "Processed  12  records\n",
      "Processed  13  records\n",
      "Processed  14  records\n",
      "Processed  15  records\n",
      "Processed  16  records\n",
      "Processed  17  records\n",
      "Processed  18  records\n",
      "Processed  19  records\n",
      "Processed  20  records\n",
      "Processed  21  records\n",
      "Processed  22  records\n",
      "Processed  23  records\n",
      "Processed  24  records\n",
      "Processed  25  records\n",
      "Processed  26  records\n",
      "Processed  27  records\n",
      "Processed  28  records\n",
      "Processed  29  records\n",
      "Processed  30  records\n",
      "Processed  31  records\n",
      "Processed  32  records\n",
      "Processed  33  records\n",
      "Processed  34  records\n",
      "Processed  35  records\n",
      "Processed  36  records\n",
      "Processed  37  records\n",
      "Processed  38  records\n",
      "Processed  39  records\n",
      "Processed  40  records\n",
      "Processed  41  records\n",
      "Processed  42  records\n",
      "Processed  43  records\n",
      "Processed  44  records\n",
      "Processed  45  records\n",
      "Processed  46  records\n",
      "Processed  47  records\n",
      "Processed  48  records\n",
      "Processed  49  records\n",
      "Processed  50  records\n",
      "Getting Embeddings\n",
      "\t||\\\\   ||  ||       ||= =||\n",
      "\t|| \\\\  ||  ||       ||= =||\n",
      "\t||  \\\\ ||  ||       ||\n",
      "\t||   \\\\||  ||_ _ _  ||\n",
      "          _ _ _\n",
      "         | o o |\n",
      "        \\|  =  |/\n",
      "         -------\n",
      "         |||||||\n",
      "         //   \\\\\n",
      "Bot:> I am online!\n",
      "Bot:> Type \"exit\" to switch to end a patient's session\n",
      "Bot:> Type \"summary\" to view patient's discharge summary\n",
      "Bot:> What is your Patient Id [0 to 49?]5\n",
      "Bot:> Reading Discharge Summary for Patient Id:  5\n",
      "Bot:> How can I help ?\n",
      "Person:>What is my date of birth?\n",
      "Bot:>  Date of Birth:  [**2109-10-8**]\n",
      "Bot:> How can I help ?\n",
      "Person:>When was I discharged?\n",
      "Bot:>  Discharge Date:   [**2172-3-8**]\n",
      "Bot:> How can I help ?\n",
      "Person:>What is my gender?\n",
      "Bot:>  Sex:   F\n",
      "Bot:> How can I help ?\n",
      "Person:>What are the services I had?\n",
      "Bot:>  Service: NEUROSURGERY\n",
      "Bot:> How can I help ?\n",
      "Person:>Am I married?\n",
      "Bot:>  ['Social History She is married']\n",
      "Bot:> How can I help ?\n",
      "Person:>How was my MRI?\n",
      "Bot:>  ['Pertinent Results MRI  Right middle cranial fossa mass likely represents a meningioma and is stable since MRI of', 'The previously seen midline nasopharyngeal mass has decreased in size since MRI of']\n",
      "Bot:> How can I help ?\n",
      "Person:>How can I make an appointment?\n",
      "Bot:>  ['Make sure to take your steroid medication with meals or a glass of milk']\n",
      "Bot:> How can I help ?\n",
      "Person:>Do I have sinus?\n",
      "Bot:>  ['She was found to hve a right cavernous sinus and nasopharyngeal mass', 'A gadoliniumenhanced head MRI performed at Hospital  on  showed a bright mass involving the cavernous sinus']\n",
      "Bot:> How can I help ?\n",
      "Person:>How should I take my medication?\n",
      "Bot:>  ['If you are being sent home on steroid medication make sure you are taking a medication to protect your stomach Prilosec Protonix or Pepcid as these medications can cause stomach irritation', 'Pain or headache that is continually increasing or not relieved by pain medication']\n"
     ]
    }
   ],
   "source": [
    "if __name__ == \"__main__\":\n",
    "    print(\"Loading data ...\",\"\\n\")\n",
    "    df_text = load_data()\n",
    "    \n",
    "    print(\"Getting Vocabulary ...\")\n",
    "    data_set, vocabulary, _vocab = get_vocab_wrd_map(df_text)\n",
    "    \n",
    "    print(\"Creating context ...\")\n",
    "    vocab = get_common_vocab(1000, _vocab)\n",
    "    vocabulary_map, vocab_map = get_vocab_map(vocabulary, vocab)\n",
    "    \n",
    "    print(\"Learning topics ...\")\n",
    "    all_topics = extract_corpus_topics(df_text)\n",
    "    \n",
    "    print(\"Getting Embeddings\")\n",
    "    embeddings = get_embedding_all(all_topics, data_set, vocab_map, 5)\n",
    "    \n",
    "    pca = PCA(n_components=10)\n",
    "    embedding_short = pca.fit_transform(embeddings)\n",
    "    \n",
    "    print_caption()\n",
    "    print_bot()\n",
    "    print(\"Bot:> I am online!\")\n",
    "    print(\"Bot:> Type \\\"exit\\\" to switch to end a patient's session\")\n",
    "    print(\"Bot:> Type \\\"summary\\\" to view patient's discharge summary\")\n",
    "    while(True):\n",
    "        while(True):\n",
    "            try:\n",
    "                pid = int(input(\"Bot:> What is your Patient Id [0 to \"+str(df_text.shape[0]-1)+\"?]\"))\n",
    "            except:\n",
    "                continue\n",
    "            if pid < 0 or pid > df_text.shape[0]-1:\n",
    "                print(\"Bot:> Patient Id out or range!\")\n",
    "                continue\n",
    "            else:\n",
    "                print(\"Bot:> Reading Discharge Summary for Patient Id: \",pid)\n",
    "                break\n",
    "\n",
    "        personal_topics = extract_topics_all(df_text[pid])\n",
    "        topic_mapping = get_topic_mapping(df_text[pid])\n",
    "        \n",
    "        ques = \"random starter\"\n",
    "        while(ques != \"exit\"):\n",
    "            ## Read Question\n",
    "            ques = input(\"Bot:> How can I help ?\\nPerson:>\")\n",
    "            \n",
    "            ## Check if it is an instructional question\n",
    "            if is_instruction_option(ques):\n",
    "                if ques == \"summary\":\n",
    "                    print(\"Bot:> ================= Discharge Summary for Patient Id \",pid,\"\\n\")\n",
    "                    print(df_text[pid])\n",
    "                elif ques == \"reveal\":\n",
    "                    print(topic_mapping, topic_mapping.keys())\n",
    "                continue\n",
    "                \n",
    "            ## Extract Question topic\n",
    "            topic_q = extract_Q_topic(ques)\n",
    "            if topic_q is None:\n",
    "                print(\"Bot:> I am a specialized NLP bot, please as a more specific question for me!\")\n",
    "                continue\n",
    "            ans = get_extracted_answer(topic_q, df_text[pid])\n",
    "            if ans is not None:\n",
    "                print(\"Bot:> \",ans)\n",
    "            else:\n",
    "                ans = get_direct_answer(topic_q, topic_mapping)\n",
    "                if ans is not None:\n",
    "                    print(\"Bot:> \",ans)\n",
    "                else:\n",
    "                    ans = get_answer(topic_q, topic_mapping, embedding_short, all_topics, data_set, vocab_map, pca, 5)\n",
    "                    if ans is not None:\n",
    "                        print(\"Bot:> \",ans)\n",
    "                    else:\n",
    "                        print(\"Bot:> Sorry but, I have no information on this topic!\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}