a b/predict.py
1
2
from transformers import (AutoModelForTokenClassification,
3
                          AutoModelForSequenceClassification,
4
                          TrainingArguments,
5
                          AutoTokenizer,
6
                          AutoConfig,
7
                          Trainer)
8
9
from biobert_ner.utils_ner import (convert_examples_to_features, get_labels, NerTestDataset)
10
from biobert_ner.utils_ner import InputExample as NerExample
11
12
from biobert_re.utils_re import RETestDataset
13
14
from bilstm_crf_ner.model.config import Config as BiLSTMConfig
15
from bilstm_crf_ner.model.ner_model import NERModel as BiLSTMModel
16
from bilstm_crf_ner.model.ner_learner import NERLearner as BiLSTMLearner
17
import en_ner_bc5cdr_md
18
19
import numpy as np
20
import os
21
from torch import nn
22
from ehr import HealthRecord
23
from generate_data import scispacy_plus_tokenizer
24
from annotations import Entity
25
import logging
26
27
from typing import List, Tuple
28
29
logger = logging.getLogger(__name__)
30
31
BIOBERT_NER_SEQ_LEN = 128
32
BILSTM_NER_SEQ_LEN = 512
33
BIOBERT_RE_SEQ_LEN = 128
34
logging.getLogger('matplotlib.font_manager').disabled = True
35
36
BIOBERT_NER_MODEL_DIR = "biobert_ner/output_full"
37
BIOBERT_RE_MODEL_DIR = "biobert_re/output_full"
38
39
# =====BioBERT Model for NER======
40
biobert_ner_labels = get_labels('biobert_ner/dataset_full/labels.txt')
41
biobert_ner_label_map = {i: label for i, label in enumerate(biobert_ner_labels)}
42
num_labels_ner = len(biobert_ner_labels)
43
44
biobert_ner_config = AutoConfig.from_pretrained(
45
    os.path.join(BIOBERT_NER_MODEL_DIR, "config.json"),
46
    num_labels=num_labels_ner,
47
    id2label=biobert_ner_label_map,
48
    label2id={label: i for i, label in enumerate(biobert_ner_labels)})
49
50
biobert_ner_tokenizer = AutoTokenizer.from_pretrained(
51
    "dmis-lab/biobert-base-cased-v1.1")
52
53
biobert_ner_model = AutoModelForTokenClassification.from_pretrained(
54
    os.path.join(BIOBERT_NER_MODEL_DIR, "pytorch_model.bin"),
55
    config=biobert_ner_config)
56
57
biobert_ner_training_args = TrainingArguments(output_dir="/tmp", do_predict=True)
58
59
biobert_ner_trainer = Trainer(model=biobert_ner_model, args=biobert_ner_training_args)
60
61
label_ent_map = {'DRUG': 'Drug', 'STR': 'Strength',
62
                 'DUR': 'Duration', 'ROU': 'Route',
63
                 'FOR': 'Form', 'ADE': 'ADE',
64
                 'DOS': 'Dosage', 'REA': 'Reason',
65
                 'FRE': 'Frequency'}
66
67
# =====BiLSTM + CRF model for NER=========
68
bilstm_config = BiLSTMConfig()
69
bilstm_model = BiLSTMModel(bilstm_config)
70
bilstm_learn = BiLSTMLearner(bilstm_config, bilstm_model)
71
bilstm_learn.load("ner_15e_bilstm_crf_elmo")
72
73
scispacy_tok = en_ner_bc5cdr_md.load().tokenizer
74
scispacy_plus_tokenizer.__defaults__ = (scispacy_tok,)
75
76
# =====BioBERT Model for RE======
77
re_label_list = ["0", "1"]
78
re_task_name = "ehr-re"
79
80
biobert_re_config = AutoConfig.from_pretrained(
81
    os.path.join(BIOBERT_RE_MODEL_DIR, "config.json"),
82
    num_labels=len(re_label_list),
83
    finetuning_task=re_task_name)
84
85
biobert_re_model = AutoModelForSequenceClassification.from_pretrained(
86
    os.path.join(BIOBERT_RE_MODEL_DIR, "pytorch_model.bin"),
87
    config=biobert_re_config,)
88
89
biobert_re_training_args = TrainingArguments(output_dir="/tmp", do_predict=True)
90
91
biobert_re_trainer = Trainer(model=biobert_re_model, args=biobert_re_training_args)
92
93
94
def align_predictions(predictions: np.ndarray, label_ids: np.ndarray) -> List[List[str]]:
95
    """
96
    Get the list of labelled predictions from model output
97
98
    Parameters
99
    ----------
100
    predictions : np.ndarray
101
        An array of shape (num_examples, seq_len, num_labels).
102
103
    label_ids : np.ndarray
104
        An array of shape (num_examples, seq_length).
105
        Has -100 at positions which need to be ignored.
106
107
    Returns
108
    -------
109
    preds_list : List[List[str]]
110
        Labelled output.
111
112
    """
