[0eda78]: / src / predict_single_sentence.py

Download this file

118 lines (95 with data), 4.0 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import argparse
parser = argparse.ArgumentParser(description='Enter a sentence for the model to work on.')
parser.add_argument('-l', '--length', type=int, default=128,
help='Choose the maximum length of the model\'s input layer.')
parser.add_argument('-m', '--model', type=str, default='../models/medcondbert.pth',
help='Choose the directory of the model to be used for prediction.')
parser.add_argument('-tr', '--transfer_learning', type=bool, default=False,
help='Choose whether the given model has been trained on BioBERT or not. \
Careful: It will not work if wrongly specified!')
parser.add_argument('sentence', type=str,
help='Write your input sentence, preferrably an admission note!')
parser.add_argument('-t', '--type', type=str, required=True,
help='Specify the type of annotation to process. Type of annotation needs to be one of the following: Medical Condition, Symptom, Medication, Vital Statistic, Measurement Value, Negation Cue, Medical Procedure')
args = parser.parse_args()
max_length = args.length
model_path = args.model
sentence = args.sentence
from utils.dataloader import Dataloader
from utils.BertArchitecture import BertNER, BioBertNER
from utils.metric_tracking import MetricsTracking
import torch
from torch.optim import SGD
from torch.utils.data import DataLoader
import numpy as np
import pandas as pd
from tqdm import tqdm
from transformers import BertTokenizer,BertForTokenClassification
if not args.transfer_learning:
print("Training base BERT model...")
model = BertNER(3) #O, B-, I- -> 3 entities
if args.type == 'Medical Condition':
type = 'MEDCOND'
elif args.type == 'Symptom':
type = 'SYMPTOM'
elif args.type == 'Medication':
type = 'MEDICATION'
elif args.type == 'Vital Statistic':
type = 'VITALSTAT'
elif args.type == 'Measurement Value':
type = 'MEASVAL'
elif args.type == 'Negation Cue':
type = 'NEGATION'
elif args.type == 'Medical Procedure':
type = 'PROCEDURE'
else:
raise ValueError('Type of annotation needs to be one of the following: Medical Condition, Symptom, Medication, Vital Statistic, Measurement Value, Negation Cue, Medical Procedure')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
tokenizer.add_tokens(['B-' + args.type, 'I-' + args.type])
else:
print("Training BERT model based on BioBERT diseases...")
if not args.type == 'Medical Condition':
raise ValueError('Type of annotation needs to be Medical Condition when using BioBERT as baseline.')
model = BioBertNER(3) #O, B-, I- -> 3 entities
tokenizer = BertTokenizer.from_pretrained('alvaroalon2/biobert_diseases_ner')
type = 'DISEASE'
label_to_ids = {
'B-' + type: 0,
'I-' + type: 1,
'O': 2
}
ids_to_label = {
0:'B-' + type,
1:'I-' + type,
2:'O'
}
model.load_state_dict(torch.load(model_path))
model.eval()
t_sen = tokenizer.tokenize(sentence)
sen_code = tokenizer.encode_plus(sentence,
return_tensors='pt',
add_special_tokens=True,
max_length = max_length,
padding='max_length',
return_attention_mask=True,
truncation = True
)
inputs = {key: torch.as_tensor(val) for key, val in sen_code.items()}
attention_mask = inputs['attention_mask'].squeeze(1)
input_ids = inputs['input_ids'].squeeze(1)
outputs = model(input_ids, attention_mask)
predictions = outputs.logits.argmax(dim=-1)
predictions = [ids_to_label.get(x) for x in predictions.numpy()[0]]
#beware special tokens
cutoff = min(len(predictions)-1, len(t_sen))
predictions = predictions[1:cutoff+1]
t_sen = t_sen[:cutoff]
df = pd.DataFrame({
'Sentence': t_sen,
'Prediction': predictions
})
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)
pd.set_option('display.width', 1000)
pd.set_option('display.max_colwidth', None)
print(df)