Diff of /MED277_bot.py [000000] .. [7b0fc8]

Switch to unified view

a b/MED277_bot.py
1
2
# coding: utf-8
3
4
# In[59]:
5
6
7
import pandas as pd
8
from sklearn.externals import joblib
9
import re
10
from nltk.stem.snowball import SnowballStemmer
11
from collections import defaultdict
12
import operator
13
import numpy as np
14
import sklearn.feature_extraction.text as text
15
from sklearn import decomposition
16
from nltk.stem import PorterStemmer, WordNetLemmatizer
17
from sklearn.decomposition import PCA
18
from numpy.linalg import norm
19
20
21
# In[60]:
22
23
24
def load_data():
25
    ## Intitializing data paths
26
    base_path = r'D:\ORGANIZATION\UCSD_Life\Work\4. Quarter-3\Subjects\MED 277\Project\DATA\\'
27
    data_file = base_path+"NOTEEVENTS.csv.gz"
28
    
29
    ## Loading data frames from CSV file
30
    #df = pd.read_csv(data_file, compression='gzip')
31
    #df = df[:10000]
32
    #joblib.dump(df,base_path+'data10.pkl')
33
    
34
    ## loading data frames from PKL memory
35
    df1 =  joblib.load(base_path+'data10.pkl')
36
    df = df1[:50]
37
    
38
    ## Filtering dataframe for "Discharge summaries" and "TEXT"
39
    df = df.loc[df['CATEGORY'] == 'Discharge summary'] #Extracting only discharge summaries
40
    df_text = df['TEXT']
41
    return df_text
42
43
44
# ## EXTRACT ALL THE TOPICS
45
46
# In[61]:
47
48
49
'''Method that processes the entire document string'''
50
def process_text(txt):
51
    txt1 = re.sub('[\n]'," ",txt)
52
    txt1 = re.sub('[^A-Za-z \.]+', '', txt1)
53
    
54
    return txt1
55
56
57
# In[62]:
58
59
60
'''Method that processes the document string not considering separate lines'''
61
def process(txt):
62
    txt1 = re.sub('[\n]'," ",txt)
63
    txt1 = re.sub('[^A-Za-z ]+', '', txt1)
64
    
65
    _wrds = txt1.split()
66
    stemmer = SnowballStemmer("english") ## May use porter stemmer
67
    wrds = [stemmer.stem(wrd) for wrd in _wrds]
68
    return wrds
69
70
71
# In[63]:
72
73
74
'''Method that processes raw string and gets a processes list containing lines'''
75
def get_processed_sentences(snt_txt):
76
    snt_list = []
77
    for line in snt_txt.split('.'):
78
        line = line.strip()
79
        if len(line.split()) >= 5:
80
            snt_list.append(line)
81
    return snt_list
82
83
84
# In[64]:
85
86
87
'''This method extracts topic from sentence'''
88
def extract_topic(str_arg, num_topics = 1, num_top_words = 3):
89
    vectorizer = text.CountVectorizer(input='content', analyzer='word', lowercase=True, stop_words='english')
90
    try:
91
        dtm = vectorizer.fit_transform(str_arg.split())
92
        vocab = np.array(vectorizer.get_feature_names())
93
    
94
        #clf = decomposition.NMF(n_components=num_topics, random_state=1) ## topic extraction
95
        clf = decomposition.LatentDirichletAllocation(n_components=num_topics, learning_method='online')
96
        clf.fit_transform(dtm)
97
98
        topic_words = []
99
        for topic in clf.components_:
100
            word_idx = np.argsort(topic)[::-1][0:num_top_words] ##[::-1] reverses the list
101
            topic_words.append([vocab[i] for i in word_idx])
102
        return topic_words
103
    except:
104
        return None
105
106
107
# In[65]:
108
109
110
'''This method extracts topics of each sentence and returns a list'''
111
def extract_topics_all(doc_string):
112
    #One entry per sentence in list
113
    doc_str = process_text(doc_string)
114
    doc_str = get_processed_sentences(doc_str)
115
    
116
    res = []
117
    for i in range (0, len(doc_str)):
118
        snd_str = doc_str[i].lower()
119
        #print("Sending ----------------------------",snd_str,"==========",len(snd_str))
120
        tmp_topic = extract_topic(snd_str, num_topics = 2, num_top_words = 1)
121
        for top in tmp_topic:
122
            for wrd in top:
123
                res.append(wrd)
124
    return res
