--- a +++ b/src/api.py @@ -0,0 +1,165 @@ +import argparse +parser = argparse.ArgumentParser(description='The backend of the specified frontend. Service obtains sentences and predicts entities.') + +parser.add_argument('-l', '--length', type=int, default=128, + help='Choose the maximum length of the model\'s input layer.') +parser.add_argument('-m', '--model', type=str, default='../models/medcondbert.pth', + help='Choose the directory of the model to be used for prediction.') +parser.add_argument('-tr', '--transfer_learning', type=bool, default=False, + help='Choose whether the given model has been trained on BioBERT or not. \ + Careful: It will not work if wrongly specified!') +parser.add_argument('-p', '--port', type=int, default=5000, + help='The port on which the model is going to run.') +parser.add_argument('-t', '--type', type=str, required=True, + help='Specify the type of annotation to process. Type of annotation needs to be one of the following: Medical Condition, Symptom, Medication, Vital Statistic, Measurement Value, Negation Cue, Medical Procedure') + +args = parser.parse_args() + +max_length = args.length +model_path = args.model +transfer_learning = args.transfer_learning +port = args.port + +print("Preparing model...") + +from gevent.pywsgi import WSGIServer # Imports the WSGIServer +from gevent import monkey; monkey.patch_all() +from flask import Flask, request, jsonify +from flask_cors import CORS +from utils.dataloader import Dataloader +from utils.BertArchitecture import BertNER, BioBertNER +from utils.metric_tracking import MetricsTracking +import torch +from torch.optim import SGD +from torch.utils.data import DataLoader +import numpy as np +import pandas as pd +from tqdm import tqdm +from transformers import BertTokenizer,BertForTokenClassification +import spacy + +# initializing backend +if not args.transfer_learning: + print("Training base BERT model...") + model = BertNER(3) #O, B-, I- -> 3 entities + + if args.type == 'Medical Condition': + type = 'MEDCOND' + elif args.type == 'Symptom': + type = 'SYMPTOM' + elif args.type == 'Medication': + type = 'MEDICATION' + elif args.type == 'Vital Statistic': + type = 'VITALSTAT' + elif args.type == 'Measurement Value': + type = 'MEASVAL' + elif args.type == 'Negation Cue': + type = 'NEGATION' + elif args.type == 'Medical Procedure': + type = 'PROCEDURE' + else: + raise ValueError('Type of annotation needs to be one of the following: Medical Condition, Symptom, Medication, Vital Statistic, Measurement Value, Negation Cue, Medical Procedure') + + tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + tokenizer.add_tokens(['B-' + args.type, 'I-' + args.type]) +else: + print("Training BERT model based on BioBERT diseases...") + + if not args.type == 'Medical Condition': + raise ValueError('Type of annotation needs to be Medical Condition when using BioBERT as baseline.') + + model = BioBertNER(3) #O, B-, I- -> 3 entities + tokenizer = BertTokenizer.from_pretrained('alvaroalon2/biobert_diseases_ner') + type = 'DISEASE' + +label_to_ids = { + 'B-' + type: 0, + 'I-' + type: 1, + 'O': 2 + } + +ids_to_label = { + 0:'B-' + type, + 1:'I-' + type, + 2:'O' + } + +model.load_state_dict(torch.load(model_path)) +model.eval() + +app = Flask(__name__) +CORS(app) # Initialize CORS + +sentence_detector = spacy.load("en_core_web_sm") + +print("Serving API now...") + +def predict_sentence(sentence): + t_sen = tokenizer.tokenize(sentence) + + sen_code = tokenizer.encode_plus(sentence, + return_tensors='pt', + add_special_tokens=True, + max_length = max_length, + padding='max_length', + return_attention_mask=True, + truncation = True + ) + inputs = {key: torch.as_tensor(val) for key, val in sen_code.items()} + + attention_mask = inputs['attention_mask'].squeeze(1) + input_ids = inputs['input_ids'].squeeze(1) + + outputs = model(input_ids, attention_mask) + + predictions = outputs.logits.argmax(dim=-1) + predictions = [ids_to_label.get(x) for x in predictions.numpy()[0]] + + #beware special tokens + cutoff = min(len(predictions)-1, len(t_sen)) + predictions = predictions[1:cutoff+1] + t_sen = t_sen[:cutoff] + + return t_sen, predictions + +def clean(tokens, labels): + cleaned_tokens = [] + cleaned_labels = [] + cnt = 1 + + for i in range(len(tokens)): #same length + if tokens[i].startswith("##") and len(cleaned_tokens) > 0: + cleaned_tokens[i-cnt] = cleaned_tokens[i-cnt] + tokens[i][2:] + cnt = cnt + 1 + else: + cleaned_tokens.append(tokens[i]) + cleaned_labels.append(labels[i]) + + return cleaned_tokens, cleaned_labels + + +def handle_request(data): + sentences = sentence_detector(data).sents + + tokens = [] + labels = [] + + for sentence in sentences: + new_tokens, new_labels = predict_sentence(sentence.text) + tokens = tokens + new_tokens + labels = labels + new_labels + + cleaned_tokens, cleaned_labels = clean(tokens, labels) + + return cleaned_tokens, cleaned_labels + +@app.route('/extract_entities', methods=['POST']) +def main(): + text = request.get_data(as_text=True) + result = handle_request(text) + return jsonify({'tokens': result[0], 'entities': result[1]}) + +if __name__ == '__main__': + LISTEN = ('0.0.0.0',port) + http_server = WSGIServer( LISTEN, app ) + http_server.serve_forever()