a b/MedicalRelationExtractor/unibiased.py
1
#AUTHOR: RAHUL VERMA and SPIRO RAZIS
2
import sys
3
import re
4
import pprint
5
import numpy
6
from sklearn import svm
7
from sklearn import linear_model
8
import time
9
from random import shuffle
10
11
start_time = time.time()
12
13
numpy.set_printoptions(threshold=numpy.nan)
14
15
def parseTextViaPMCID(textFile, pmcidFeatureList, uniqueWordsDictionary,lim):
16
    
17
    if textFile.startswith("beneficial"):
18
        #print("beneficial")
19
        fileType = "beneficial".encode('utf-8')
20
    elif textFile.startswith("harmful"):
21
        #print("harmful")
22
        fileType = "harmful".encode('utf-8')
23
    else:
24
        #print("invalid file name")
25
        sys.exit(2)
26
    limit = 0 
27
    entryCount       = 0
28
    disease          = ""
29
    causeOrTreatment = ""
30
    relation         = ""
31
    newEntry = False    
32
    
33
    
34
    with open(textFile, "r") as openedTextFile:
35
        for line in openedTextFile:
36
            if limit < lim:
37
                if line.startswith("pmcid   : "): #it's the idNumber
38
                    entryCount += 1
39
                    newEntry = True
40
                elif line.startswith("sentence: "): #it's a sentence
41
                    pass
42
                      
43
                elif line.startswith("entities: "): #it's the two in a relationship
44
                    disease = line[11:line.index(",")].lower().encode('utf-8')
45
                    causeOrTreatment = line[(line.index(",")+2):-2].lower().encode('utf-8')
46
                    
47
                    #add disease and cause/treatment to dictionary of unique words/phrases
48
                    if disease not in uniqueWordsDictionary:
49
                        uniqueWordsDictionary[disease] = {}
50
                    if causeOrTreatment not in uniqueWordsDictionary:
51
                        uniqueWordsDictionary[causeOrTreatment] = {}                    
52
    
53
                elif line.startswith("offsets : "): #the position of the entities
54
                    pass
55
                elif line.startswith("relation: "): #the actual relationship
56
                    relation = line[10:-1].lower().encode('utf-8')                   
57
                else:             
58
                    if line.startswith("\n") and (newEntry == True):
59
                        pmcidFeatureList.append([disease, causeOrTreatment, relation, fileType])
60
                        disease          = ""
61
                        causeOrTreatment = ""
62
                        relation         = ""
63
                        newEntry = False
64
                        limit += 1
65
                    else:
66
                        print("invalid line: %s" %(line))
67
                        sys.exit(2)
68
            else: break
69
70
    return (pmcidFeatureList, entryCount, uniqueWordsDictionary)
71
72
73
def printFeatureWithCellValue(numpyRow, featureRow):
74
    for index, feature in enumerate(featureRow):
75
        print("%s: %d" %(feature, numpyRow[index]))
76
    print("harmfulOrBeneficial: %d" %(numpyRow[-1]))
77
    return
78
79
def printFeaturesWithValuesEqualOne(numpyRow, featureRow):
80
    for index, feature in enumerate(featureRow):
81
        if numpyRow[index] == 1:
82
            print("%s: %d" %(feature, numpyRow[index]))
83
    print("harmfulOrBeneficial: %d" %(numpyRow[-1]))
84
    return
85
86
87
def parseEntitiesIntoUnigrams(beneficialFile, harmfulFile, beneficialLimit, harmfulLimit):
88
    
89
    beneficialEntry = 0
90
    harmfulEntry = 0
91
    entitiesTrainingDictionary= {}
92
    
93
    entityUnigramList           = []
94
    
95
    beneficialFullEntitiesList  = []
96
    harmfulFullEntitiesList     = []
97
    
98
    sentenceUnigramList         = []
99
    
100
    beneficialSplitSentences    = []
101
    harmfulSplitSentences       = []
102
    
103
    entityUnigrams              = {}
104
    harmfulUnigrams             = {}
105
    beneficialUnigrams          = {}
106
    sentenceFeatureUnigrams             = {}
107
108
    testArrayForWritingEntries = numpy.empty(shape = (1, 1), dtype = "S128")
109
    
110
    #WORKING ON THE UNIGRAMS OF THE TRAINING BENEFICIAL ENTITIES HERE
111
    with open(beneficialFile, "r") as openedBeneficialFile:
112
        for line in openedBeneficialFile:
113
            if beneficialEntry < beneficialLimit:
114
                if line.startswith("entities: "):
115
                    #individual entities
116
                    disease = line[11:line.index(",")].lower().encode('utf-8')
117
                    causeOrTreatment = line[(line.index(",")+2):-2].lower().encode('utf-8')     
118
                    if disease not in entitiesTrainingDictionary:
119
                        entitiesTrainingDictionary[disease] = {}
120
                    if causeOrTreatment not in entitiesTrainingDictionary:
121
                        entitiesTrainingDictionary[causeOrTreatment] = {}  
122
                    #unigrams composing the entities
123
                    entityUnigramList = re.split("-|, |\. |\/| ", line[11:-2].lower())                    
124
                    for entry in entityUnigramList:
125
                        if (entry != "") and (entry not in entityUnigrams):
126
                            entityUnigrams[entry] = {}
127
                    beneficialEntry += 1
128
            else: break                
129
    #WORKING ON THE TRAINING HARMFUL ENTITIES HERE
130
    with open(harmfulFile, "r") as openedHarmfulFile:
131
        for line in openedHarmfulFile:
132
            if harmfulEntry < harmfulLimit:
133
                if line.startswith("entities: "):
134
                    #individual entities                    
135
                    disease = line[11:line.index(",")].lower().encode('utf-8')
136
                    causeOrTreatment = line[(line.index(",")+2):-2].lower().encode('utf-8')                    
137
                    if disease not in entitiesTrainingDictionary:
138
                        entitiesTrainingDictionary[disease] = {}
139
                    if causeOrTreatment not in entitiesTrainingDictionary:
140
                        entitiesTrainingDictionary[causeOrTreatment] = {}                    
141
                    entityUnigramList = re.split("-|, |\. |\/| ", line[11:-2].lower())                    
142
                    for entry in entityUnigramList:
143
                        if (entry != "")  and (entry not in entityUnigrams):
144
                            entityUnigrams[entry] = {}
145
                    harmfulEntry += 1
146
            else: break
147
                                       
148
    beneficialEntry = 0
149
    mostRecentPMCID = ""
150
    with open(beneficialFile, "r") as openedBeneficialFile:
151
        for line in openedBeneficialFile:
152
            if line.startswith("pmcid   : "): #it's the pmcid line
153
                mostRecentPMCID = line[11:-1]
154
            elif line.startswith("sentence: "):
155
                sentenceUnigramList = re.split("\—|\-|\, |\.|\/|\(|\)|\'|\"|\[|\]|\ |\“|\”|\,|\d|\<|\>|\:|\$|\%|\*|\′", line[10:-2].lower())                    
156
                beneficialSplitSentences.append(sentenceUnigramList)
157
                
158
                if beneficialEntry < beneficialLimit: 
159
                    for word in sentenceUnigramList:
160
                        if (word != "") and (word not in entityUnigrams):
161
                            if word not in sentenceFeatureUnigrams:
162
                                try: 
163
                                    testArrayForWritingEntries[0,0] = word
164
                                    sentenceFeatureUnigrams[word] = {}
165
                                    sentenceFeatureUnigrams[word]["beneficial"] = {}
166
                                    sentenceFeatureUnigrams[word]["beneficial"]["pmcid"] = {}
167
                                    sentenceFeatureUnigrams[word]["beneficial"]["pmcid"][mostRecentPMCID] = {}
168
                                    sentenceFeatureUnigrams[word]["beneficial"]["count"] = 0
169
                                    
170
                                    sentenceFeatureUnigrams[word]["harmful"] = {}
171
                                    sentenceFeatureUnigrams[word]["harmful"]["pmcid"] = {}
172
                                    sentenceFeatureUnigrams[word]["harmful"]["count"] = 0
173
                                except UnicodeEncodeError: pass 
174
                            else: #it is in the feature unigrams already, so add the 
175
                                if mostRecentPMCID not in sentenceFeatureUnigrams[word]["beneficial"]["pmcid"]: #and the same pmcid isn't already there
176
                                    sentenceFeatureUnigrams[word]["beneficial"]["pmcid"][mostRecentPMCID] = {}
177
                                    
178
                beneficialEntry += 1
179
            elif line.startswith("entities: "):
180
                #individual entities
181
                disease = line[11:line.index(",")].lower().encode('utf-8')
182
                causeOrTreatment = line[(line.index(",")+2):-2].lower().encode('utf-8')     
183
                                          
184
                beneficialFullEntitiesList.append([disease, causeOrTreatment]) 
185
                    
186
            else: pass
187
                   
188
                    
