Diff of /allennlp/sentiment.py [000000] .. [2d4573]

Switch to unified view

a b/allennlp/sentiment.py
1
import sys
2
import os
3
import torch
4
import re
5
import loader
6
# from allennlp.models.archival import *
7
from allennlp.data import DatasetReader
8
from allennlp.common.params import Params
9
from allennlp.predictors.text_classifier import TextClassifierPredictor
10
import time
11
12
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
13
from ehrkit import ehrkit
14
# from config import USERNAME, PASSWORD
15
16
17
def load_glove():
18
    # Loads GLOVE model
19
    glove_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "glove_sentiment_predictor.txt")
20
    if os.path.exists(glove_path):  # same dir for github
21
        print('Loading Glove Sentiment Analysis Model')
22
        predictor = torch.load(glove_path)
23
    else:
24
        print('Downloading Glove Sentiment Analysis Model')
25
        predictor = loader.download_glove()
26
    return predictor
27
28
29
def load_roberta():
30
    # Loads Roberta model
31
    serialization_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'roberta', '')
32
    config_file = os.path.join(serialization_dir, 'config.json')
33
    if os.path.exists(config_file):
34
        print('Loading Roberta Sentiment Analysis Model')
35
        model_file = os.path.join(serialization_dir, 'whole_model.pt')
36
        model = torch.load(model_file)
37
        loaded_params = Params.from_file(config_file)
38
        dataset_reader = DatasetReader.from_params(loaded_params.get('dataset_reader'))
39
40
        # Gets predictor from model and dataset reader
41
        predictor = TextClassifierPredictor(model, dataset_reader)
42
43
        # weights_file = os.path.join(serialization_dir, 'weights.th')
44
        # loaded_model = Model.load(loaded_params, serialization_dir, weights_file) # Takes forever
45
        # archive = load_archive(os.path.join('roberta', 'model.tar.gz')) # takes forever
46
    else:
47
        print('Downloading Roberta Sentiment Analysis Model')
48
        predictor = loader.download_roberta()
49
    return predictor
50
51
52
def get_doc():
53
    doc_id = input("MIMIC Document ID [press Enter for random]: ")
54
    if doc_id == '':
55
        ehrdb.cur.execute("SELECT ROW_ID FROM mimic.NOTEEVENTS ORDER BY RAND() LIMIT 1")
56
        doc_id = ehrdb.cur.fetchall()[0][0]
57
        print('Document ID: %s' % doc_id)
58
    try:
59
        text = ehrdb.get_document(int(doc_id))
60
        clean_text = re.sub('[^A-Za-z0-9\.\,\-\/]+', ' ', text).lower()
61
        return doc_id, clean_text
62
    except:
63
        message = 'Error: There is no document with ID \'' + doc_id + '\' in mimic.NOTEEVENTS'
64
        sys.exit(message)
65
66
67
if __name__ == '__main__':
68
    # ehrdb = ehrkit.start_session(USERNAME, PASSWORD)
69
    ehrdb = ehrkit.start_session("jeremy.goldwasser@localhost", "mysql4710")
70
    doc_id, clean_text = get_doc()
71
    # print('LENGTH OF DOCUMENT: %d' % len(clean_text))
72
73
    x = input('GloVe or RoBERTa predictor [g=GloVe, r=RoBERTa]? ')
74
    if x == 'g':
75
        glove_predictor = load_glove()
76
        probs = glove_predictor.predict(clean_text)['probs']
77
    elif x == 'r':
78
        roberta_predictor = load_roberta()
79
        try:
80
            probs = roberta_predictor.predict(clean_text)['probs']
81
        except:
82
            print('Document too long for RoBERTa model. Using GLoVe instead.')
83
            glove_predictor = load_glove()
84
            probs = glove_predictor.predict(clean_text)['probs']
85
    else:
86
        sys.exit('Error: Must input \'g\' or  \'r\'')
87
88
    classification = 'Positive' if probs[0] >= 0.5 else 'Negative'
89
    print('Sentiment of document: %s' % classification)
90
91
92
    # # jeremy.goldwasser@localhost
93
    # # Save sentiment as json file
94
    # sentiment = {'text': clean_text, 'sentiment': classification, 'prob': probs[0]}
95
    # with open('predicted_sentiments/' + str(doc_id) + '.json', 'w', encoding='utf-8') as f:
96
    #     json.dump(sentiment, f, ensure_ascii=False, indent=4)
97