125
126
127
# In[66]:
128
129
130
'''This function takes a dataframe and returns all the topics in the entire corpus'''
131
def extract_corpus_topics(arg_df):
132
    all_topics = set()
133
    cnt = 1
134
    for txt in arg_df:
135
        all_topics = all_topics.union(extract_topics_all(txt))
136
        print("Processed ",cnt," records")
137
        cnt += 1
138
    all_topics = list(all_topics)
139
    return all_topics
140
141
142
# ## GET A VECTORIZED REPRESENTATION OF ALL THE TOPICS
143
144
# In[67]:
145
146
147
'''data_set = words list per document.
148
    vocabulary = list of all the words present
149
    _vocab = dict of word counts for words in vocabulary'''
150
def get_vocab_wrd_map(df_text):
151
    data_set = []
152
    vocabulary = []
153
    _vocab = defaultdict(int)
154
    for i in range(0,df_text.size):
155
        txt = process(df_text[i])
156
        data_set.append(txt)
157
158
        for wrd in txt:
159
            _vocab[wrd] += 1
160
161
        vocabulary = vocabulary + txt
162
        vocabulary = list(set(vocabulary))
163
164
        if(i%100 == 0):
165
            print("%5d records processed"%(i))
166
    return data_set, vocabulary, _vocab
167
168
169
# In[68]:
170
171
172
'''vocab = return sorted list of most common words in vocabulary'''
173
def get_common_vocab(num_arg, vocab):
174
    vocab = sorted(vocab.items(), key=operator.itemgetter(1), reverse=True)
175
    vocab = vocab[:num_arg]
176
    return vocab
177
178
179
# In[69]:
180
181
182
'''Convert vocabulary and most common words to map for faster access'''
183
def get_vocab_map(vocabulary, vocab):
184
    vocab_map = {}
185
    for i in range(0,len(vocab)):
186
        vocab_map[vocab[i][0]] = i 
187
    
188
    vocabulary_map = {}
189
    for i in range(0,len(vocabulary)):
190
        vocabulary_map[vocabulary[i]] = i
191
    
192
    return vocabulary_map, vocab_map
193
194
195
# In[70]:
196
197
198
def get_embedding(word, data_set, vocab_map, wdw_size):
199
    embedding = [0]*len(vocab_map)
200
    for docs in data_set:
201
        for i in range(wdw_size, len(docs)-wdw_size):
202
            if docs[i] == word:
203
                for j in range(i-wdw_size, i-1):
204
                    if docs[j] in vocab_map:
205
                        embedding[vocab_map[docs[j]]] += 1
206
                for j in range(i+1, i+wdw_size):
207
                    if docs[j] in vocab_map:
208
                        embedding[vocab_map[docs[j]]] += 1
209
    total_words = sum(embedding)
210
    if total_words != 0:
211
        embedding[:] = [e/total_words for e in embedding]
212
    return embedding
213
214
215
# In[71]:
216
217
218
def get_embedding_all(all_topics, data_set, vocab_map, wdw_size):
219
    embeddings = []
220
    for i in range(0, len(all_topics)):
221
        embeddings.append(get_embedding(all_topics[i], data_set, vocab_map, wdw_size))
222
    return embeddings
223
224
225
# ## Get similarity function
226
227
# In[72]:
228
229
230
def cos_matrix_multiplication(matrix, vector):
231
    """
232
    Calculating pairwise cosine distance using matrix vector multiplication.
233
    """
234
    dotted = matrix.dot(vector)
235
    matrix_norms = np.linalg.norm(matrix, axis=1)
236
    vector_norm = np.linalg.norm(vector)
237
    matrix_vector_norms = np.multiply(matrix_norms, vector_norm)
238
    neighbors = np.divide(dotted, matrix_vector_norms)
239
    return neighbors
240
241
242
# In[73]:
243
244
245
def get_most_similar_topics(embd, embeddings, all_topics, num_wrd=10):
246
    sim_top = []
247
    cos_sim = cos_matrix_multiplication(np.array(embeddings), embd)
248
    #closest_match = cos_sim.argsort()[-num_wrd:][::-1] ## This sorts all matches in order
249
    
250
    ## This just takes 80% and above similar matches
251
    idx = list(np.where(cos_sim > 0.9)[0])
252
    val = list(cos_sim[np.where(cos_sim > 0.9)])
253
    closest_match, list2 = (list(t) for t in zip(*sorted(zip(idx, val), reverse=True)))
254
    closest_match = np.array(closest_match)
255
    
256
    for i in range(0, closest_match.shape[0]):