189
    harmfulEntry = 0
190
    mostRecentPMCID = ""
191
    with open(harmfulFile, "r") as openedHarmfulFile:
192
        for line in openedHarmfulFile:
193
            if line.startswith("pmcid   : "): #it's the pmcid line
194
                mostRecentPMCID = line[11:-1]                
195
            elif line.startswith("sentence: "):
196
                sentenceUnigramList = re.split("\—|\-|\, |\.|\/|\(|\)|\'|\"|\[|\]|\ |\“|\”|\,|\d|\<|\>|\:|\$|\%|\*|\′", line[10:-2].lower())
197
                harmfulSplitSentences.append(sentenceUnigramList)
198
                if harmfulEntry < harmfulLimit: 
199
                    for word in sentenceUnigramList:
200
                        if (word != "") and (word not in entityUnigrams):
201
                            if word not in sentenceFeatureUnigrams:
202
                                try:
203
                                    testArrayForWritingEntries[0,0] = word
204
                                    sentenceFeatureUnigrams[word] = {}
205
                                    sentenceFeatureUnigrams[word]["beneficial"] = {}
206
                                    sentenceFeatureUnigrams[word]["beneficial"]["pmcid"] = {}
207
                                    sentenceFeatureUnigrams[word]["beneficial"]["count"] = 0
208
                                    
209
                                    sentenceFeatureUnigrams[word]["harmful"] = {}
210
                                    sentenceFeatureUnigrams[word]["harmful"]["pmcid"] = {}
211
                                    sentenceFeatureUnigrams[word]["harmful"]["pmcid"][mostRecentPMCID] = {}
212
                                    sentenceFeatureUnigrams[word]["harmful"]["count"] = 0
213
                                except UnicodeEncodeError: pass 
214
                            else:
215
                                if mostRecentPMCID not in sentenceFeatureUnigrams[word]["harmful"]["pmcid"]: #and the same pmcid isn't already there
216
                                    sentenceFeatureUnigrams[word]["harmful"]["pmcid"][mostRecentPMCID] = {}
217
                harmfulEntry += 1
218
            elif line.startswith("entities: "):
219
                disease = line[11:line.index(",")].lower().encode('utf-8')
220
                causeOrTreatment = line[(line.index(",")+2):-2].lower().encode('utf-8')                    
221
                                 
222
                harmfulFullEntitiesList.append([disease, causeOrTreatment])
223
            
224
            else: pass
225
226
                    
227
                    
228
                    
229
    for word in sentenceFeatureUnigrams:
230
        for benefitHarmfulOrEntity in sentenceFeatureUnigrams[word]:
231
            #start counting!
232
            for pmcid in sentenceFeatureUnigrams[word][benefitHarmfulOrEntity]["pmcid"]:
233
                sentenceFeatureUnigrams[word][benefitHarmfulOrEntity]["count"] += 1
234
                
235
        if (sentenceFeatureUnigrams[word]["beneficial"]["count"] > 1) or (sentenceFeatureUnigrams[word]["harmful"]["count"] > 1):
236
            if sentenceFeatureUnigrams[word]["beneficial"]["count"] > (2*sentenceFeatureUnigrams[word]["harmful"]["count"]):
237
                beneficialUnigrams[word] = {}
238
            elif sentenceFeatureUnigrams[word]["harmful"]["count"] > (2*sentenceFeatureUnigrams[word]["beneficial"]["count"]):
239
                harmfulUnigrams[word] = {}
240
            else: pass #the words can't be categorized one way or the other
241
        
242
                                        
243
                    
244
    return (entitiesTrainingDictionary, 
245
            beneficialUnigrams, harmfulUnigrams, 
246
            beneficialEntry, harmfulEntry, 
247
            beneficialSplitSentences, harmfulSplitSentences,
248
            beneficialFullEntitiesList, harmfulFullEntitiesList)
249
250
251
def main(argv):
252
    #Python3 training.py beneficial.txt harmful.txt
253
    if len(argv) != 3:
254
        print("invalid number of arguments")
255
        sys.exit(2)
256
    
257
    #two separate lists because don't know how many entries in each, so dividing one list will be difficult
258
    (entitiesTrainingDictionary, beneficialUnigrams, harmfulUnigrams, 
259
                beneficialCount, harmfulCount, pmcidBeneficialSentences, pmcidHarmfulSentences, 
260
                beneficialFullEntitiesList, harmfulFullEntitiesList) = parseEntitiesIntoUnigrams(argv[1], argv[2], 10356, 9797)
