--- a +++ b/MED277_bot.ipynb @@ -0,0 +1,883 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "ExecuteTime": { + "end_time": "2018-06-11T22:22:28.421703Z", + "start_time": "2018-06-11T22:22:26.327228Z" + } + }, + "outputs": [], + "source": [ + "import pandas as pd\n", + "from sklearn.externals import joblib\n", + "import re\n", + "from nltk.stem.snowball import SnowballStemmer\n", + "from collections import defaultdict\n", + "import operator\n", + "import numpy as np\n", + "import sklearn.feature_extraction.text as text\n", + "from sklearn import decomposition\n", + "from nltk.stem import PorterStemmer, WordNetLemmatizer\n", + "from sklearn.decomposition import PCA\n", + "from numpy.linalg import norm" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "ExecuteTime": { + "end_time": "2018-06-11T22:22:30.024471Z", + "start_time": "2018-06-11T22:22:30.000549Z" + } + }, + "outputs": [], + "source": [ + "'''This function loads discharge summary data from MIMIC-III dataset NOTEEVENTS.csv.gz file.\n", + "update the base_path to be the location of this zipped '''\n", + "def load_data():\n", + " ## Intitializing data paths\n", + " base_path = r'D:\\ORGANIZATION\\UCSD_Life\\Work\\4. Quarter-3\\Subjects\\MED 277\\Project\\DATA\\\\'\n", + " data_file = base_path+\"NOTEEVENTS.csv.gz\"\n", + " \n", + " ## Loading data frames from CSV file\n", + " df = pd.read_csv(data_file, compression='gzip')\n", + " \n", + " ## Uncomment this to slice the size of dataset\n", + " #df = df[:10000]\n", + " \n", + " ## Uncomment this to save processed data to the memory\n", + " #joblib.dump(df,base_path+'data10.pkl')\n", + " ## loading data frames from PKL memory\n", + " #df1 = joblib.load(base_path+'data10.pkl')\n", + " \n", + " ## Filtering dataframe for \"Discharge summaries\" and \"TEXT\"\n", + " df = df.loc[df['CATEGORY'] == 'Discharge summary'] #Extracting only discharge summaries\n", + " df_text = df['TEXT']\n", + " return df_text" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## EXTRACT ALL THE TOPICS" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "ExecuteTime": { + "end_time": "2018-06-11T22:22:32.608559Z", + "start_time": "2018-06-11T22:22:32.603573Z" + } + }, + "outputs": [], + "source": [ + "'''Method that processes the entire document string'''\n", + "def process_text(txt):\n", + " txt1 = re.sub('[\\n]',\" \",txt)\n", + " txt1 = re.sub('[^A-Za-z \\.]+', '', txt1)\n", + " \n", + " return txt1" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "ExecuteTime": { + "end_time": "2018-06-11T22:22:33.596917Z", + "start_time": "2018-06-11T22:22:33.586970Z" + } + }, + "outputs": [], + "source": [ + "'''Method that processes the document string not considering separate lines'''\n", + "def process(txt):\n", + " txt1 = re.sub('[\\n]',\" \",txt)\n", + " txt1 = re.sub('[^A-Za-z ]+', '', txt1)\n", + " \n", + " _wrds = txt1.split()\n", + " stemmer = SnowballStemmer(\"english\") ## May use porter stemmer\n", + " wrds = [stemmer.stem(wrd) for wrd in _wrds]\n", + " return wrds" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "ExecuteTime": { + "end_time": "2018-06-11T22:22:34.412734Z", + "start_time": "2018-06-11T22:22:34.402790Z" + } + }, + "outputs": [], + "source": [ + "'''Method that processes raw string and gets a processes list containing lines'''\n", + "def get_processed_sentences(snt_txt):\n", + " snt_list = []\n", + " for line in snt_txt.split('.'):\n", + " line = line.strip()\n", + " if len(line.split()) >= 5:\n", + " snt_list.append(line)\n", + " return snt_list" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "ExecuteTime": { + "end_time": "2018-06-11T22:22:35.078078Z", + "start_time": "2018-06-11T22:22:35.051176Z" + } + }, + "outputs": [], + "source": [ + "'''This method extracts topic from sentence'''\n", + "def extract_topic(str_arg, num_topics = 1, num_top_words = 3):\n", + " vectorizer = text.CountVectorizer(input='content', analyzer='word', lowercase=True, stop_words='english')\n", + " try:\n", + " dtm = vectorizer.fit_transform(str_arg.split())\n", + " vocab = np.array(vectorizer.get_feature_names())\n", + " \n", + " #clf = decomposition.NMF(n_components=num_topics, random_state=1) ## topic extraction\n", + " clf = decomposition.LatentDirichletAllocation(n_components=num_topics, learning_method='online')\n", + " clf.fit_transform(dtm)\n", + "\n", + " topic_words = []\n", + " for topic in clf.components_:\n", + " word_idx = np.argsort(topic)[::-1][0:num_top_words] ##[::-1] reverses the list\n", + " topic_words.append([vocab[i] for i in word_idx])\n", + " return topic_words\n", + " except:\n", + " return None" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "ExecuteTime": { + "end_time": "2018-06-11T22:22:35.804137Z", + "start_time": "2018-06-11T22:22:35.794166Z" + } + }, + "outputs": [], + "source": [ + "'''This method extracts topics of each sentence and returns a list'''\n", + "def extract_topics_all(doc_string):\n", + " #One entry per sentence in list\n", + " doc_str = process_text(doc_string)\n", + " doc_str = get_processed_sentences(doc_str)\n", + " \n", + " res = []\n", + " for i in range (0, len(doc_str)):\n", + " snd_str = doc_str[i].lower()\n", + " #print(\"Sending ----------------------------\",snd_str,\"==========\",len(snd_str))\n", + " tmp_topic = extract_topic(snd_str, num_topics = 2, num_top_words = 1)\n", + " for top in tmp_topic:\n", + " for wrd in top:\n", + " res.append(wrd)\n", + " return res" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "ExecuteTime": { + "end_time": "2018-06-11T22:22:36.470386Z", + "start_time": "2018-06-11T22:22:36.462381Z" + } + }, + "outputs": [], + "source": [ + "'''This function takes a dataframe and returns all the topics in the entire corpus'''\n", + "def extract_corpus_topics(arg_df):\n", + " all_topics = set()\n", + " cnt = 1\n", + " for txt in arg_df:\n", + " all_topics = all_topics.union(extract_topics_all(txt))\n", + " print(\"Processed \",cnt,\" records\")\n", + " cnt += 1\n", + " all_topics = list(all_topics)\n", + " return all_topics" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## GET A VECTORIZED REPRESENTATION OF ALL THE TOPICS" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "ExecuteTime": { + "end_time": "2018-06-11T22:22:38.161868Z", + "start_time": "2018-06-11T22:22:38.140924Z" + } + }, + "outputs": [], + "source": [ + "'''data_set = words list per document.\n", + " vocabulary = list of all the words present\n", + " _vocab = dict of word counts for words in vocabulary'''\n", + "def get_vocab_wrd_map(df_text):\n", + " data_set = []\n", + " vocabulary = []\n", + " _vocab = defaultdict(int)\n", + " for i in range(0,df_text.size):\n", + " txt = process(df_text[i])\n", + " data_set.append(txt)\n", + "\n", + " for wrd in txt:\n", + " _vocab[wrd] += 1\n", + "\n", + " vocabulary = vocabulary + txt\n", + " vocabulary = list(set(vocabulary))\n", + "\n", + " if(i%100 == 0):\n", + " print(\"%5d records processed\"%(i))\n", + " return data_set, vocabulary, _vocab" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "ExecuteTime": { + "end_time": "2018-06-11T22:22:39.105476Z", + "start_time": "2018-06-11T22:22:39.099498Z" + } + }, + "outputs": [], + "source": [ + "'''vocab = return sorted list of most common words in vocabulary'''\n", + "def get_common_vocab(num_arg, vocab):\n", + " vocab = sorted(vocab.items(), key=operator.itemgetter(1), reverse=True)\n", + " vocab = vocab[:num_arg]\n", + " return vocab" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "ExecuteTime": { + "end_time": "2018-06-11T22:22:39.851482Z", + "start_time": "2018-06-11T22:22:39.839514Z" + } + }, + "outputs": [], + "source": [ + "'''Convert vocabulary and most common words to map for faster access'''\n", + "def get_vocab_map(vocabulary, vocab):\n", + " vocab_map = {}\n", + " for i in range(0,len(vocab)):\n", + " vocab_map[vocab[i][0]] = i \n", + " \n", + " vocabulary_map = {}\n", + " for i in range(0,len(vocabulary)):\n", + " vocabulary_map[vocabulary[i]] = i\n", + " \n", + " return vocabulary_map, vocab_map" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "ExecuteTime": { + "end_time": "2018-06-11T22:22:40.626409Z", + "start_time": "2018-06-11T22:22:40.609455Z" + } + }, + "outputs": [], + "source": [ + "'''This function returns n-gram context embedding for each word'''\n", + "def get_embedding(word, data_set, vocab_map, wdw_size):\n", + " embedding = [0]*len(vocab_map)\n", + " for docs in data_set:\n", + " for i in range(wdw_size, len(docs)-wdw_size):\n", + " if docs[i] == word:\n", + " for j in range(i-wdw_size, i-1):\n", + " if docs[j] in vocab_map:\n", + " embedding[vocab_map[docs[j]]] += 1\n", + " for j in range(i+1, i+wdw_size):\n", + " if docs[j] in vocab_map:\n", + " embedding[vocab_map[docs[j]]] += 1\n", + " total_words = sum(embedding)\n", + " if total_words != 0:\n", + " embedding[:] = [e/total_words for e in embedding]\n", + " return embedding" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "ExecuteTime": { + "end_time": "2018-06-11T22:22:41.413308Z", + "start_time": "2018-06-11T22:22:41.405327Z" + } + }, + "outputs": [], + "source": [ + "'''This is a helper function that returns n-gram embedding for all the topics in the corpus'''\n", + "def get_embedding_all(all_topics, data_set, vocab_map, wdw_size):\n", + " embeddings = []\n", + " for i in range(0, len(all_topics)):\n", + " embeddings.append(get_embedding(all_topics[i], data_set, vocab_map, wdw_size))\n", + " return embeddings" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Get similarity function" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "ExecuteTime": { + "end_time": "2018-06-11T22:22:42.876508Z", + "start_time": "2018-06-11T22:22:42.868529Z" + } + }, + "outputs": [], + "source": [ + "def cos_matrix_multiplication(matrix, vector):\n", + " \"\"\"\n", + " Calculating pairwise cosine distance using matrix vector multiplication.\n", + " \"\"\"\n", + " dotted = matrix.dot(vector)\n", + " matrix_norms = np.linalg.norm(matrix, axis=1)\n", + " vector_norm = np.linalg.norm(vector)\n", + " matrix_vector_norms = np.multiply(matrix_norms, vector_norm)\n", + " neighbors = np.divide(dotted, matrix_vector_norms)\n", + " return neighbors" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "ExecuteTime": { + "end_time": "2018-06-11T22:22:43.710277Z", + "start_time": "2018-06-11T22:22:43.695318Z" + } + }, + "outputs": [], + "source": [ + "'''This function generates most similar topic to a given embedding'''\n", + "def get_most_similar_topics(embd, embeddings, all_topics, num_wrd=10):\n", + " sim_top = []\n", + " cos_sim = cos_matrix_multiplication(np.array(embeddings), embd)\n", + " #closest_match = cos_sim.argsort()[-num_wrd:][::-1] ## This sorts all matches in order\n", + " \n", + " ## This just takes 80% and above similar matches\n", + " idx = list(np.where(cos_sim > 0.9)[0])\n", + " val = list(cos_sim[np.where(cos_sim > 0.9)])\n", + " closest_match, list2 = (list(t) for t in zip(*sorted(zip(idx, val), reverse=True)))\n", + " closest_match = np.array(closest_match)\n", + " \n", + " for i in range(0, closest_match.shape[0]):\n", + " sim_top.append(all_topics[closest_match[i]])\n", + " return sim_top" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Topic Modelling" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "ExecuteTime": { + "end_time": "2018-06-11T22:22:44.978887Z", + "start_time": "2018-06-11T22:22:44.972904Z" + } + }, + "outputs": [], + "source": [ + "'''This function extracts matches for a regular expression in the text'''\n", + "def get_regex_match(regex, str_arg):\n", + " srch = re.search(regex,str_arg)\n", + " if srch is not None:\n", + " return srch.group(0).strip()\n", + " else:\n", + " return \"Not found\"" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "ExecuteTime": { + "end_time": "2018-06-11T22:22:45.895464Z", + "start_time": "2018-06-11T22:22:45.878481Z" + } + }, + "outputs": [], + "source": [ + "'''This is a helper function that helps extracting answer to extraction type questions'''\n", + "def extract(key,str_arg):\n", + " if key == 'dob':\n", + " return get_regex_match('Date of Birth:(.*)] ', str_arg)\n", + " elif key == 'a_date':\n", + " return get_regex_match('Admission Date:(.*)] ', str_arg)\n", + " elif key == 'd_date':\n", + " return get_regex_match('Discharge Date:(.*)]\\n', str_arg)\n", + " elif key == 'sex':\n", + " return get_regex_match('Sex:(.*)\\n', str_arg)\n", + " elif key == 'service':\n", + " return get_regex_match('Service:(.*)\\n', str_arg)\n", + " elif key == 'allergy':\n", + " return get_regex_match('Allergies:(.*)\\n(.*)\\n', str_arg)\n", + " elif key == 'attdng':\n", + " return get_regex_match('Attending:(.*)]\\n', str_arg)\n", + " else:\n", + " return \"I Don't know\"" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "ExecuteTime": { + "end_time": "2018-06-11T22:22:46.788047Z", + "start_time": "2018-06-11T22:22:46.771119Z" + } + }, + "outputs": [], + "source": [ + "'''This method extracts topic from sentence'''\n", + "def extract_topic(str_arg, num_topics = 1, num_top_words = 3):\n", + " vectorizer = text.CountVectorizer(input='content', analyzer='word', lowercase=True, stop_words='english')\n", + " dtm = vectorizer.fit_transform(str_arg.split())\n", + " vocab = np.array(vectorizer.get_feature_names())\n", + " \n", + " #clf = decomposition.NMF(n_components=num_topics, random_state=1) ## topic extraction\n", + " clf = decomposition.LatentDirichletAllocation(n_components=num_topics, learning_method='online')\n", + " clf.fit_transform(dtm)\n", + " \n", + " topic_words = []\n", + " for topic in clf.components_:\n", + " word_idx = np.argsort(topic)[::-1][0:num_top_words] ##[::-1] reverses the list\n", + " topic_words.append([vocab[i] for i in word_idx])\n", + " return topic_words" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": { + "ExecuteTime": { + "end_time": "2018-06-11T22:22:47.747483Z", + "start_time": "2018-06-11T22:22:47.740500Z" + } + }, + "outputs": [], + "source": [ + "'''This method extracts topics in a question'''\n", + "def extract_Q_topic(str_arg):\n", + " try:\n", + " return extract_topic(str_arg)\n", + " except:\n", + " return None\n", + " ## Future Scope fix later for more comprehensive results" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": { + "ExecuteTime": { + "end_time": "2018-06-11T22:22:48.590228Z", + "start_time": "2018-06-11T22:22:48.578259Z" + } + }, + "outputs": [], + "source": [ + "def get_extract_map(key_wrd):\n", + " ## A Stemmed mapping for simple extractions\n", + " extract_map = {'birth':'dob', 'dob':'dob',\n", + " 'admiss':'a_date', 'discharg':'d_date',\n", + " 'sex':'sex', 'gender':'sex', 'servic':'service',\n", + " 'allergi':'allergy', 'attend':'attdng'}\n", + " if key_wrd in extract_map.keys():\n", + " return extract_map[key_wrd]\n", + " else:\n", + " return None" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": { + "ExecuteTime": { + "end_time": "2018-06-11T22:22:49.521736Z", + "start_time": "2018-06-11T22:22:49.504781Z" + } + }, + "outputs": [], + "source": [ + "'''Method that generates the answer for text extraction questions'''\n", + "def get_extracted_answer(topic_str, text):\n", + " port = PorterStemmer()\n", + " for i in range(0, len(topic_str)):\n", + " rel_wrd = topic_str[i]\n", + " for wrd in rel_wrd:\n", + " key = get_extract_map(port.stem(wrd))\n", + " if key is not None:\n", + " return extract(key, text)\n", + " return None" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": { + "ExecuteTime": { + "end_time": "2018-06-11T22:22:50.506103Z", + "start_time": "2018-06-11T22:22:50.494136Z" + } + }, + "outputs": [], + "source": [ + "'''This method extracts topics of each sentence and returns a list'''\n", + "def get_topic_mapping(doc_string):\n", + " #One entry per sentence in list\n", + " doc_str = process_text(doc_string)\n", + " doc_str = get_processed_sentences(doc_str)\n", + " \n", + " res = defaultdict(list)\n", + " for i in range (0, len(doc_str)):\n", + " snd_str = doc_str[i].lower()\n", + " #print(\"Sending ----------------------------\",snd_str,\"==========\",len(snd_str))\n", + " tmp_topic = extract_topic(snd_str, num_topics = 2, num_top_words = 1)\n", + " for top in tmp_topic:\n", + " for wrd in top:\n", + " res[wrd].append(doc_str[i])\n", + " return res" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": { + "ExecuteTime": { + "end_time": "2018-06-11T22:22:51.287014Z", + "start_time": "2018-06-11T22:22:51.280036Z" + } + }, + "outputs": [], + "source": [ + "def get_direct_answer(topic_str, topic_map):\n", + " ## Maybe apply lemmatizer here\n", + " for i in range(0, len(topic_str)):\n", + " rel_wrd = topic_str[i]\n", + " for wrd in rel_wrd:\n", + " if wrd in topic_map.keys():\n", + " return topic_map[wrd]\n", + " return None" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": { + "ExecuteTime": { + "end_time": "2018-06-11T22:22:52.353164Z", + "start_time": "2018-06-11T22:22:52.343190Z" + } + }, + "outputs": [], + "source": [ + "def get_answer(topic, topic_map, embedding_short, all_topics, data_set, vocab_map, pca, wdw_size=5):\n", + " ## Get most similar topics\n", + " tpc_embedding = get_embedding(topic, data_set, vocab_map, wdw_size)\n", + " tpc_embedding = pca.transform([tpc_embedding])\n", + " sim_topics = get_most_similar_topics(tpc_embedding[0], embedding_short, all_topics, num_wrd = len(all_topics))\n", + " for topic in sim_topics:\n", + " if topic in topic_map.keys():\n", + " return topic_map[topic]\n", + " return None" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": { + "ExecuteTime": { + "end_time": "2018-06-11T22:22:53.157013Z", + "start_time": "2018-06-11T22:22:53.150059Z" + } + }, + "outputs": [], + "source": [ + "'''This function checks if the user input text is an instruction allowed in chatbot or not'''\n", + "def is_instruction_option(str_arg):\n", + " if str_arg == \"exit\" or str_arg == \"summary\" or str_arg == \"reveal\":\n", + " return True\n", + " else:\n", + " return False" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": { + "ExecuteTime": { + "end_time": "2018-06-11T22:22:53.953950Z", + "start_time": "2018-06-11T22:22:53.941983Z" + } + }, + "outputs": [], + "source": [ + "def print_bot():\n", + "\tprint(r\" _ _ _\")\n", + "\tprint(r\" | o o |\")\n", + "\tprint(r\" \\| = |/\")\n", + "\tprint(r\" -------\")\n", + "\tprint(r\" |||||||\")\n", + "\tprint(r\" // \\\\\")\n", + "\t\n", + "def print_caption():\n", + "\tprint(r\"\t||\\\\ || || ||= =||\")\n", + "\tprint(r\"\t|| \\\\ || || ||= =||\")\n", + "\tprint(r\"\t|| \\\\ || || ||\")\n", + "\tprint(r\"\t|| \\\\|| ||_ _ _ ||\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "start_time": "2018-06-11T22:24:04.067Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading data ... \n", + "\n", + "Getting Vocabulary ...\n", + " 0 records processed\n", + "Creating context ...\n", + "Learning topics ...\n", + "Processed 1 records\n", + "Processed 2 records\n", + "Processed 3 records\n", + "Processed 4 records\n", + "Processed 5 records\n", + "Processed 6 records\n", + "Processed 7 records\n", + "Processed 8 records\n", + "Processed 9 records\n", + "Processed 10 records\n", + "Processed 11 records\n", + "Processed 12 records\n", + "Processed 13 records\n", + "Processed 14 records\n", + "Processed 15 records\n", + "Processed 16 records\n", + "Processed 17 records\n", + "Processed 18 records\n", + "Processed 19 records\n", + "Processed 20 records\n", + "Processed 21 records\n", + "Processed 22 records\n", + "Processed 23 records\n", + "Processed 24 records\n", + "Processed 25 records\n", + "Processed 26 records\n", + "Processed 27 records\n", + "Processed 28 records\n", + "Processed 29 records\n", + "Processed 30 records\n", + "Processed 31 records\n", + "Processed 32 records\n", + "Processed 33 records\n", + "Processed 34 records\n", + "Processed 35 records\n", + "Processed 36 records\n", + "Processed 37 records\n", + "Processed 38 records\n", + "Processed 39 records\n", + "Processed 40 records\n", + "Processed 41 records\n", + "Processed 42 records\n", + "Processed 43 records\n", + "Processed 44 records\n", + "Processed 45 records\n", + "Processed 46 records\n", + "Processed 47 records\n", + "Processed 48 records\n", + "Processed 49 records\n", + "Processed 50 records\n", + "Getting Embeddings\n", + "\t||\\\\ || || ||= =||\n", + "\t|| \\\\ || || ||= =||\n", + "\t|| \\\\ || || ||\n", + "\t|| \\\\|| ||_ _ _ ||\n", + " _ _ _\n", + " | o o |\n", + " \\| = |/\n", + " -------\n", + " |||||||\n", + " // \\\\\n", + "Bot:> I am online!\n", + "Bot:> Type \"exit\" to switch to end a patient's session\n", + "Bot:> Type \"summary\" to view patient's discharge summary\n", + "Bot:> What is your Patient Id [0 to 49?]5\n", + "Bot:> Reading Discharge Summary for Patient Id: 5\n", + "Bot:> How can I help ?\n", + "Person:>What is my date of birth?\n", + "Bot:> Date of Birth: [**2109-10-8**]\n", + "Bot:> How can I help ?\n", + "Person:>When was I discharged?\n", + "Bot:> Discharge Date: [**2172-3-8**]\n", + "Bot:> How can I help ?\n", + "Person:>What is my gender?\n", + "Bot:> Sex: F\n", + "Bot:> How can I help ?\n", + "Person:>What are the services I had?\n", + "Bot:> Service: NEUROSURGERY\n", + "Bot:> How can I help ?\n", + "Person:>Am I married?\n", + "Bot:> ['Social History She is married']\n", + "Bot:> How can I help ?\n", + "Person:>How was my MRI?\n", + "Bot:> ['Pertinent Results MRI Right middle cranial fossa mass likely represents a meningioma and is stable since MRI of', 'The previously seen midline nasopharyngeal mass has decreased in size since MRI of']\n", + "Bot:> How can I help ?\n", + "Person:>How can I make an appointment?\n", + "Bot:> ['Make sure to take your steroid medication with meals or a glass of milk']\n", + "Bot:> How can I help ?\n", + "Person:>Do I have sinus?\n", + "Bot:> ['She was found to hve a right cavernous sinus and nasopharyngeal mass', 'A gadoliniumenhanced head MRI performed at Hospital on showed a bright mass involving the cavernous sinus']\n", + "Bot:> How can I help ?\n", + "Person:>How should I take my medication?\n", + "Bot:> ['If you are being sent home on steroid medication make sure you are taking a medication to protect your stomach Prilosec Protonix or Pepcid as these medications can cause stomach irritation', 'Pain or headache that is continually increasing or not relieved by pain medication']\n" + ] + } + ], + "source": [ + "if __name__ == \"__main__\":\n", + " print(\"Loading data ...\",\"\\n\")\n", + " df_text = load_data()\n", + " \n", + " print(\"Getting Vocabulary ...\")\n", + " data_set, vocabulary, _vocab = get_vocab_wrd_map(df_text)\n", + " \n", + " print(\"Creating context ...\")\n", + " vocab = get_common_vocab(1000, _vocab)\n", + " vocabulary_map, vocab_map = get_vocab_map(vocabulary, vocab)\n", + " \n", + " print(\"Learning topics ...\")\n", + " all_topics = extract_corpus_topics(df_text)\n", + " \n", + " print(\"Getting Embeddings\")\n", + " embeddings = get_embedding_all(all_topics, data_set, vocab_map, 5)\n", + " \n", + " pca = PCA(n_components=10)\n", + " embedding_short = pca.fit_transform(embeddings)\n", + " \n", + " print_caption()\n", + " print_bot()\n", + " print(\"Bot:> I am online!\")\n", + " print(\"Bot:> Type \\\"exit\\\" to switch to end a patient's session\")\n", + " print(\"Bot:> Type \\\"summary\\\" to view patient's discharge summary\")\n", + " while(True):\n", + " while(True):\n", + " try:\n", + " pid = int(input(\"Bot:> What is your Patient Id [0 to \"+str(df_text.shape[0]-1)+\"?]\"))\n", + " except:\n", + " continue\n", + " if pid < 0 or pid > df_text.shape[0]-1:\n", + " print(\"Bot:> Patient Id out or range!\")\n", + " continue\n", + " else:\n", + " print(\"Bot:> Reading Discharge Summary for Patient Id: \",pid)\n", + " break\n", + "\n", + " personal_topics = extract_topics_all(df_text[pid])\n", + " topic_mapping = get_topic_mapping(df_text[pid])\n", + " \n", + " ques = \"random starter\"\n", + " while(ques != \"exit\"):\n", + " ## Read Question\n", + " ques = input(\"Bot:> How can I help ?\\nPerson:>\")\n", + " \n", + " ## Check if it is an instructional question\n", + " if is_instruction_option(ques):\n", + " if ques == \"summary\":\n", + " print(\"Bot:> ================= Discharge Summary for Patient Id \",pid,\"\\n\")\n", + " print(df_text[pid])\n", + " elif ques == \"reveal\":\n", + " print(topic_mapping, topic_mapping.keys())\n", + " continue\n", + " \n", + " ## Extract Question topic\n", + " topic_q = extract_Q_topic(ques)\n", + " if topic_q is None:\n", + " print(\"Bot:> I am a specialized NLP bot, please as a more specific question for me!\")\n", + " continue\n", + " ans = get_extracted_answer(topic_q, df_text[pid])\n", + " if ans is not None:\n", + " print(\"Bot:> \",ans)\n", + " else:\n", + " ans = get_direct_answer(topic_q, topic_mapping)\n", + " if ans is not None:\n", + " print(\"Bot:> \",ans)\n", + " else:\n", + " ans = get_answer(topic_q, topic_mapping, embedding_short, all_topics, data_set, vocab_map, pca, 5)\n", + " if ans is not None:\n", + " print(\"Bot:> \",ans)\n", + " else:\n", + " print(\"Bot:> Sorry but, I have no information on this topic!\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}