|
a |
|
b/predict.py |
|
|
1 |
|
|
|
2 |
from transformers import (AutoModelForTokenClassification, |
|
|
3 |
AutoModelForSequenceClassification, |
|
|
4 |
TrainingArguments, |
|
|
5 |
AutoTokenizer, |
|
|
6 |
AutoConfig, |
|
|
7 |
Trainer) |
|
|
8 |
|
|
|
9 |
from biobert_ner.utils_ner import (convert_examples_to_features, get_labels, NerTestDataset) |
|
|
10 |
from biobert_ner.utils_ner import InputExample as NerExample |
|
|
11 |
|
|
|
12 |
from biobert_re.utils_re import RETestDataset |
|
|
13 |
|
|
|
14 |
from bilstm_crf_ner.model.config import Config as BiLSTMConfig |
|
|
15 |
from bilstm_crf_ner.model.ner_model import NERModel as BiLSTMModel |
|
|
16 |
from bilstm_crf_ner.model.ner_learner import NERLearner as BiLSTMLearner |
|
|
17 |
import en_ner_bc5cdr_md |
|
|
18 |
|
|
|
19 |
import numpy as np |
|
|
20 |
import os |
|
|
21 |
from torch import nn |
|
|
22 |
from ehr import HealthRecord |
|
|
23 |
from generate_data import scispacy_plus_tokenizer |
|
|
24 |
from annotations import Entity |
|
|
25 |
import logging |
|
|
26 |
|
|
|
27 |
from typing import List, Tuple |
|
|
28 |
|
|
|
29 |
logger = logging.getLogger(__name__) |
|
|
30 |
|
|
|
31 |
BIOBERT_NER_SEQ_LEN = 128 |
|
|
32 |
BILSTM_NER_SEQ_LEN = 512 |
|
|
33 |
BIOBERT_RE_SEQ_LEN = 128 |
|
|
34 |
logging.getLogger('matplotlib.font_manager').disabled = True |
|
|
35 |
|
|
|
36 |
BIOBERT_NER_MODEL_DIR = "biobert_ner/output_full" |
|
|
37 |
BIOBERT_RE_MODEL_DIR = "biobert_re/output_full" |
|
|
38 |
|
|
|
39 |
# =====BioBERT Model for NER====== |
|
|
40 |
biobert_ner_labels = get_labels('biobert_ner/dataset_full/labels.txt') |
|
|
41 |
biobert_ner_label_map = {i: label for i, label in enumerate(biobert_ner_labels)} |
|
|
42 |
num_labels_ner = len(biobert_ner_labels) |
|
|
43 |
|
|
|
44 |
biobert_ner_config = AutoConfig.from_pretrained( |
|
|
45 |
os.path.join(BIOBERT_NER_MODEL_DIR, "config.json"), |
|
|
46 |
num_labels=num_labels_ner, |
|
|
47 |
id2label=biobert_ner_label_map, |
|
|
48 |
label2id={label: i for i, label in enumerate(biobert_ner_labels)}) |
|
|
49 |
|
|
|
50 |
biobert_ner_tokenizer = AutoTokenizer.from_pretrained( |
|
|
51 |
"dmis-lab/biobert-base-cased-v1.1") |
|
|
52 |
|
|
|
53 |
biobert_ner_model = AutoModelForTokenClassification.from_pretrained( |
|
|
54 |
os.path.join(BIOBERT_NER_MODEL_DIR, "pytorch_model.bin"), |
|
|
55 |
config=biobert_ner_config) |
|
|
56 |
|
|
|
57 |
biobert_ner_training_args = TrainingArguments(output_dir="/tmp", do_predict=True) |
|
|
58 |
|
|
|
59 |
biobert_ner_trainer = Trainer(model=biobert_ner_model, args=biobert_ner_training_args) |
|
|
60 |
|
|
|
61 |
label_ent_map = {'DRUG': 'Drug', 'STR': 'Strength', |
|
|
62 |
'DUR': 'Duration', 'ROU': 'Route', |
|
|
63 |
'FOR': 'Form', 'ADE': 'ADE', |
|
|
64 |
'DOS': 'Dosage', 'REA': 'Reason', |
|
|
65 |
'FRE': 'Frequency'} |
|
|
66 |
|
|
|
67 |
# =====BiLSTM + CRF model for NER========= |
|
|
68 |
bilstm_config = BiLSTMConfig() |
|
|
69 |
bilstm_model = BiLSTMModel(bilstm_config) |
|
|
70 |
bilstm_learn = BiLSTMLearner(bilstm_config, bilstm_model) |
|
|
71 |
bilstm_learn.load("ner_15e_bilstm_crf_elmo") |
|
|
72 |
|
|
|
73 |
scispacy_tok = en_ner_bc5cdr_md.load().tokenizer |
|
|
74 |
scispacy_plus_tokenizer.__defaults__ = (scispacy_tok,) |
|
|
75 |
|
|
|
76 |
# =====BioBERT Model for RE====== |
|
|
77 |
re_label_list = ["0", "1"] |
|
|
78 |
re_task_name = "ehr-re" |
|
|
79 |
|
|
|
80 |
biobert_re_config = AutoConfig.from_pretrained( |
|
|
81 |
os.path.join(BIOBERT_RE_MODEL_DIR, "config.json"), |
|
|
82 |
num_labels=len(re_label_list), |
|
|
83 |
finetuning_task=re_task_name) |
|
|
84 |
|
|
|
85 |
biobert_re_model = AutoModelForSequenceClassification.from_pretrained( |
|
|
86 |
os.path.join(BIOBERT_RE_MODEL_DIR, "pytorch_model.bin"), |
|
|
87 |
config=biobert_re_config,) |
|
|
88 |
|
|
|
89 |
biobert_re_training_args = TrainingArguments(output_dir="/tmp", do_predict=True) |
|
|
90 |
|
|
|
91 |
biobert_re_trainer = Trainer(model=biobert_re_model, args=biobert_re_training_args) |
|
|
92 |
|
|
|
93 |
|
|
|
94 |
def align_predictions(predictions: np.ndarray, label_ids: np.ndarray) -> List[List[str]]: |
|
|
95 |
""" |
|
|
96 |
Get the list of labelled predictions from model output |
|
|
97 |
|
|
|
98 |
Parameters |
|
|
99 |
---------- |
|
|
100 |
predictions : np.ndarray |
|
|
101 |
An array of shape (num_examples, seq_len, num_labels). |
|
|
102 |
|
|
|
103 |
label_ids : np.ndarray |
|
|
104 |
An array of shape (num_examples, seq_length). |
|
|
105 |
Has -100 at positions which need to be ignored. |
|
|
106 |
|
|
|
107 |
Returns |
|
|
108 |
------- |
|
|
109 |
preds_list : List[List[str]] |
|
|
110 |
Labelled output. |
|
|
111 |
|
|
|
112 |
""" |
|
|
113 |
preds = np.argmax(predictions, axis=2) |
|
|
114 |
batch_size, seq_len = preds.shape |
|
|
115 |
preds_list = [[] for _ in range(batch_size)] |
|
|
116 |
|
|
|
117 |
for i in range(batch_size): |
|
|
118 |
for j in range(seq_len): |
|
|
119 |
if label_ids[i, j] != nn.CrossEntropyLoss().ignore_index: |
|
|
120 |
preds_list[i].append(biobert_ner_label_map[preds[i][j]]) |
|
|
121 |
|
|
|
122 |
return preds_list |
|
|
123 |
|
|
|
124 |
|
|
|
125 |
def get_chunk_type(tok: str) -> Tuple[str, str]: |
|
|
126 |
""" |
|
|
127 |
Args: |
|
|
128 |
tok: Label in IOB format |
|
|
129 |
|
|
|
130 |
Returns: |
|
|
131 |
tuple: ("B", "DRUG") |
|
|
132 |
|
|
|
133 |
""" |
|
|
134 |
tag_class = tok.split('-')[0] |
|
|
135 |
tag_type = tok.split('-')[-1] |
|
|
136 |
|
|
|
137 |
return tag_class, tag_type |
|
|
138 |
|
|
|
139 |
|
|
|
140 |
def get_chunks(seq: List[str]) -> List[Tuple[str, int, int]]: |
|
|
141 |
""" |
|
|
142 |
Given a sequence of tags, group entities and their position |
|
|
143 |
|
|
|
144 |
Args: |
|
|
145 |
seq: ["O", "O", "B-DRUG", "I-DRUG", ...] sequence of labels |
|
|
146 |
|
|
|
147 |
Returns: |
|
|
148 |
list of (chunk_type, chunk_start, chunk_end) |
|
|
149 |
|
|
|
150 |
Example: |
|
|
151 |
seq = ["B-DRUG", "I-DRUG", "O", "B-STR"] |
|
|
152 |
result = [("DRUG", 0, 1), ("STR", 3, 3)] |
|
|
153 |
|
|
|
154 |
""" |
|
|
155 |
default = "O" |
|
|
156 |
chunks = [] |
|
|
157 |
chunk_type, chunk_start = None, None |
|
|
158 |
|
|
|
159 |
for i, tok in enumerate(seq): |
|
|
160 |
# End of a chunk 1 |
|
|
161 |
if tok == default and chunk_type is not None: |
|
|
162 |
# Add a chunk. |
|
|
163 |
chunk = (chunk_type, chunk_start, i - 1) |
|
|
164 |
chunks.append(chunk) |
|
|
165 |
chunk_type, chunk_start = None, None |
|
|
166 |
|
|
|
167 |
# End of a chunk + start of a chunk! |
|
|
168 |
elif tok != default: |
|
|
169 |
tok_chunk_class, tok_chunk_type = get_chunk_type(tok) |
|
|
170 |
if chunk_type is None: |
|
|
171 |
chunk_type, chunk_start = tok_chunk_type, i |
|
|
172 |
elif tok_chunk_type != chunk_type or tok_chunk_class == "B": |
|
|
173 |
chunk = (chunk_type, chunk_start, i - 1) |
|
|
174 |
chunks.append(chunk) |
|
|
175 |
chunk_type, chunk_start = tok_chunk_type, i |
|
|
176 |
else: |
|
|
177 |
continue |
|
|
178 |
|
|
|
179 |
# end condition |
|
|
180 |
if chunk_type is not None: |
|
|
181 |
chunk = (chunk_type, chunk_start, len(seq)) |
|
|
182 |
chunks.append(chunk) |
|
|
183 |
|
|
|
184 |
return chunks |
|
|
185 |
|
|
|
186 |
|
|
|
187 |
# noinspection PyTypeChecker |
|
|
188 |
def get_biobert_ner_predictions(test_ehr: HealthRecord) -> List[Tuple[str, int, int]]: |
|
|
189 |
""" |
|
|
190 |
Get predictions for a single EHR record using BioBERT |
|
|
191 |
|
|
|
192 |
Parameters |
|
|
193 |
---------- |
|
|
194 |
test_ehr : HealthRecord |
|
|
195 |
The EHR record, this object should have a tokenizer set. |
|
|
196 |
|
|
|
197 |
Returns |
|
|
198 |
------- |
|
|
199 |
pred_entities : List[Tuple[str, int, int]] |
|
|
200 |
List of predicted Entities each with the format |
|
|
201 |
("entity", start_idx, end_idx). |
|
|
202 |
|
|
|
203 |
""" |
|
|
204 |
split_points = test_ehr.get_split_points(max_len=BIOBERT_NER_SEQ_LEN - 2) |
|
|
205 |
examples = [] |
|
|
206 |
|
|
|
207 |
for idx in range(len(split_points) - 1): |
|
|
208 |
words = test_ehr.tokens[split_points[idx]:split_points[idx + 1]] |
|
|
209 |
examples.append(NerExample(guid=str(split_points[idx]), |
|
|
210 |
words=words, |
|
|
211 |
labels=["O"] * len(words))) |
|
|
212 |
|
|
|
213 |
input_features = convert_examples_to_features( |
|
|
214 |
examples, |
|
|
215 |
biobert_ner_labels, |
|
|
216 |
max_seq_length=BIOBERT_NER_SEQ_LEN, |
|
|
217 |
tokenizer=biobert_ner_tokenizer, |
|
|
218 |
cls_token_at_end=False, |
|
|
219 |
cls_token=biobert_ner_tokenizer.cls_token, |
|
|
220 |
cls_token_segment_id=0, |
|
|
221 |
sep_token=biobert_ner_tokenizer.sep_token, |
|
|
222 |
sep_token_extra=False, |
|
|
223 |
pad_on_left=bool(biobert_ner_tokenizer.padding_side == "left"), |
|
|
224 |
pad_token=biobert_ner_tokenizer.pad_token_id, |
|
|
225 |
pad_token_segment_id=biobert_ner_tokenizer.pad_token_type_id, |
|
|
226 |
pad_token_label_id=nn.CrossEntropyLoss().ignore_index, |
|
|
227 |
verbose=0) |
|
|
228 |
|
|
|
229 |
test_dataset = NerTestDataset(input_features) |
|
|
230 |
|
|
|
231 |
predictions, label_ids, _ = biobert_ner_trainer.predict(test_dataset) |
|
|
232 |
predictions = align_predictions(predictions, label_ids) |
|
|
233 |
|
|
|
234 |
# Flatten the prediction list |
|
|
235 |
predictions = [p for ex in predictions for p in ex] |
|
|
236 |
|
|
|
237 |
input_tokens = test_ehr.get_tokens() |
|
|
238 |
prev_pred = "" |
|
|
239 |
final_predictions = [] |
|
|
240 |
idx = 0 |
|
|
241 |
|
|
|
242 |
for token in input_tokens: |
|
|
243 |
if token.startswith("##"): |
|
|
244 |
if prev_pred == "O": |
|
|
245 |
final_predictions.append(prev_pred) |
|
|
246 |
else: |
|
|
247 |
pred_typ = prev_pred.split("-")[-1] |
|
|
248 |
final_predictions.append("I-" + pred_typ) |
|
|
249 |
else: |
|
|
250 |
prev_pred = predictions[idx] |
|
|
251 |
final_predictions.append(prev_pred) |
|
|
252 |
idx += 1 |
|
|
253 |
|
|
|
254 |
pred_entities = [] |
|
|
255 |
chunk_pred = get_chunks(final_predictions) |
|
|
256 |
for ent in chunk_pred: |
|
|
257 |
pred_entities.append((ent[0], |
|
|
258 |
test_ehr.get_char_idx(ent[1])[0], |
|
|
259 |
test_ehr.get_char_idx(ent[2])[1])) |
|
|
260 |
|
|
|
261 |
return pred_entities |
|
|
262 |
|
|
|
263 |
|
|
|
264 |
def get_bilstm_ner_predictions(test_ehr: HealthRecord) -> List[Tuple[str, int, int]]: |
|
|
265 |
""" |
|
|
266 |
Get predictions for a single EHR record using BiLSTM |
|
|
267 |
|
|
|
268 |
Parameters |
|
|
269 |
---------- |
|
|
270 |
test_ehr : HealthRecord |
|
|
271 |
The EHR record, this object should have a tokenizer set. |
|
|
272 |
|
|
|
273 |
Returns |
|
|
274 |
------- |
|
|
275 |
pred_entities : List[Tuple[str, int, int]] |
|
|
276 |
List of predicted Entities each with the format |
|
|
277 |
("entity", start_idx, end_idx). |
|
|
278 |
|
|
|
279 |
""" |
|
|
280 |
split_points = test_ehr.get_split_points(max_len=BILSTM_NER_SEQ_LEN) |
|
|
281 |
examples = [] |
|
|
282 |
|
|
|
283 |
for idx in range(len(split_points) - 1): |
|
|
284 |
words = test_ehr.tokens[split_points[idx]:split_points[idx + 1]] |
|
|
285 |
examples.append(words) |
|
|
286 |
|
|
|
287 |
predictions = bilstm_learn.predict(examples) |
|
|
288 |
|
|
|
289 |
pred_entities = [] |
|
|
290 |
for idx in range(len(split_points) - 1): |
|
|
291 |
chunk_pred = get_chunks(predictions[idx]) |
|
|
292 |
for ent in chunk_pred: |
|
|
293 |
pred_entities.append((ent[0], |
|
|
294 |
test_ehr.get_char_idx(split_points[idx] + ent[1])[0], |
|
|
295 |
test_ehr.get_char_idx(split_points[idx] + ent[2])[1])) |
|
|
296 |
|
|
|
297 |
return pred_entities |
|
|
298 |
|
|
|
299 |
|
|
|
300 |
# noinspection PyTypeChecker |
|
|
301 |
def get_ner_predictions(ehr_record: str, model_name: str = "biobert", record_id: str = "1") -> HealthRecord: |
|
|
302 |
""" |
|
|
303 |
Get predictions for NER using either BioBERT or BiLSTM |
|
|
304 |
|
|
|
305 |
Parameters |
|
|
306 |
-------------- |
|
|
307 |
ehr_record : str |
|
|
308 |
An EHR record in text format. |
|
|
309 |
|
|
|
310 |
model_name : str |
|
|
311 |
The model to use for prediction. Default is biobert. |
|
|
312 |
|
|
|
313 |
record_id : str |
|
|
314 |
The record id of the returned object. Default is 1. |
|
|
315 |
|
|
|
316 |
Returns |
|
|
317 |
----------- |
|
|
318 |
A HealthRecord object with entities set. |
|
|
319 |
""" |
|
|
320 |
if model_name.lower() == "biobert": |
|
|
321 |
test_ehr = HealthRecord(record_id=record_id, |
|
|
322 |
text=ehr_record, |
|
|
323 |
tokenizer=biobert_ner_tokenizer.tokenize, |
|
|
324 |
is_bert_tokenizer=True, |
|
|
325 |
is_training=False) |
|
|
326 |
|
|
|
327 |
predictions = get_biobert_ner_predictions(test_ehr) |
|
|
328 |
|
|
|
329 |
elif model_name.lower() == "bilstm": |
|
|
330 |
test_ehr = HealthRecord(text=ehr_record, |
|
|
331 |
tokenizer=scispacy_plus_tokenizer, |
|
|
332 |
is_bert_tokenizer=False, |
|
|
333 |
is_training=False) |
|
|
334 |
predictions = get_bilstm_ner_predictions(test_ehr) |
|
|
335 |
|
|
|
336 |
else: |
|
|
337 |
raise AttributeError("Accepted model names include 'biobert' " |
|
|
338 |
"and 'bilstm'.") |
|
|
339 |
|
|
|
340 |
ent_preds = [] |
|
|
341 |
for i, pred in enumerate(predictions): |
|
|
342 |
ent = Entity("T%d" % i, label_ent_map[pred[0]], [pred[1], pred[2]]) |
|
|
343 |
ent_text = test_ehr.text[ent[0]:ent[1]] |
|
|
344 |
|
|
|
345 |
if not any(letter.isalnum() for letter in ent_text): |
|
|
346 |
continue |
|
|
347 |
|
|
|
348 |
ent.set_text(ent_text) |
|
|
349 |
ent_preds.append(ent) |
|
|
350 |
|
|
|
351 |
test_ehr.entities = ent_preds |
|
|
352 |
return test_ehr |
|
|
353 |
|
|
|
354 |
|
|
|
355 |
def get_re_predictions(test_ehr: HealthRecord) -> HealthRecord: |
|
|
356 |
""" |
|
|
357 |
Get predictions for Relation Extraction. |
|
|
358 |
|
|
|
359 |
Parameters |
|
|
360 |
----------- |
|
|
361 |
test_ehr : HealthRecord |
|
|
362 |
A HealthRecord object with entities set. |
|
|
363 |
|
|
|
364 |
Returns |
|
|
365 |
-------- |
|
|
366 |
HealthRecord |
|
|
367 |
The original object with relations set. |
|
|
368 |
""" |
|
|
369 |
test_dataset = RETestDataset(test_ehr, biobert_ner_tokenizer, |
|
|
370 |
BIOBERT_RE_SEQ_LEN, re_label_list) |
|
|
371 |
|
|
|
372 |
if len(test_dataset) == 0: |
|
|
373 |
test_ehr.relations = [] |
|
|
374 |
return test_ehr |
|
|
375 |
|
|
|
376 |
re_predictions = biobert_re_trainer.predict(test_dataset=test_dataset).predictions |
|
|
377 |
re_predictions = np.argmax(re_predictions, axis=1) |
|
|
378 |
|
|
|
379 |
idx = 1 |
|
|
380 |
rel_preds = [] |
|
|
381 |
for relation, pred in zip(test_dataset.relation_list, re_predictions): |
|
|
382 |
if pred == 1: |
|
|
383 |
relation.ann_id = "R%d" % idx |
|
|
384 |
idx += 1 |
|
|
385 |
rel_preds.append(relation) |
|
|
386 |
|
|
|
387 |
test_ehr.relations = rel_preds |
|
|
388 |
return test_ehr |