261
262
    benprec = 10356/beneficialCount
263
    harmprec = 9797/harmfulCount
264
265
266
    numFeatures = len(entitiesTrainingDictionary) + len(beneficialUnigrams) + len(harmfulUnigrams) + 1  #plus 1 for harmful or beneficial
267
    uniqueFeaturesArray = numpy.empty(shape = (1, numFeatures), dtype="S128")
268
    
269
    #place the dictionary words into the array
270
    for index, feature in enumerate(entitiesTrainingDictionary):
271
        uniqueFeaturesArray[0, index] = feature
272
273
    finalColumn = len(entitiesTrainingDictionary)
274
    
275
    for index, feature in enumerate(beneficialUnigrams):
276
        currentColumn = index + finalColumn
277
        uniqueFeaturesArray[0, currentColumn] = feature
278
    finalColumn += len(beneficialUnigrams)
279
280
    for index, feature in enumerate(harmfulUnigrams):
281
        currentColumn = index + finalColumn
282
        uniqueFeaturesArray[0, currentColumn] = feature 
283
    
284
    uniqueFeaturesArray[0][:-1].sort()
285
286
    beneficial80Percent = int(beneficialCount * benprec)-1
287
    beneficial20Percent = int(beneficialCount - beneficial80Percent)
288
    harmful80Percent    = int(harmfulCount * harmprec)-1
289
    harmful20Percent    = int(harmfulCount - harmful80Percent)
290
  
291
    trainArray = numpy.empty(shape=((beneficial80Percent + harmful80Percent), numFeatures), dtype=numpy.int8) #Default is numpy.float64
292
    testArray   = numpy.empty(shape=((beneficial20Percent + harmful20Percent), numFeatures), dtype=numpy.int8)
293
    
294
    #training data
295
    for entry in range(0, beneficial80Percent):
296
        #for each entry, find the index of the given feature
297
        for word in pmcidBeneficialSentences[entry]:
298
            #get the index of the given feature
299
            featureColumn = numpy.searchsorted(uniqueFeaturesArray[0][:-1], word.encode("utf-8"))
300
            if uniqueFeaturesArray[0][featureColumn] == word.encode("utf-8"):
301
                trainArray[entry, featureColumn] = 1
302
303
        for entity in beneficialFullEntitiesList[entry]:
304
            featureColumn = numpy.searchsorted(uniqueFeaturesArray[0][:-1], entity)
305
            if uniqueFeaturesArray[0][featureColumn] == entity:
306
                trainArray[entry, featureColumn] = 1            
307
        trainArray[entry, -1] = 1
308
    
309
    for entry in range(0, harmful80Percent):
310
        trainingEntry = entry + beneficial80Percent
311
        for word in pmcidHarmfulSentences[entry]:
312
            featureColumn = numpy.searchsorted(uniqueFeaturesArray[0][:-1], word.encode("utf-8"))
313
            if uniqueFeaturesArray[0][featureColumn] == word.encode("utf-8"):
314
                trainArray[trainingEntry, featureColumn] = 1            
315
        for entity in harmfulFullEntitiesList[entry]:    
316
            #get the index of the given feature
317
            featureColumn = numpy.searchsorted(uniqueFeaturesArray[0][:-1], entity)
318
            if uniqueFeaturesArray[0][featureColumn] == entity:
319
                trainArray[trainingEntry, featureColumn] = 1
320
321
    #test data
322
    for entry in range(0, beneficial20Percent):
323
        dataEntry = entry + beneficial80Percent #finding next beneficial entry, starting from 60% until 80%        
324
        for word in pmcidBeneficialSentences[dataEntry]:
325
            #get the index of the given feature
326
            featureColumn = numpy.searchsorted(uniqueFeaturesArray[0][:-1], word.encode("utf-8"))
327
            if uniqueFeaturesArray[0][featureColumn] == word.encode("utf-8"):
328
                testArray[entry, featureColumn] = 1
329
                    
330
        for entity in beneficialFullEntitiesList[dataEntry]:
331
            featureColumn = numpy.searchsorted(uniqueFeaturesArray[0][:-1], entity)
332
            if uniqueFeaturesArray[0][featureColumn] == entity:
333
                testArray[entry, featureColumn] = 1
334
        testArray[entry, -1] = 1 
335
        
336
        
337
        
338
    for entry in range(0, harmful20Percent):
339
        dataEntry = entry + harmful80Percent # finding the next harmful entry starting from 60% until 80%
