Diff of /predict.py [000000] .. [1de6ed]

Switch to side-by-side view

--- a
+++ b/predict.py
@@ -0,0 +1,388 @@
+
+from transformers import (AutoModelForTokenClassification,
+                          AutoModelForSequenceClassification,
+                          TrainingArguments,
+                          AutoTokenizer,
+                          AutoConfig,
+                          Trainer)
+
+from biobert_ner.utils_ner import (convert_examples_to_features, get_labels, NerTestDataset)
+from biobert_ner.utils_ner import InputExample as NerExample
+
+from biobert_re.utils_re import RETestDataset
+
+from bilstm_crf_ner.model.config import Config as BiLSTMConfig
+from bilstm_crf_ner.model.ner_model import NERModel as BiLSTMModel
+from bilstm_crf_ner.model.ner_learner import NERLearner as BiLSTMLearner
+import en_ner_bc5cdr_md
+
+import numpy as np
+import os
+from torch import nn
+from ehr import HealthRecord
+from generate_data import scispacy_plus_tokenizer
+from annotations import Entity
+import logging
+
+from typing import List, Tuple
+
+logger = logging.getLogger(__name__)
+
+BIOBERT_NER_SEQ_LEN = 128
+BILSTM_NER_SEQ_LEN = 512
+BIOBERT_RE_SEQ_LEN = 128
+logging.getLogger('matplotlib.font_manager').disabled = True
+
+BIOBERT_NER_MODEL_DIR = "biobert_ner/output_full"
+BIOBERT_RE_MODEL_DIR = "biobert_re/output_full"
+
+# =====BioBERT Model for NER======
+biobert_ner_labels = get_labels('biobert_ner/dataset_full/labels.txt')
+biobert_ner_label_map = {i: label for i, label in enumerate(biobert_ner_labels)}
+num_labels_ner = len(biobert_ner_labels)
+
+biobert_ner_config = AutoConfig.from_pretrained(
+    os.path.join(BIOBERT_NER_MODEL_DIR, "config.json"),
+    num_labels=num_labels_ner,
+    id2label=biobert_ner_label_map,
+    label2id={label: i for i, label in enumerate(biobert_ner_labels)})
+
+biobert_ner_tokenizer = AutoTokenizer.from_pretrained(
+    "dmis-lab/biobert-base-cased-v1.1")
+
+biobert_ner_model = AutoModelForTokenClassification.from_pretrained(
+    os.path.join(BIOBERT_NER_MODEL_DIR, "pytorch_model.bin"),
+    config=biobert_ner_config)
+
+biobert_ner_training_args = TrainingArguments(output_dir="/tmp", do_predict=True)
+
+biobert_ner_trainer = Trainer(model=biobert_ner_model, args=biobert_ner_training_args)
+
+label_ent_map = {'DRUG': 'Drug', 'STR': 'Strength',
+                 'DUR': 'Duration', 'ROU': 'Route',
+                 'FOR': 'Form', 'ADE': 'ADE',
+                 'DOS': 'Dosage', 'REA': 'Reason',
+                 'FRE': 'Frequency'}
+
+# =====BiLSTM + CRF model for NER=========
+bilstm_config = BiLSTMConfig()
+bilstm_model = BiLSTMModel(bilstm_config)
+bilstm_learn = BiLSTMLearner(bilstm_config, bilstm_model)
+bilstm_learn.load("ner_15e_bilstm_crf_elmo")
+
+scispacy_tok = en_ner_bc5cdr_md.load().tokenizer
+scispacy_plus_tokenizer.__defaults__ = (scispacy_tok,)
+
+# =====BioBERT Model for RE======
+re_label_list = ["0", "1"]
+re_task_name = "ehr-re"
+
+biobert_re_config = AutoConfig.from_pretrained(
+    os.path.join(BIOBERT_RE_MODEL_DIR, "config.json"),
+    num_labels=len(re_label_list),
+    finetuning_task=re_task_name)
+
+biobert_re_model = AutoModelForSequenceClassification.from_pretrained(
+    os.path.join(BIOBERT_RE_MODEL_DIR, "pytorch_model.bin"),
+    config=biobert_re_config,)
+
+biobert_re_training_args = TrainingArguments(output_dir="/tmp", do_predict=True)
+
+biobert_re_trainer = Trainer(model=biobert_re_model, args=biobert_re_training_args)
+
+
+def align_predictions(predictions: np.ndarray, label_ids: np.ndarray) -> List[List[str]]:
+    """
+    Get the list of labelled predictions from model output
+
+    Parameters
+    ----------
+    predictions : np.ndarray
+        An array of shape (num_examples, seq_len, num_labels).
+
+    label_ids : np.ndarray
+        An array of shape (num_examples, seq_length).
+        Has -100 at positions which need to be ignored.
+
+    Returns
+    -------
+    preds_list : List[List[str]]
+        Labelled output.
+
+    """
+    preds = np.argmax(predictions, axis=2)
+    batch_size, seq_len = preds.shape
+    preds_list = [[] for _ in range(batch_size)]
+
+    for i in range(batch_size):
+        for j in range(seq_len):
+            if label_ids[i, j] != nn.CrossEntropyLoss().ignore_index:
+                preds_list[i].append(biobert_ner_label_map[preds[i][j]])
+
+    return preds_list
+
+
+def get_chunk_type(tok: str) -> Tuple[str, str]:
+    """
+    Args:
+        tok: Label in IOB format
+
+    Returns:
+        tuple: ("B", "DRUG")
+
+    """
+    tag_class = tok.split('-')[0]
+    tag_type = tok.split('-')[-1]
+
+    return tag_class, tag_type
+
+
+def get_chunks(seq: List[str]) -> List[Tuple[str, int, int]]:
+    """
+    Given a sequence of tags, group entities and their position
+
+    Args:
+        seq: ["O", "O", "B-DRUG", "I-DRUG", ...] sequence of labels
+
+    Returns:
+        list of (chunk_type, chunk_start, chunk_end)
+
+    Example:
+        seq = ["B-DRUG", "I-DRUG", "O", "B-STR"]
+        result = [("DRUG", 0, 1), ("STR", 3, 3)]
+
+    """
+    default = "O"
+    chunks = []
+    chunk_type, chunk_start = None, None
+
+    for i, tok in enumerate(seq):
+        # End of a chunk 1
+        if tok == default and chunk_type is not None:
+            # Add a chunk.
+            chunk = (chunk_type, chunk_start, i - 1)
+            chunks.append(chunk)
+            chunk_type, chunk_start = None, None
+
+        # End of a chunk + start of a chunk!
+        elif tok != default:
+            tok_chunk_class, tok_chunk_type = get_chunk_type(tok)
+            if chunk_type is None:
+                chunk_type, chunk_start = tok_chunk_type, i
+            elif tok_chunk_type != chunk_type or tok_chunk_class == "B":
+                chunk = (chunk_type, chunk_start, i - 1)
+                chunks.append(chunk)
+                chunk_type, chunk_start = tok_chunk_type, i
+        else:
+            continue
+
+    # end condition
+    if chunk_type is not None:
+        chunk = (chunk_type, chunk_start, len(seq))
+        chunks.append(chunk)
+
+    return chunks
+
+
+# noinspection PyTypeChecker
+def get_biobert_ner_predictions(test_ehr: HealthRecord) -> List[Tuple[str, int, int]]:
+    """
+    Get predictions for a single EHR record using BioBERT
+
+    Parameters
+    ----------
+    test_ehr : HealthRecord
+        The EHR record, this object should have a tokenizer set.
+
+    Returns
+    -------
+    pred_entities : List[Tuple[str, int, int]]
+        List of predicted Entities each with the format
+        ("entity", start_idx, end_idx).
+
+    """
+    split_points = test_ehr.get_split_points(max_len=BIOBERT_NER_SEQ_LEN - 2)
+    examples = []
+
+    for idx in range(len(split_points) - 1):
+        words = test_ehr.tokens[split_points[idx]:split_points[idx + 1]]
+        examples.append(NerExample(guid=str(split_points[idx]),
+                                   words=words,
+                                   labels=["O"] * len(words)))
+
+    input_features = convert_examples_to_features(
+        examples,
+        biobert_ner_labels,
+        max_seq_length=BIOBERT_NER_SEQ_LEN,
+        tokenizer=biobert_ner_tokenizer,
+        cls_token_at_end=False,
+        cls_token=biobert_ner_tokenizer.cls_token,
+        cls_token_segment_id=0,
+        sep_token=biobert_ner_tokenizer.sep_token,
+        sep_token_extra=False,
+        pad_on_left=bool(biobert_ner_tokenizer.padding_side == "left"),
+        pad_token=biobert_ner_tokenizer.pad_token_id,
+        pad_token_segment_id=biobert_ner_tokenizer.pad_token_type_id,
+        pad_token_label_id=nn.CrossEntropyLoss().ignore_index,
+        verbose=0)
+
+    test_dataset = NerTestDataset(input_features)
+
+    predictions, label_ids, _ = biobert_ner_trainer.predict(test_dataset)
+    predictions = align_predictions(predictions, label_ids)
+
+    # Flatten the prediction list
+    predictions = [p for ex in predictions for p in ex]
+
+    input_tokens = test_ehr.get_tokens()
+    prev_pred = ""
+    final_predictions = []
+    idx = 0
+
+    for token in input_tokens:
+        if token.startswith("##"):
+            if prev_pred == "O":
+                final_predictions.append(prev_pred)
+            else:
+                pred_typ = prev_pred.split("-")[-1]
+                final_predictions.append("I-" + pred_typ)
+        else:
+            prev_pred = predictions[idx]
+            final_predictions.append(prev_pred)
+            idx += 1
+
+    pred_entities = []
+    chunk_pred = get_chunks(final_predictions)
+    for ent in chunk_pred:
+        pred_entities.append((ent[0],
+                              test_ehr.get_char_idx(ent[1])[0],
+                              test_ehr.get_char_idx(ent[2])[1]))
+
+    return pred_entities
+
+
+def get_bilstm_ner_predictions(test_ehr: HealthRecord) -> List[Tuple[str, int, int]]:
+    """
+    Get predictions for a single EHR record using BiLSTM
+
+    Parameters
+    ----------
+    test_ehr : HealthRecord
+        The EHR record, this object should have a tokenizer set.
+
+    Returns
+    -------
+    pred_entities : List[Tuple[str, int, int]]
+        List of predicted Entities each with the format
+        ("entity", start_idx, end_idx).
+
+    """
+    split_points = test_ehr.get_split_points(max_len=BILSTM_NER_SEQ_LEN)
+    examples = []
+
+    for idx in range(len(split_points) - 1):
+        words = test_ehr.tokens[split_points[idx]:split_points[idx + 1]]
+        examples.append(words)
+
+    predictions = bilstm_learn.predict(examples)
+
+    pred_entities = []
+    for idx in range(len(split_points) - 1):
+        chunk_pred = get_chunks(predictions[idx])
+        for ent in chunk_pred:
+            pred_entities.append((ent[0],
+                                  test_ehr.get_char_idx(split_points[idx] + ent[1])[0],
+                                  test_ehr.get_char_idx(split_points[idx] + ent[2])[1]))
+
+    return pred_entities
+
+
+# noinspection PyTypeChecker
+def get_ner_predictions(ehr_record: str, model_name: str = "biobert", record_id: str = "1") -> HealthRecord:
+    """
+    Get predictions for NER using either BioBERT or BiLSTM
+
+    Parameters
+    --------------
+    ehr_record : str
+        An EHR record in text format.
+
+    model_name : str
+        The model to use for prediction. Default is biobert.
+
+    record_id : str
+        The record id of the returned object. Default is 1.
+
+    Returns
+    -----------
+    A HealthRecord object with entities set.
+    """
+    if model_name.lower() == "biobert":
+        test_ehr = HealthRecord(record_id=record_id,
+                                text=ehr_record,
+                                tokenizer=biobert_ner_tokenizer.tokenize,
+                                is_bert_tokenizer=True,
+                                is_training=False)
+
+        predictions = get_biobert_ner_predictions(test_ehr)
+
+    elif model_name.lower() == "bilstm":
+        test_ehr = HealthRecord(text=ehr_record,
+                                tokenizer=scispacy_plus_tokenizer,
+                                is_bert_tokenizer=False,
+                                is_training=False)
+        predictions = get_bilstm_ner_predictions(test_ehr)
+
+    else:
+        raise AttributeError("Accepted model names include 'biobert' "
+                             "and 'bilstm'.")
+
+    ent_preds = []
+    for i, pred in enumerate(predictions):
+        ent = Entity("T%d" % i, label_ent_map[pred[0]], [pred[1], pred[2]])
+        ent_text = test_ehr.text[ent[0]:ent[1]]
+
+        if not any(letter.isalnum() for letter in ent_text):
+            continue
+
+        ent.set_text(ent_text)
+        ent_preds.append(ent)
+
+    test_ehr.entities = ent_preds
+    return test_ehr
+
+
+def get_re_predictions(test_ehr: HealthRecord) -> HealthRecord:
+    """
+    Get predictions for Relation Extraction.
+
+    Parameters
+    -----------
+    test_ehr : HealthRecord
+        A HealthRecord object with entities set.
+
+    Returns
+    --------
+    HealthRecord
+        The original object with relations set.
+    """
+    test_dataset = RETestDataset(test_ehr, biobert_ner_tokenizer,
+                                 BIOBERT_RE_SEQ_LEN, re_label_list)
+
+    if len(test_dataset) == 0:
+        test_ehr.relations = []
+        return test_ehr
+
+    re_predictions = biobert_re_trainer.predict(test_dataset=test_dataset).predictions
+    re_predictions = np.argmax(re_predictions, axis=1)
+
+    idx = 1
+    rel_preds = []
+    for relation, pred in zip(test_dataset.relation_list, re_predictions):
+        if pred == 1:
+            relation.ann_id = "R%d" % idx
+            idx += 1
+            rel_preds.append(relation)
+
+    test_ehr.relations = rel_preds
+    return test_ehr