Diff of /src/api.py [000000] .. [0eda78]

Switch to unified view

a b/src/api.py
1
import argparse
2
parser = argparse.ArgumentParser(description='The backend of the specified frontend. Service obtains sentences and predicts entities.')
3
4
parser.add_argument('-l', '--length', type=int, default=128,
5
                    help='Choose the maximum length of the model\'s input layer.')
6
parser.add_argument('-m', '--model', type=str, default='../models/medcondbert.pth',
7
                    help='Choose the directory of the model to be used for prediction.')
8
parser.add_argument('-tr', '--transfer_learning', type=bool, default=False,
9
                    help='Choose whether the given model has been trained on BioBERT or not. \
10
                    Careful: It will not work if wrongly specified!')
11
parser.add_argument('-p', '--port', type=int, default=5000,
12
                    help='The port on which the model is going to run.')
13
parser.add_argument('-t', '--type', type=str, required=True,
14
                    help='Specify the type of annotation to process. Type of annotation needs to be one of the following: Medical Condition, Symptom, Medication, Vital Statistic, Measurement Value, Negation Cue, Medical Procedure')
15
16
args = parser.parse_args()
17
18
max_length = args.length
19
model_path = args.model
20
transfer_learning = args.transfer_learning
21
port = args.port
22
23
print("Preparing model...")
24
25
from gevent.pywsgi import WSGIServer # Imports the WSGIServer
26
from gevent import monkey; monkey.patch_all()
27
from flask import Flask, request, jsonify
28
from flask_cors import CORS
29
from utils.dataloader import Dataloader
30
from utils.BertArchitecture import BertNER, BioBertNER
31
from utils.metric_tracking import MetricsTracking
32
import torch
33
from torch.optim import SGD
34
from torch.utils.data import DataLoader
35
import numpy as np
36
import pandas as pd
37
from tqdm import tqdm
38
from transformers import BertTokenizer,BertForTokenClassification
39
import spacy
40
41
# initializing backend
42
if not args.transfer_learning:
43
    print("Training base BERT model...")
44
    model = BertNER(3) #O, B-, I- -> 3 entities
45
46
    if args.type == 'Medical Condition':
47
        type = 'MEDCOND'
48
    elif args.type == 'Symptom':
49
        type = 'SYMPTOM'
50
    elif args.type == 'Medication':
51
        type = 'MEDICATION'
52
    elif args.type == 'Vital Statistic':
53
        type = 'VITALSTAT'
54
    elif args.type == 'Measurement Value':
55
        type = 'MEASVAL'
56
    elif args.type == 'Negation Cue':
57
        type = 'NEGATION'
58
    elif args.type == 'Medical Procedure':
59
        type = 'PROCEDURE'
60
    else:    
61
        raise ValueError('Type of annotation needs to be one of the following: Medical Condition, Symptom, Medication, Vital Statistic, Measurement Value, Negation Cue, Medical Procedure')
62
    
63
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
64
    tokenizer.add_tokens(['B-' + args.type, 'I-' + args.type])
65
else:
66
    print("Training BERT model based on BioBERT diseases...")
67
68
    if not args.type == 'Medical Condition':
69
        raise ValueError('Type of annotation needs to be Medical Condition when using BioBERT as baseline.')
70
71
    model = BioBertNER(3) #O, B-, I- -> 3 entities
72
    tokenizer = BertTokenizer.from_pretrained('alvaroalon2/biobert_diseases_ner')
73
    type = 'DISEASE'
74
75
label_to_ids = {
76
    'B-' + type: 0,
77
    'I-' + type: 1,
78
    'O': 2
79
    }
80
81
ids_to_label = {
82
    0:'B-' + type,
83
    1:'I-' + type,
84
    2:'O'
85
    }
86
87
model.load_state_dict(torch.load(model_path))
88
model.eval()
89
90
app = Flask(__name__)
91
CORS(app)  # Initialize CORS
92
93
sentence_detector = spacy.load("en_core_web_sm")
94
95
print("Serving API now...")
96
97
def predict_sentence(sentence):
98
    t_sen = tokenizer.tokenize(sentence)
99
100
    sen_code = tokenizer.encode_plus(sentence,
101
        return_tensors='pt',
102
        add_special_tokens=True,
103
        max_length = max_length,
104
        padding='max_length',
105
        return_attention_mask=True,
106
        truncation = True
107
        )
108
    inputs = {key: torch.as_tensor(val) for key, val in sen_code.items()}
109
110
    attention_mask = inputs['attention_mask'].squeeze(1)
111
    input_ids = inputs['input_ids'].squeeze(1)
112
113
    outputs = model(input_ids, attention_mask)
114
115
    predictions = outputs.logits.argmax(dim=-1)
116
    predictions = [ids_to_label.get(x) for x in predictions.numpy()[0]]
117
118
    #beware special tokens
119
    cutoff = min(len(predictions)-1, len(t_sen))
120
    predictions = predictions[1:cutoff+1]
121
    t_sen = t_sen[:cutoff]
122
123
    return t_sen, predictions
124
125
def clean(tokens, labels):
126
    cleaned_tokens = []
127
    cleaned_labels = []
128
    cnt = 1
129
130
    for i in range(len(tokens)): #same length
131
        if tokens[i].startswith("##") and len(cleaned_tokens) > 0:
132
            cleaned_tokens[i-cnt] = cleaned_tokens[i-cnt] + tokens[i][2:]
133
            cnt = cnt + 1
134
        else:
135
            cleaned_tokens.append(tokens[i])
136
            cleaned_labels.append(labels[i])
137
138
    return cleaned_tokens, cleaned_labels
139
140
141
def handle_request(data):
142
    sentences = sentence_detector(data).sents
143
144
    tokens = []
145
    labels = []
146
147
    for sentence in sentences:
148
        new_tokens, new_labels = predict_sentence(sentence.text)
149
        tokens = tokens + new_tokens
150
        labels = labels + new_labels
151
152
    cleaned_tokens, cleaned_labels = clean(tokens, labels)
153
154
    return cleaned_tokens, cleaned_labels
155
156
@app.route('/extract_entities', methods=['POST'])
157
def main():
158
    text = request.get_data(as_text=True)
159
    result = handle_request(text)
160
    return jsonify({'tokens': result[0], 'entities': result[1]})
161
162
if __name__ == '__main__':
163
    LISTEN = ('0.0.0.0',port)
164
    http_server = WSGIServer( LISTEN, app )
165
    http_server.serve_forever()