340
        testEntry  = entry + beneficial20Percent #because the prior data entered ended with beneficial20Percent
341
342
        for word in pmcidHarmfulSentences[dataEntry]:
343
            featureColumn = numpy.searchsorted(uniqueFeaturesArray[0][:-1], feature.encode("utf-8"))
344
            if uniqueFeaturesArray[0][featureColumn] == word.encode("utf-8"):
345
                testArray[testEntry, featureColumn] = 1
346
              
347
        for entity in harmfulFullEntitiesList[dataEntry]:
348
            featureColumn = numpy.searchsorted(uniqueFeaturesArray[0][:-1], entity)
349
            if uniqueFeaturesArray[0][featureColumn] == entity:
350
                testArray[testEntry, featureColumn] = 1
351
            
352
 
353
    ###################################################CLASSIFICATION SECTION################################################################
354
    
355
    #Here we set up our list for support vectors and our  list for classes.
356
    #We will setup lists to hold our support vectors our classes.
357
    supportVectorsL = []
358
    classesListL = []
359
360
    for row in trainArray:
361
        y1 = row[len(row)-1]
362
        supportVectorsL.append(row[:-1])
363
        classesListL.append(y1)
364
    #Here we initialize our Linear classifier
365
    supportVectors = numpy.asarray(supportVectorsL)
366
    classesList = numpy.asarray(classesListL)
367
    #Here we try out the linear regresion stuff
368
    classifier = linear_model.LogisticRegression()
369
    classifier.fit(supportVectors,classesList)
370
    ############Test our sets through our logisitc model##################
371
    print("--------------------LOGISTIC------------------------")
372
    logistic(classifier,testArray,"TEST")
373
374
    print("--------------------SVM------------------------")
375
    #Here we set up the svm
376
    classifier = svm.SVC()
377
    classifier.fit(supportVectors,classesList)
378
    classifier.kernel="linear"
379
    ############Test our sets through our SVM model##################
380
    SVC(classifier,testArray,"TEST") 
381
    
382
    sys.exit(0)
383
384
def SVC(classifier, testArray,t):
385
    testpredictionarray = []
386
    for row in testArray:
387
        predictionvector = row[:-1]
388
        if 1 in predictionvector:
389
            predictionvector = [predictionvector]
390
            prediction = classifier.predict(predictionvector)
391
            pre = int(prediction[0])
392
        else:
393
            pre = -1
394
        testpredictionarray.append(pre)
395
    totalAccuray(testArray,testpredictionarray,t)
396
    featAccuracy(testArray,testpredictionarray,t,1)
397
    featAccuracy(testArray,testpredictionarray,t,2)
398
399
def logistic(classifier, testArray,t):
400
    testpredictionarray = []
401
    for row in testArray:
402
        predictionvector = row[:-1]
403
        if 1 in predictionvector:
404
            predictionvector = [predictionvector]
405
            prediction = classifier.predict(predictionvector)
406
            pre = int(prediction[0])
407
        else:
408
            pre = -1
409
        testpredictionarray.append(pre)
410
    totalAccuray(testArray,testpredictionarray,t)
411
    featAccuracy(testArray,testpredictionarray,t,1)
412
    featAccuracy(testArray,testpredictionarray,t,2)
413
414
415
def totalAccuray(testArray,testpredictionarray,t):
416
    testcounter = 0
417
    #here we test for accuracy in the test set results.
418
    for x in range(0,len(testArray)):
419
        t1= testArray[x][len(testArray[x])-1]
420
        t1 = int(t1)
421
        if t1 == testpredictionarray[x]:
422
            testcounter = testcounter + 1
423
    accuracy= testcounter/len(testArray)   
424
    print(t+" set accuracy = " + str(accuracy))        
425
426
def featAccuracy(testArray,testpredictionarray,t,y):
427
    actual = 0
428
    testcounter = 0
429
    for x in range(0,len(testArray)):
430
        l = list(testArray[x])
431
        c = l.count(1)
432
        if c == y:
433
            actual+=1
434
            t1= testArray[x][len(testArray[x])-1]
435
            t1 = int(t1)
436
            if t1 == testpredictionarray[x]:
437
                testcounter = testcounter + 1
438
    try:
439
        accuracy= testcounter/actual
440
    except ZeroDivisionError:
441
        print(t+" set accuracy for only "+str(y)+" feature vectors = UNDEFINED")
442
        return
443
    
444
    print(t+" set accuracy for only "+str(y)+" feature vectors = " + str(accuracy))        
445
446
447
448
main(sys.argv)
449
450
#