--- a +++ b/predict.py @@ -0,0 +1,388 @@ + +from transformers import (AutoModelForTokenClassification, + AutoModelForSequenceClassification, + TrainingArguments, + AutoTokenizer, + AutoConfig, + Trainer) + +from biobert_ner.utils_ner import (convert_examples_to_features, get_labels, NerTestDataset) +from biobert_ner.utils_ner import InputExample as NerExample + +from biobert_re.utils_re import RETestDataset + +from bilstm_crf_ner.model.config import Config as BiLSTMConfig +from bilstm_crf_ner.model.ner_model import NERModel as BiLSTMModel +from bilstm_crf_ner.model.ner_learner import NERLearner as BiLSTMLearner +import en_ner_bc5cdr_md + +import numpy as np +import os +from torch import nn +from ehr import HealthRecord +from generate_data import scispacy_plus_tokenizer +from annotations import Entity +import logging + +from typing import List, Tuple + +logger = logging.getLogger(__name__) + +BIOBERT_NER_SEQ_LEN = 128 +BILSTM_NER_SEQ_LEN = 512 +BIOBERT_RE_SEQ_LEN = 128 +logging.getLogger('matplotlib.font_manager').disabled = True + +BIOBERT_NER_MODEL_DIR = "biobert_ner/output_full" +BIOBERT_RE_MODEL_DIR = "biobert_re/output_full" + +# =====BioBERT Model for NER====== +biobert_ner_labels = get_labels('biobert_ner/dataset_full/labels.txt') +biobert_ner_label_map = {i: label for i, label in enumerate(biobert_ner_labels)} +num_labels_ner = len(biobert_ner_labels) + +biobert_ner_config = AutoConfig.from_pretrained( + os.path.join(BIOBERT_NER_MODEL_DIR, "config.json"), + num_labels=num_labels_ner, + id2label=biobert_ner_label_map, + label2id={label: i for i, label in enumerate(biobert_ner_labels)}) + +biobert_ner_tokenizer = AutoTokenizer.from_pretrained( + "dmis-lab/biobert-base-cased-v1.1") + +biobert_ner_model = AutoModelForTokenClassification.from_pretrained( + os.path.join(BIOBERT_NER_MODEL_DIR, "pytorch_model.bin"), + config=biobert_ner_config) + +biobert_ner_training_args = TrainingArguments(output_dir="/tmp", do_predict=True) + +biobert_ner_trainer = Trainer(model=biobert_ner_model, args=biobert_ner_training_args) + +label_ent_map = {'DRUG': 'Drug', 'STR': 'Strength', + 'DUR': 'Duration', 'ROU': 'Route', + 'FOR': 'Form', 'ADE': 'ADE', + 'DOS': 'Dosage', 'REA': 'Reason', + 'FRE': 'Frequency'} + +# =====BiLSTM + CRF model for NER========= +bilstm_config = BiLSTMConfig() +bilstm_model = BiLSTMModel(bilstm_config) +bilstm_learn = BiLSTMLearner(bilstm_config, bilstm_model) +bilstm_learn.load("ner_15e_bilstm_crf_elmo") + +scispacy_tok = en_ner_bc5cdr_md.load().tokenizer +scispacy_plus_tokenizer.__defaults__ = (scispacy_tok,) + +# =====BioBERT Model for RE====== +re_label_list = ["0", "1"] +re_task_name = "ehr-re" + +biobert_re_config = AutoConfig.from_pretrained( + os.path.join(BIOBERT_RE_MODEL_DIR, "config.json"), + num_labels=len(re_label_list), + finetuning_task=re_task_name) + +biobert_re_model = AutoModelForSequenceClassification.from_pretrained( + os.path.join(BIOBERT_RE_MODEL_DIR, "pytorch_model.bin"), + config=biobert_re_config,) + +biobert_re_training_args = TrainingArguments(output_dir="/tmp", do_predict=True) + +biobert_re_trainer = Trainer(model=biobert_re_model, args=biobert_re_training_args) + + +def align_predictions(predictions: np.ndarray, label_ids: np.ndarray) -> List[List[str]]: + """ + Get the list of labelled predictions from model output + + Parameters + ---------- + predictions : np.ndarray + An array of shape (num_examples, seq_len, num_labels). + + label_ids : np.ndarray + An array of shape (num_examples, seq_length). + Has -100 at positions which need to be ignored. + + Returns + ------- + preds_list : List[List[str]] + Labelled output. + + """ + preds = np.argmax(predictions, axis=2) + batch_size, seq_len = preds.shape + preds_list = [[] for _ in range(batch_size)] + + for i in range(batch_size): + for j in range(seq_len): + if label_ids[i, j] != nn.CrossEntropyLoss().ignore_index: + preds_list[i].append(biobert_ner_label_map[preds[i][j]]) + + return preds_list + + +def get_chunk_type(tok: str) -> Tuple[str, str]: + """ + Args: + tok: Label in IOB format + + Returns: + tuple: ("B", "DRUG") + + """ + tag_class = tok.split('-')[0] + tag_type = tok.split('-')[-1] + + return tag_class, tag_type + + +def get_chunks(seq: List[str]) -> List[Tuple[str, int, int]]: + """ + Given a sequence of tags, group entities and their position + + Args: + seq: ["O", "O", "B-DRUG", "I-DRUG", ...] sequence of labels + + Returns: + list of (chunk_type, chunk_start, chunk_end) + + Example: + seq = ["B-DRUG", "I-DRUG", "O", "B-STR"] + result = [("DRUG", 0, 1), ("STR", 3, 3)] + + """ + default = "O" + chunks = [] + chunk_type, chunk_start = None, None + + for i, tok in enumerate(seq): + # End of a chunk 1 + if tok == default and chunk_type is not None: + # Add a chunk. + chunk = (chunk_type, chunk_start, i - 1) + chunks.append(chunk) + chunk_type, chunk_start = None, None + + # End of a chunk + start of a chunk! + elif tok != default: + tok_chunk_class, tok_chunk_type = get_chunk_type(tok) + if chunk_type is None: + chunk_type, chunk_start = tok_chunk_type, i + elif tok_chunk_type != chunk_type or tok_chunk_class == "B": + chunk = (chunk_type, chunk_start, i - 1) + chunks.append(chunk) + chunk_type, chunk_start = tok_chunk_type, i + else: + continue + + # end condition + if chunk_type is not None: + chunk = (chunk_type, chunk_start, len(seq)) + chunks.append(chunk) + + return chunks + + +# noinspection PyTypeChecker +def get_biobert_ner_predictions(test_ehr: HealthRecord) -> List[Tuple[str, int, int]]: + """ + Get predictions for a single EHR record using BioBERT + + Parameters + ---------- + test_ehr : HealthRecord + The EHR record, this object should have a tokenizer set. + + Returns + ------- + pred_entities : List[Tuple[str, int, int]] + List of predicted Entities each with the format + ("entity", start_idx, end_idx). + + """ + split_points = test_ehr.get_split_points(max_len=BIOBERT_NER_SEQ_LEN - 2) + examples = [] + + for idx in range(len(split_points) - 1): + words = test_ehr.tokens[split_points[idx]:split_points[idx + 1]] + examples.append(NerExample(guid=str(split_points[idx]), + words=words, + labels=["O"] * len(words))) + + input_features = convert_examples_to_features( + examples, + biobert_ner_labels, + max_seq_length=BIOBERT_NER_SEQ_LEN, + tokenizer=biobert_ner_tokenizer, + cls_token_at_end=False, + cls_token=biobert_ner_tokenizer.cls_token, + cls_token_segment_id=0, + sep_token=biobert_ner_tokenizer.sep_token, + sep_token_extra=False, + pad_on_left=bool(biobert_ner_tokenizer.padding_side == "left"), + pad_token=biobert_ner_tokenizer.pad_token_id, + pad_token_segment_id=biobert_ner_tokenizer.pad_token_type_id, + pad_token_label_id=nn.CrossEntropyLoss().ignore_index, + verbose=0) + + test_dataset = NerTestDataset(input_features) + + predictions, label_ids, _ = biobert_ner_trainer.predict(test_dataset) + predictions = align_predictions(predictions, label_ids) + + # Flatten the prediction list + predictions = [p for ex in predictions for p in ex] + + input_tokens = test_ehr.get_tokens() + prev_pred = "" + final_predictions = [] + idx = 0 + + for token in input_tokens: + if token.startswith("##"): + if prev_pred == "O": + final_predictions.append(prev_pred) + else: + pred_typ = prev_pred.split("-")[-1] + final_predictions.append("I-" + pred_typ) + else: + prev_pred = predictions[idx] + final_predictions.append(prev_pred) + idx += 1 + + pred_entities = [] + chunk_pred = get_chunks(final_predictions) + for ent in chunk_pred: + pred_entities.append((ent[0], + test_ehr.get_char_idx(ent[1])[0], + test_ehr.get_char_idx(ent[2])[1])) + + return pred_entities + + +def get_bilstm_ner_predictions(test_ehr: HealthRecord) -> List[Tuple[str, int, int]]: + """ + Get predictions for a single EHR record using BiLSTM + + Parameters + ---------- + test_ehr : HealthRecord + The EHR record, this object should have a tokenizer set. + + Returns + ------- + pred_entities : List[Tuple[str, int, int]] + List of predicted Entities each with the format + ("entity", start_idx, end_idx). + + """ + split_points = test_ehr.get_split_points(max_len=BILSTM_NER_SEQ_LEN) + examples = [] + + for idx in range(len(split_points) - 1): + words = test_ehr.tokens[split_points[idx]:split_points[idx + 1]] + examples.append(words) + + predictions = bilstm_learn.predict(examples) + + pred_entities = [] + for idx in range(len(split_points) - 1): + chunk_pred = get_chunks(predictions[idx]) + for ent in chunk_pred: + pred_entities.append((ent[0], + test_ehr.get_char_idx(split_points[idx] + ent[1])[0], + test_ehr.get_char_idx(split_points[idx] + ent[2])[1])) + + return pred_entities + + +# noinspection PyTypeChecker +def get_ner_predictions(ehr_record: str, model_name: str = "biobert", record_id: str = "1") -> HealthRecord: + """ + Get predictions for NER using either BioBERT or BiLSTM + + Parameters + -------------- + ehr_record : str + An EHR record in text format. + + model_name : str + The model to use for prediction. Default is biobert. + + record_id : str + The record id of the returned object. Default is 1. + + Returns + ----------- + A HealthRecord object with entities set. + """ + if model_name.lower() == "biobert": + test_ehr = HealthRecord(record_id=record_id, + text=ehr_record, + tokenizer=biobert_ner_tokenizer.tokenize, + is_bert_tokenizer=True, + is_training=False) + + predictions = get_biobert_ner_predictions(test_ehr) + + elif model_name.lower() == "bilstm": + test_ehr = HealthRecord(text=ehr_record, + tokenizer=scispacy_plus_tokenizer, + is_bert_tokenizer=False, + is_training=False) + predictions = get_bilstm_ner_predictions(test_ehr) + + else: + raise AttributeError("Accepted model names include 'biobert' " + "and 'bilstm'.") + + ent_preds = [] + for i, pred in enumerate(predictions): + ent = Entity("T%d" % i, label_ent_map[pred[0]], [pred[1], pred[2]]) + ent_text = test_ehr.text[ent[0]:ent[1]] + + if not any(letter.isalnum() for letter in ent_text): + continue + + ent.set_text(ent_text) + ent_preds.append(ent) + + test_ehr.entities = ent_preds + return test_ehr + + +def get_re_predictions(test_ehr: HealthRecord) -> HealthRecord: + """ + Get predictions for Relation Extraction. + + Parameters + ----------- + test_ehr : HealthRecord + A HealthRecord object with entities set. + + Returns + -------- + HealthRecord + The original object with relations set. + """ + test_dataset = RETestDataset(test_ehr, biobert_ner_tokenizer, + BIOBERT_RE_SEQ_LEN, re_label_list) + + if len(test_dataset) == 0: + test_ehr.relations = [] + return test_ehr + + re_predictions = biobert_re_trainer.predict(test_dataset=test_dataset).predictions + re_predictions = np.argmax(re_predictions, axis=1) + + idx = 1 + rel_preds = [] + for relation, pred in zip(test_dataset.relation_list, re_predictions): + if pred == 1: + relation.ann_id = "R%d" % idx + idx += 1 + rel_preds.append(relation) + + test_ehr.relations = rel_preds + return test_ehr