257
        sim_top.append(all_topics[closest_match[i]])
258
    return sim_top
259
260
261
# ## Topic Modelling
262
263
# In[74]:
264
265
266
def get_regex_match(regex, str_arg):
267
    srch = re.search(regex,str_arg)
268
    if srch is not None:
269
        return srch.group(0).strip()
270
    else:
271
        return "Not found"
272
273
274
# In[75]:
275
276
277
def extract(key,str_arg):
278
    if key == 'dob':
279
        return get_regex_match('Date of Birth:(.*)] ', str_arg)
280
    elif key == 'a_date':
281
        return get_regex_match('Admission Date:(.*)] ', str_arg)
282
    elif key == 'd_date':
283
        return get_regex_match('Discharge Date:(.*)]\n', str_arg)
284
    elif key == 'sex':
285
        return get_regex_match('Sex:(.*)\n', str_arg)
286
    elif key == 'service':
287
        return get_regex_match('Service:(.*)\n', str_arg)
288
    elif key == 'allergy':
289
        return get_regex_match('Allergies:(.*)\n(.*)\n', str_arg)
290
    elif key == 'attdng':
291
        return get_regex_match('Attending:(.*)]\n', str_arg)
292
    else:
293
        return "I Don't know"
294
295
296
# In[76]:
297
298
299
'''This method extracts topic from sentence'''
300
def extract_topic(str_arg, num_topics = 1, num_top_words = 3):
301
    vectorizer = text.CountVectorizer(input='content', analyzer='word', lowercase=True, stop_words='english')
302
    dtm = vectorizer.fit_transform(str_arg.split())
303
    vocab = np.array(vectorizer.get_feature_names())
304
    
305
    #clf = decomposition.NMF(n_components=num_topics, random_state=1) ## topic extraction
306
    clf = decomposition.LatentDirichletAllocation(n_components=num_topics, learning_method='online')
307
    clf.fit_transform(dtm)
308
    
309
    topic_words = []
310
    for topic in clf.components_:
311
        word_idx = np.argsort(topic)[::-1][0:num_top_words] ##[::-1] reverses the list
312
        topic_words.append([vocab[i] for i in word_idx])
313
    return topic_words
314
315
316
# In[77]:
317
318
319
'''This method extracts topics in a question'''
320
def extract_Q_topic(str_arg):
321
    try:
322
        return extract_topic(str_arg)
323
    except:
324
        return None
325
    ## TODO fix later for more comprehensive results
326
327
328
# In[78]:
329
330
331
def get_extract_map(key_wrd):
332
    ## A Stemmed mapping for simple extractions
333
    extract_map = {'birth':'dob', 'dob':'dob',
334
              'admiss':'a_date', 'discharg':'d_date',
335
              'sex':'sex', 'gender':'sex', 'servic':'service',
336
              'allergi':'allergy', 'attend':'attdng'}
337
    if key_wrd in extract_map.keys():
338
        return extract_map[key_wrd]
339
    else:
340
        return None
341
342
343
# In[79]:
344
345
346
'''Method that generates the answer for text extraction questions'''
347
def get_extracted_answer(topic_str, text):
348
    port = PorterStemmer()
349
    for i in range(0, len(topic_str)):
350
        rel_wrd = topic_str[i]
351
        for wrd in rel_wrd:
352
            key = get_extract_map(port.stem(wrd))
353
            if key is not None:
354
                return extract(key, text)
355
    return None
356
357
358
# In[80]:
359
360
361
'''This method extracts topics of each sentence and returns a list'''
362
def get_topic_mapping(doc_string):
363
    #One entry per sentence in list
364
    doc_str = process_text(doc_string)
365
    doc_str = get_processed_sentences(doc_str)
366
    
367
    res = defaultdict(list)
368
    for i in range (0, len(doc_str)):
369
        snd_str = doc_str[i].lower()
370
        #print("Sending ----------------------------",snd_str,"==========",len(snd_str))
371
        tmp_topic = extract_topic(snd_str, num_topics = 2, num_top_words = 1)
372
        for top in tmp_topic:
373
            for wrd in top:
374
                res[wrd].append(doc_str[i])
375
    return res
376
377
378
# In[81]:
379
380
381
def get_direct_answer(topic_str, topic_map):
382
    ## Maybe apply lemmatizer here
383
    for i in range(0, len(topic_str)):
384
        rel_wrd = topic_str[i]
385
        for wrd in rel_wrd:
386
            if wrd in topic_map.keys():
387
                return topic_map[wrd]