113
    preds = np.argmax(predictions, axis=2)
114
    batch_size, seq_len = preds.shape
115
    preds_list = [[] for _ in range(batch_size)]
116
117
    for i in range(batch_size):
118
        for j in range(seq_len):
119
            if label_ids[i, j] != nn.CrossEntropyLoss().ignore_index:
120
                preds_list[i].append(biobert_ner_label_map[preds[i][j]])
121
122
    return preds_list
123
124
125
def get_chunk_type(tok: str) -> Tuple[str, str]:
126
    """
127
    Args:
128
        tok: Label in IOB format
129
130
    Returns:
131
        tuple: ("B", "DRUG")
132
133
    """
134
    tag_class = tok.split('-')[0]
135
    tag_type = tok.split('-')[-1]
136
137
    return tag_class, tag_type
138
139
140
def get_chunks(seq: List[str]) -> List[Tuple[str, int, int]]:
141
    """
142
    Given a sequence of tags, group entities and their position
143
144
    Args:
145
        seq: ["O", "O", "B-DRUG", "I-DRUG", ...] sequence of labels
146
147
    Returns:
148
        list of (chunk_type, chunk_start, chunk_end)
149
150
    Example:
151
        seq = ["B-DRUG", "I-DRUG", "O", "B-STR"]
152
        result = [("DRUG", 0, 1), ("STR", 3, 3)]
153
154
    """
155
    default = "O"
156
    chunks = []
157
    chunk_type, chunk_start = None, None
158
159
    for i, tok in enumerate(seq):
160
        # End of a chunk 1
161
        if tok == default and chunk_type is not None:
162
            # Add a chunk.
163
            chunk = (chunk_type, chunk_start, i - 1)
164
            chunks.append(chunk)
165
            chunk_type, chunk_start = None, None
166
167
        # End of a chunk + start of a chunk!
168
        elif tok != default:
169
            tok_chunk_class, tok_chunk_type = get_chunk_type(tok)
170
            if chunk_type is None:
171
                chunk_type, chunk_start = tok_chunk_type, i
172
            elif tok_chunk_type != chunk_type or tok_chunk_class == "B":
173
                chunk = (chunk_type, chunk_start, i - 1)
174
                chunks.append(chunk)
175
                chunk_type, chunk_start = tok_chunk_type, i
176
        else:
177
            continue
178
179
    # end condition
180
    if chunk_type is not None:
181
        chunk = (chunk_type, chunk_start, len(seq))
182
        chunks.append(chunk)
183
184
    return chunks
185
186
187
# noinspection PyTypeChecker
188
def get_biobert_ner_predictions(test_ehr: HealthRecord) -> List[Tuple[str, int, int]]:
189
    """
190
    Get predictions for a single EHR record using BioBERT
191
192
    Parameters
193
    ----------
194
    test_ehr : HealthRecord
195
        The EHR record, this object should have a tokenizer set.
196
197
    Returns
198
    -------
199
    pred_entities : List[Tuple[str, int, int]]
200
        List of predicted Entities each with the format
201
        ("entity", start_idx, end_idx).
202
203
    """
204
    split_points = test_ehr.get_split_points(max_len=BIOBERT_NER_SEQ_LEN - 2)
205
    examples = []
206
207
    for idx in range(len(split_points) - 1):
208
        words = test_ehr.tokens[split_points[idx]:split_points[idx + 1]]
209
        examples.append(NerExample(guid=str(split_points[idx]),
210
                                   words=words,
211
                                   labels=["O"] * len(words)))
212
213
    input_features = convert_examples_to_features(
214
        examples,
215
        biobert_ner_labels,
216
        max_seq_length=BIOBERT_NER_SEQ_LEN,
217
        tokenizer=biobert_ner_tokenizer,
218
        cls_token_at_end=False,
219
        cls_token=biobert_ner_tokenizer.cls_token,
220
        cls_token_segment_id=0,
221
        sep_token=biobert_ner_tokenizer.sep_token,
222
        sep_token_extra=False,
223
        pad_on_left=bool(biobert_ner_tokenizer.padding_side == "left"),
224
        pad_token=biobert_ner_tokenizer.pad_token_id,
225
        pad_token_segment_id=biobert_ner_tokenizer.pad_token_type_id,
226
        pad_token_label_id=nn.CrossEntropyLoss().ignore_index,
227
        verbose=0)
228
229
    test_dataset = NerTestDataset(input_features)
230
231
    predictions, label_ids, _ = biobert_ner_trainer.predict(test_dataset)
232
    predictions = align_predictions(predictions, label_ids)
233
234
    # Flatten the prediction list
235
    predictions = [p for ex in predictions for p in ex]
236
237
    input_tokens = test_ehr.get_tokens()
238
    prev_pred = ""
239
    final_predictions = []
240
    idx = 0
241
242
    for token in input_tokens:
243
        if token.startswith("##"):
244
            if prev_pred == "O":
245
                final_predictions.append(prev_pred)
246
            else:
