Diff of /biobert_re/utils_re.py [000000] .. [1de6ed]

Switch to side-by-side view

--- a
+++ b/biobert_re/utils_re.py
@@ -0,0 +1,439 @@
+import os
+import time
+import random
+from enum import Enum
+from dataclasses import dataclass, field
+from typing import List, Optional, Union, Dict, Tuple
+
+import torch
+from torch.utils.data.dataset import Dataset
+
+from filelock import FileLock
+
+import logging
+
+from transformers import (InputFeatures,
+                          InputExample,
+                          PreTrainedTokenizerBase)
+
+import pandas as pd
+from sklearn.metrics import precision_recall_fscore_support
+
+
+import sys
+sys.path.append("../")
+sys.path.append('./biobert_re/')
+
+from data_processor import glue_convert_examples_to_features, glue_output_modes, glue_processors
+
+import utils
+from ehr import HealthRecord
+from annotations import Relation
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class GlueDataTrainingArguments:
+    """
+    Arguments pertaining to what data we are going to input our model for training and eval.
+
+    Using `HfArgumentParser` we can turn this class into argparse arguments to be able to specify them on the command
+    line.
+    """
+
+    task_name: str = field(metadata={"help": "The name of the task to train on: " + ", ".join(glue_processors.keys())})
+    data_dir: str = field(
+        metadata={"help": "The input data dir. Should contain the .tsv files (or other data files) for the task."}
+    )
+    max_seq_length: int = field(
+        default=128,
+        metadata={
+            "help": "The maximum total input sequence length after tokenization. Sequences longer "
+            "than this will be truncated, sequences shorter will be padded."
+        },
+    )
+    overwrite_cache: bool = field(
+        default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
+    )
+
+    def __post_init__(self):
+        self.task_name = self.task_name.lower()
+
+
+class Split(Enum):
+    train = "train"
+    dev = "dev"
+    test = "test"
+
+
+# noinspection PyTypeChecker
+class REDataset(Dataset):
+    """
+    A class representing a training dataset for Relation Extraction.
+    """
+
+    args: GlueDataTrainingArguments
+    output_mode: str
+    features: List[InputFeatures]
+
+    def __init__(
+        self,
+        args: GlueDataTrainingArguments,
+        tokenizer: PreTrainedTokenizerBase,
+        limit_length: Optional[int] = None,
+        mode: Union[str, Split] = Split.train,
+        cache_dir: Optional[str] = None,
+    ):
+        self.args = args
+        self.processor = glue_processors[args.task_name]()
+        self.output_mode = glue_output_modes[args.task_name]
+        if isinstance(mode, str):
+            try:
+                mode = Split[mode]
+            except KeyError:
+                raise KeyError("mode is not a valid split name")
+
+        # Load data features from cache or dataset file
+        cached_features_file = os.path.join(
+            cache_dir if cache_dir is not None else args.data_dir,
+            "cached_{}_{}_{}_{}".format(
+                mode.value,
+                tokenizer.__class__.__name__,
+                str(args.max_seq_length),
+                args.task_name,
+            ),
+        )
+
+        label_list = self.processor.get_labels()
+
+        self.label_list = label_list
+
+        # Make sure only the first process in distributed training processes the dataset,
+        # and the others will use the cache.
+        lock_path = cached_features_file + ".lock"
+        with FileLock(lock_path):
+
+            if os.path.exists(cached_features_file) and not args.overwrite_cache:
+                start = time.time()
+                self.features = torch.load(cached_features_file)
+                logger.info(f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start)
+            else:
+                logger.info(f"Creating features from dataset file at {args.data_dir}")
+
+                if mode == Split.dev:
+                    examples = self.processor.get_dev_examples(args.data_dir)
+                elif mode == Split.test:
+                    examples = self.processor.get_test_examples(args.data_dir)
+                else:
+                    examples = self.processor.get_train_examples(args.data_dir)
+                if limit_length is not None:
+                    examples = examples[:limit_length]
+                self.features = glue_convert_examples_to_features(
+                    examples,
+                    tokenizer,
+                    max_length=args.max_seq_length,
+                    label_list=label_list,
+                    output_mode=self.output_mode,
+                )
+                start = time.time()
+                torch.save(self.features, cached_features_file)
+
+                logger.info("Saving features into cached file %s [took %.3f s]", cached_features_file, time.time() - start)
+
+    def __len__(self):
+        return len(self.features)
+
+    def __getitem__(self, i) -> InputFeatures:
+        return self.features[i]
+
+    def get_labels(self):
+        return self.label_list
+
+
+class RETestDataset(Dataset):
+    """
+    A class representing a test Dataset for relation extraction.
+    """
+
+    def __init__(self, test_ehr, tokenizer, max_seq_len, label_list):
+
+        self.re_text_list, self.relation_list = generate_re_test_file(test_ehr)
+
+        if not self.re_text_list:
+            self.features = []
+        else:
+            examples = []
+            for (i, text) in enumerate(self.re_text_list):
+                guid = "%s" % i
+                examples.append(InputExample(guid=guid, text_a=text, text_b=None, label=None))
+
+            self.features = glue_convert_examples_to_features(examples, tokenizer,
+                                                              max_length=max_seq_len,
+                                                              label_list=label_list)
+
+    def __len__(self):
+        return len(self.features)
+
+    def __getitem__(self, i) -> InputFeatures:
+        return self.features[i]
+
+
+def replace_ent_label(text, ent_type, start_idx, end_idx):
+    label = '@'+ent_type+'$'
+    return text[:start_idx]+label+text[end_idx:]
+
+
+def write_file(file, index, sentence, label, sep, is_test, is_label):
+    if is_test and is_label:  # test_original - test with labels
+        file.write('{}{}{}{}{}'.format(index, sep, sentence, sep, label))
+    elif is_test and not is_label:  # test - test with no labels
+        file.write('{}{}{}'.format(index, sep, sentence))
+    else:  # train
+        file.write('{}{}{}'.format(sentence, sep, label))
+    file.write('\n')
+
+
+def get_char_split_points(record, max_len):
+    char_split_points = []
+
+    split_points = record.get_split_points(max_len=max_len)
+    for pt in split_points[:-1]:
+        char_split_points.append(record.get_char_idx(pt)[1])
+
+    if len(char_split_points) == 1:
+        return char_split_points
+    else:
+        return char_split_points[1:]
+
+
+def replace_entity_text(split_text, ent1, ent2, split_offset):
+    # Remove split offset
+    ent1_start = ent1.range[0] - split_offset
+    ent1_end = ent1.range[1] - split_offset
+
+    ent2_start = ent2.range[0] - split_offset
+    ent2_end = ent2.range[1] - split_offset
+
+    # If entity 1 is present before entity 2
+    if ent1_end < ent2_end:
+        # Insert entity 2 and then entity 1
+        modified_text = replace_ent_label(split_text, ent2.name, ent2_start, ent2_end)
+        modified_text = replace_ent_label(modified_text, ent1.name, ent1_start, ent1_end)
+
+    # If entity 1 is present after entity 2
+    else:
+        # Insert entity 1 and then entity 2
+        modified_text = replace_ent_label(split_text, ent1.name, ent1_start, ent1_end)
+        modified_text = replace_ent_label(modified_text, ent2.name, ent2_start, ent2_end)
+
+    return modified_text
+
+
+def generate_re_input_files(ehr_records: List[HealthRecord], filename: str,
+                            ade_records: List[Dict] = None, max_len: int = 128,
+                            is_test=False, is_label=True, is_predict=False, sep: str = '\t'):
+
+    random.seed(0)
+
+    index = 0
+    index_rel_label_map = []
+
+    with open(filename, 'w') as file:
+        # Write headers
+        write_file(file, 'index', 'sentence', 'label', sep, is_test, is_label)
+
+        # Preprocess EHR records
+        for record in ehr_records:
+            text = record.text
+            entities = record.get_entities()
+
+            if is_predict:
+                true_relations = None
+            else:
+                true_relations = record.get_relations()
+
+            # get character split points
+            char_split_points = get_char_split_points(record, max_len)
+
+            start = 0
+            end = char_split_points[0]
+
+            for i in range(len(char_split_points)):
+                # Obtain only entities within the split text
+                range_entities = {ent_id: ent for ent_id, ent in
+                                  filter(lambda item: int(item[1][0]) >= start and int(item[1][1]) <= end,
+                                         entities.items())}
+
+                # Get all possible relations within the split text
+                possible_relations = utils.map_entities(range_entities, true_relations)
+
+                for rel, label in possible_relations:
+                    if label == 0 and rel.name != "ADE-Drug":
+                        if random.random() > 0.25:
+                            continue
+
+                    split_text = text[start:end]
+                    split_offset = start
+
+                    ent1 = rel.get_entities()[0]
+                    ent2 = rel.get_entities()[1]
+
+                    # Check if both entities are within split text
+                    if ent1.range[0] >= start and ent1.range[1] < end and \
+                            ent2.range[0] >= start and ent2.range[1] < end:
+
+                        modified_text = replace_entity_text(split_text, ent1, ent2, split_offset)
+
+                        # Replace un-required characters with space
+                        final_text = modified_text.replace('\n', ' ').replace('\t', ' ')
+                        write_file(file, index, final_text, label, sep, is_test, is_label)
+
+                        if is_predict:
+                            index_rel_label_map.append({'relation': rel})
+                        else:
+                            index_rel_label_map.append({'label': label, 'relation': rel})
+
+                        index += 1
+
+                start = end
+                if i != len(char_split_points)-1:
+                    end = char_split_points[i+1]
+                else:
+                    end = len(text)+1
+
+        # Preprocess ADE records
+        if ade_records is not None:
+            for record in ade_records:
+                entities = record['entities']
+                true_relations = record['relations']
+                possible_relations = utils.map_entities(entities, true_relations)
+
+                for rel, label in possible_relations:
+
+                    if label == 1 and random.random() > 0.5:
+                        continue
+
+                    new_tokens = record['tokens'].copy()
+
+                    for ent in rel.get_entities():
+                        ent_type = ent.name
+
+                        start_tok = ent.range[0]
+                        end_tok = ent.range[1]+1
+
+                        for i in range(start_tok, end_tok):
+                            new_tokens[i] = '@'+ent_type+'$'
+
+                    """Remove consecutive repeating entities.
+                    Eg. this is @ADE$ @ADE$ @ADE$ for @Drug$ @Drug$ -> this is @ADE$ for @Drug$"""
+                    final_tokens = [new_tokens[i] for i in range(len(new_tokens))\
+                                    if (i == 0) or new_tokens[i] != new_tokens[i-1]]
+
+                    final_text = " ".join(final_tokens)
+
+                    write_file(file, index, final_text, label, sep, is_test, is_label)
+                    index_rel_label_map.append({'label': label, 'relation': rel})
+                    index += 1
+
+    filename, ext = filename.split('.')
+    utils.save_pickle(filename+'_rel.pkl', index_rel_label_map)
+
+
+def get_eval_results(answer_path, output_path):
+    """
+    Get evaluation metrics for predictions
+
+    Parameters
+    ------------
+    answer_path : test.tsv file. Tab-separated.
+                  One example per a line. True labels at the 3rd column.
+
+    output_path : test_predictions.txt. Model generated predictions.
+    """
+    testdf = pd.read_csv(answer_path, sep="\t", index_col=0)
+    preddf = pd.read_csv(output_path, sep="\t", header=None)
+
+    pred = [preddf.iloc[i].tolist() for i in preddf.index]
+    pred_class = [int(v[1]) for v in pred[1:]]
+
+    p, r, f, s = precision_recall_fscore_support(y_pred=pred_class, y_true=testdf["label"])
+    results = dict()
+    results["f1 score"] = f[1]
+    results["recall"] = r[1]
+    results["precision"] = p[1]
+    results["specificity"] = r[0]
+    return results
+
+
+def generate_re_test_file(ehr_record: HealthRecord,
+                          max_len: int = 128) -> Tuple[List[str], List[Relation]]:
+    """
+    Generates test file for Relation Extraction.
+
+    Parameters
+    -----------
+    ehr_record : HealthRecord
+        The EHR record with entities set.
+
+    max_len : int
+        The maximum length of sequence.
+
+    Returns
+    --------
+    Tuple[List[str], List[Relation]]
+        List of sequences with entity replaced by it's tag.
+        And a list of relation objects representing relation in those sequences.
+    """
+    random.seed(0)
+
+    re_text_list = []
+    relation_list = []
+
+    text = ehr_record.text
+    entities = ehr_record.get_entities()
+    if isinstance(entities, dict):
+        entities = list(entities.values())
+
+    # get character split points
+    char_split_points = get_char_split_points(ehr_record, max_len)
+
+    start = 0
+    end = char_split_points[0]
+
+    for i in range(len(char_split_points)):
+        # Obtain only entities within the split text
+        range_entities = [ent for ent in filter(lambda item: int(item[0]) >= start and int(item[1]) <= end,
+                                                entities)]
+
+        # Get all possible relations within the split text
+        possible_relations = utils.map_entities(range_entities)
+
+        for rel, label in possible_relations:
+            split_text = text[start:end]
+            split_offset = start
+
+            ent1 = rel.get_entities()[0]
+            ent2 = rel.get_entities()[1]
+
+            # Check if both entities are within split text
+            if ent1[0] >= start and ent1[1] < end and \
+                    ent2[0] >= start and ent2[1] < end:
+
+                modified_text = replace_entity_text(split_text, ent1, ent2, split_offset)
+
+                # Replace un-required characters with space
+                final_text = modified_text.replace('\n', ' ').replace('\t', ' ')
+
+                re_text_list.append(final_text)
+                relation_list.append(rel)
+
+        start = end
+        if i != len(char_split_points)-1:
+            end = char_split_points[i+1]
+        else:
+            end = len(text)+1
+
+    assert len(re_text_list) == len(relation_list)
+
+    return re_text_list, relation_list