388
    return None
389
390
391
# In[82]:
392
393
394
def get_answer(topic, topic_map, embedding_short, all_topics, data_set, vocab_map, pca, wdw_size=5):
395
    ## Get most similar topics
396
    tpc_embedding = get_embedding(topic, data_set, vocab_map, wdw_size)
397
    tpc_embedding = pca.transform([tpc_embedding])
398
    sim_topics = get_most_similar_topics(tpc_embedding[0], embedding_short, all_topics, num_wrd = len(all_topics))
399
    for topic in sim_topics:
400
        if topic in topic_map.keys():
401
            return topic_map[topic]
402
    return None
403
404
405
# In[83]:
406
407
408
'''This function checks if the user input text is an instruction allowed in chatbot or not'''
409
def is_instruction_option(str_arg):
410
    if str_arg == "exit" or str_arg == "summary" or str_arg == "reveal":
411
        return True
412
    else:
413
        return False
414
415
def print_bot():
416
    print(r"          _ _ _")
417
    print(r"         | o o |")
418
    print(r"        \|  =  |/")
419
    print(r"         -------")
420
    print(r"         |||||||")
421
    print(r"         //   \\")
422
    
423
def print_caption():
424
    print(r"    ||\\   ||  ||       ||= =||")
425
    print(r"    || \\  ||  ||       ||= =||")
426
    print(r"    ||  \\ ||  ||       ||")
427
    print(r"    ||   \\||  ||_ _ _  ||")
428
429
430
# In[ ]:
431
432
433
if __name__ == "__main__":
434
    print("Loading data ...","\n")
435
    df_text = load_data()
436
    
437
    print("Getting Vocabulary ...")
438
    data_set, vocabulary, _vocab = get_vocab_wrd_map(df_text)
439
    
440
    print("Creating context ...")
441
    vocab = get_common_vocab(1000, _vocab)
442
    vocabulary_map, vocab_map = get_vocab_map(vocabulary, vocab)
443
    
444
    print("Learning topics ...")
445
    all_topics = extract_corpus_topics(df_text)
446
    
447
    print("Getting Embeddings")
448
    embeddings = get_embedding_all(all_topics, data_set, vocab_map, 5)
449
    
450
    pca = PCA(n_components=10)
451
    embedding_short = pca.fit_transform(embeddings)
452
    
453
    print_caption()
454
    print_bot()
455
    print("Bot:> I am online!")
456
    print("Bot:> Type \"exit\" to switch to end a patient's session")
457
    print("Bot:> Type \"summary\" to view patient's discharge summary")
458
    while(True):
459
        while(True):
460
            try:
461
                pid = int(input("Bot:> What is your Patient Id [0 to "+str(df_text.shape[0]-1)+"?]"))
462
            except:
463
                continue
464
            if pid < 0 or pid > df_text.shape[0]-1:
465
                print("Bot:> Patient Id out or range!")
466
                continue
467
            else:
468
                print("Bot:> Reading Discharge Summary for Patient Id: ",pid)
469
                break
470
471
        personal_topics = extract_topics_all(df_text[pid])
472
        topic_mapping = get_topic_mapping(df_text[pid])
473
        
474
        ques = "random starter"
475
        while(ques != "exit"):
476
            ## Read Question
477
            ques = input("Bot:> How can I help ?\nPerson:>")
478
            
479
            ## Check if it is an instructional question
480
            if is_instruction_option(ques):
481
                if ques == "summary":
482
                    print("Bot:> ================= Discharge Summary for Patient Id ",pid,"\n")
483
                    print(df_text[pid])
484
                elif ques == "reveal":
485
                    print(topic_mapping, topic_mapping.keys())
486
                continue
487
                
488
            ## Extract Question topic
489
            topic_q = extract_Q_topic(ques)
490
            if topic_q is None:
491
                print("Bot:> I am a specialized NLP bot, please as a more specific question for me!")
492
                continue
493
            ans = get_extracted_answer(topic_q, df_text[pid])
494
            if ans is not None:
495
                print("Bot:> ",ans)
496
            else:
497
                ans = get_direct_answer(topic_q, topic_mapping)
498
                if ans is not None:
499
                    print("Bot:> ",ans)
500
                else:
501
                    ans = get_answer(topic_q, topic_mapping, embedding_short, all_topics, data_set, vocab_map, pca, 5)
502
                    if ans is not None:
503
                        print("Bot:> ",ans)
504
                    else:
505
                        print("Bot:> Sorry but, I have no information on this topic!")