247
                pred_typ = prev_pred.split("-")[-1]
248
                final_predictions.append("I-" + pred_typ)
249
        else:
250
            prev_pred = predictions[idx]
251
            final_predictions.append(prev_pred)
252
            idx += 1
253
254
    pred_entities = []
255
    chunk_pred = get_chunks(final_predictions)
256
    for ent in chunk_pred:
257
        pred_entities.append((ent[0],
258
                              test_ehr.get_char_idx(ent[1])[0],
259
                              test_ehr.get_char_idx(ent[2])[1]))
260
261
    return pred_entities
262
263
264
def get_bilstm_ner_predictions(test_ehr: HealthRecord) -> List[Tuple[str, int, int]]:
265
    """
266
    Get predictions for a single EHR record using BiLSTM
267
268
    Parameters
269
    ----------
270
    test_ehr : HealthRecord
271
        The EHR record, this object should have a tokenizer set.
272
273
    Returns
274
    -------
275
    pred_entities : List[Tuple[str, int, int]]
276
        List of predicted Entities each with the format
277
        ("entity", start_idx, end_idx).
278
279
    """
280
    split_points = test_ehr.get_split_points(max_len=BILSTM_NER_SEQ_LEN)
281
    examples = []
282
283
    for idx in range(len(split_points) - 1):
284
        words = test_ehr.tokens[split_points[idx]:split_points[idx + 1]]
285
        examples.append(words)
286
287
    predictions = bilstm_learn.predict(examples)
288
289
    pred_entities = []
290
    for idx in range(len(split_points) - 1):
291
        chunk_pred = get_chunks(predictions[idx])
292
        for ent in chunk_pred:
293
            pred_entities.append((ent[0],
294
                                  test_ehr.get_char_idx(split_points[idx] + ent[1])[0],
295
                                  test_ehr.get_char_idx(split_points[idx] + ent[2])[1]))
296
297
    return pred_entities
298
299
300
# noinspection PyTypeChecker
301
def get_ner_predictions(ehr_record: str, model_name: str = "biobert", record_id: str = "1") -> HealthRecord:
302
    """
303
    Get predictions for NER using either BioBERT or BiLSTM
304
305
    Parameters
306
    --------------
307
    ehr_record : str
308
        An EHR record in text format.
309
310
    model_name : str
311
        The model to use for prediction. Default is biobert.
312
313
    record_id : str
314
        The record id of the returned object. Default is 1.
315
316
    Returns
317
    -----------
318
    A HealthRecord object with entities set.
319
    """
320
    if model_name.lower() == "biobert":
321
        test_ehr = HealthRecord(record_id=record_id,
322
                                text=ehr_record,
323
                                tokenizer=biobert_ner_tokenizer.tokenize,
324
                                is_bert_tokenizer=True,
325
                                is_training=False)
326
327
        predictions = get_biobert_ner_predictions(test_ehr)
328
329
    elif model_name.lower() == "bilstm":
330
        test_ehr = HealthRecord(text=ehr_record,
331
                                tokenizer=scispacy_plus_tokenizer,
332
                                is_bert_tokenizer=False,
333
                                is_training=False)
334
        predictions = get_bilstm_ner_predictions(test_ehr)
335
336
    else:
337
        raise AttributeError("Accepted model names include 'biobert' "
338
                             "and 'bilstm'.")
339
340
    ent_preds = []
341
    for i, pred in enumerate(predictions):
342
        ent = Entity("T%d" % i, label_ent_map[pred[0]], [pred[1], pred[2]])
343
        ent_text = test_ehr.text[ent[0]:ent[1]]
344
345
        if not any(letter.isalnum() for letter in ent_text):
346
            continue
347
348
        ent.set_text(ent_text)
349
        ent_preds.append(ent)
350
351
    test_ehr.entities = ent_preds
352
    return test_ehr
353
354
355
def get_re_predictions(test_ehr: HealthRecord) -> HealthRecord:
356
    """
357
    Get predictions for Relation Extraction.
358
359
    Parameters
360
    -----------
361
    test_ehr : HealthRecord
362
        A HealthRecord object with entities set.
363
364
    Returns
365
    --------
366
    HealthRecord
367
        The original object with relations set.
368
    """
369
    test_dataset = RETestDataset(test_ehr, biobert_ner_tokenizer,
370
                                 BIOBERT_RE_SEQ_LEN, re_label_list)
371
372
    if len(test_dataset) == 0:
373
        test_ehr.relations = []
374
        return test_ehr
375
376
    re_predictions = biobert_re_trainer.predict(test_dataset=test_dataset).predictions
377
    re_predictions = np.argmax(re_predictions, axis=1)
378
379
    idx = 1
380
    rel_preds = []
381
    for relation, pred in zip(test_dataset.relation_list, re_predictions):
382
        if pred == 1:
383
            relation.ann_id = "R%d" % idx
384
            idx += 1
385
            rel_preds.append(relation)
386
387
    test_ehr.relations = rel_preds
388
    return test_ehr