--- a +++ b/biobert_ner/utils_ner.py @@ -0,0 +1,373 @@ +import sys +sys.path.append("../") + +import logging +import os +from dataclasses import dataclass +from enum import Enum +from typing import List, Optional, Union, Dict +from ehr import HealthRecord + +from filelock import FileLock + +from transformers import PreTrainedTokenizer +import torch +from torch import nn +from torch.utils.data.dataset import Dataset + + +logger = logging.getLogger(__name__) + + +@dataclass +class InputExample: + """ + A single training/test example for token classification. + + Args: + guid: Unique id for the example. + words: list. The words of the sequence. + labels: (Optional) list. The labels for each word of the sequence. This should be + specified for train and dev examples, but not for test examples. + """ + + guid: str + words: List[str] + labels: Optional[List[str]] + + +@dataclass +class InputFeatures: + """ + A single set of features of data. + Property names are the same names as the corresponding inputs to a model. + """ + + input_ids: List[int] + attention_mask: List[int] + token_type_ids: Optional[List[int]] = None + label_ids: Optional[List[int]] = None + + +class Split(Enum): + train = "train_dev" + dev = "devel" + test = "test" + +class NerTestDataset(Dataset): + """ + Dataset for test examples + """ + features: List[InputFeatures] + pad_token_label_id: int = nn.CrossEntropyLoss().ignore_index + + def __init__(self, input_features): + self.features = input_features + + def __len__(self): + return len(self.features) + + def __getitem__(self, i) -> InputFeatures: + return self.features[i] + + +class NerDataset(Dataset): + + features: List[InputFeatures] + pad_token_label_id: int = nn.CrossEntropyLoss().ignore_index + + # Use cross entropy ignore_index as padding label id so that only + # real label ids contribute to the loss later. + + def __init__( + self, + data_dir: str, + tokenizer: PreTrainedTokenizer, + labels: List[str], + model_type: str, + max_seq_length: Optional[int] = None, + overwrite_cache=False, + mode: Split = Split.train, + ): + # Load data features from cache or dataset file + cached_features_file = os.path.join( + data_dir, "cached_{}_{}_{}".format(mode.value, tokenizer.__class__.__name__, str(max_seq_length)), + ) + + # Make sure only the first process in distributed training processes the dataset, + # and the others will use the cache. + lock_path = cached_features_file + ".lock" + with FileLock(lock_path): + + if os.path.exists(cached_features_file) and not overwrite_cache: + logger.info(f"Loading features from cached file {cached_features_file}") + self.features = torch.load(cached_features_file) + else: + logger.info(f"Creating features from dataset file at {data_dir}") + examples = read_examples_from_file(data_dir, mode) + self.features = convert_examples_to_features( + examples, + labels, + max_seq_length, + tokenizer, + cls_token_at_end=bool(model_type in ["xlnet"]), + # xlnet has a cls token at the end + cls_token=tokenizer.cls_token, + cls_token_segment_id=2 if model_type in ["xlnet"] else 0, + sep_token=tokenizer.sep_token, + sep_token_extra=False, + # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805 + pad_on_left=bool(tokenizer.padding_side == "left"), + pad_token=tokenizer.pad_token_id, + pad_token_segment_id=tokenizer.pad_token_type_id, + pad_token_label_id=self.pad_token_label_id, + ) + logger.info(f"Saving features into cached file {cached_features_file}") + torch.save(self.features, cached_features_file) + + def __len__(self): + return len(self.features) + + def __getitem__(self, i) -> InputFeatures: + return self.features[i] + + +def read_examples_from_file(data_dir, mode: Union[Split, str]) -> List[InputExample]: + if isinstance(mode, Split): + mode = mode.value + file_path = os.path.join(data_dir, f"{mode}.txt") + guid_index = 1 + examples = [] + with open(file_path, encoding="utf-8") as f: + words = [] + labels = [] + for line in f: + line = line.rstrip() + if line.startswith("-DOCSTART-") or line == "" or line == "\n": + if words: + examples.append(InputExample(guid=f"{mode}-{guid_index}", words=words, labels=labels)) + guid_index += 1 + words = [] + labels = [] + else: + splits = line.split(" ") + words.append(splits[0]) + if len(splits) > 1: + labels.append(splits[-1].replace("\n", "")) + else: + # Examples could have no label for mode = "test" + labels.append("O") + if words: + examples.append(InputExample(guid=f"{mode}-{guid_index}", words=words, labels=labels)) + return examples + + +def convert_examples_to_features( + examples: List[InputExample], + label_list: List[str], + max_seq_length: int, + tokenizer: PreTrainedTokenizer, + cls_token_at_end=False, + cls_token="[CLS]", + cls_token_segment_id=1, + sep_token="[SEP]", + sep_token_extra=False, + pad_on_left=False, + pad_token=0, + pad_token_segment_id=0, + pad_token_label_id=-100, + sequence_a_segment_id=0, + mask_padding_with_zero=True, + verbose=1, +) -> List[InputFeatures]: + """ + Loads a data file into a list of `InputFeatures` + `cls_token_at_end` define the location of the CLS token: + - False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP] + - True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS] + `cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet) + """ + + label_map = {label: i for i, label in enumerate(label_list)} + + features = [] + for (ex_index, example) in enumerate(examples): + if ex_index % 10_000 == 0: + logger.info("Writing example %d of %d", ex_index, len(examples)) + + tokens = [] + label_ids = [] + for word, label in zip(example.words, example.labels): + tokens.append(word) + if word.startswith("##"): + label_ids.append(pad_token_label_id) + else: + label_ids.append(label_map[label]) + + # Account for [CLS] and [SEP] with "- 2" and with "- 3" for RoBERTa. + special_tokens_count = tokenizer.num_special_tokens_to_add() + if len(tokens) > max_seq_length - special_tokens_count: + logger.info("Length %d exceeds max seq len, truncating." % len(tokens)) + tokens = tokens[: (max_seq_length - special_tokens_count)] + label_ids = label_ids[: (max_seq_length - special_tokens_count)] + + # The convention in BERT is: + # (a) For sequence pairs: + # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] + # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 + # (b) For single sequences: + # tokens: [CLS] the dog is hairy . [SEP] + # type_ids: 0 0 0 0 0 0 0 + # + # Where "type_ids" are used to indicate whether this is the first + # sequence or the second sequence. The embedding vectors for `type=0` and + # `type=1` were learned during pre-training and are added to the wordpiece + # embedding vector (and position vector). This is not *strictly* necessary + # since the [SEP] token unambiguously separates the sequences, but it makes + # it easier for the model to learn the concept of sequences. + # + # For classification tasks, the first vector (corresponding to [CLS]) is + # used as as the "sentence vector". Note that this only makes sense because + # the entire model is fine-tuned. + + tokens += [sep_token] + label_ids += [pad_token_label_id] + if sep_token_extra: + # roberta uses an extra separator b/w pairs of sentences + tokens += [sep_token] + label_ids += [pad_token_label_id] + segment_ids = [sequence_a_segment_id] * len(tokens) + + if cls_token_at_end: + tokens += [cls_token] + label_ids += [pad_token_label_id] + segment_ids += [cls_token_segment_id] + else: + tokens = [cls_token] + tokens + label_ids = [pad_token_label_id] + label_ids + segment_ids = [cls_token_segment_id] + segment_ids + + input_ids = tokenizer.convert_tokens_to_ids(tokens) + + # The mask has 1 for real tokens and 0 for padding tokens. Only real + # tokens are attended to. + input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids) + + # Zero-pad up to the sequence length. + padding_length = max_seq_length - len(input_ids) + if pad_on_left: + input_ids = ([pad_token] * padding_length) + input_ids + input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask + segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids + label_ids = ([pad_token_label_id] * padding_length) + label_ids + else: + input_ids += [pad_token] * padding_length + input_mask += [0 if mask_padding_with_zero else 1] * padding_length + segment_ids += [pad_token_segment_id] * padding_length + label_ids += [pad_token_label_id] * padding_length + + assert len(input_ids) == max_seq_length + assert len(input_mask) == max_seq_length + assert len(segment_ids) == max_seq_length + assert len(label_ids) == max_seq_length + + if ex_index < 2 and verbose == 1: + logger.info("*** Example ***") + logger.info("guid: %s", example.guid) + logger.info("tokens: %s", " ".join([str(x) for x in tokens])) + logger.info("input_ids: %s", " ".join([str(x) for x in input_ids])) + logger.info("input_mask: %s", " ".join([str(x) for x in input_mask])) + logger.info("segment_ids: %s", " ".join([str(x) for x in segment_ids])) + logger.info("label_ids: %s", " ".join([str(x) for x in label_ids])) + + if "token_type_ids" not in tokenizer.model_input_names: + segment_ids = None + + features.append( + InputFeatures( + input_ids=input_ids, attention_mask=input_mask, token_type_ids=segment_ids, label_ids=label_ids + ) + ) + return features + + +def get_labels(path: str) -> List[str]: + if path: + with open(path, "r") as f: + labels = f.read().splitlines() + if "O" not in labels: + labels = ["O"] + labels + return labels + else: + return ["O", "B-MISC", "I-MISC", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC"] + + +def generate_input_files(ehr_records: List[HealthRecord], filename: str, + ade_records: List[Dict] = None, max_len: int = 510, + sep: str = ' '): + """ + Write EHR and ADE records to a file. + + Parameters + ---------- + ehr_records : List[HealthRecord] + List of EHR records. + + ade_records : List[Dict] + List of ADE records. + + filename : str + File name to write to. + + max_len : int, optional + Max length of an example. The default is 510. + + sep : str, optional + Token-label separator. The default is a space. + + """ + with open(filename, 'w') as f: + for record in ehr_records: + + split_idx = record.get_split_points(max_len=max_len) + labels = record.get_labels() + tokens = record.get_tokens() + + start = split_idx[0] + end = split_idx[1] + + for i in range(1, len(split_idx)): + for (token, label) in zip(tokens[start:end + 1], labels[start:end + 1]): + f.write('{}{}{}\n'.format(token, sep, label)) + + start = end + 1 + if i != len(split_idx) - 1: + end = split_idx[i + 1] + f.write('\n') + f.write('\n') + + if ade_records is not None: + + for ade in ade_records: + ade_tokens = ade['tokens'] + ade_entities = ade['entities'] + + ent_label_map = {'Drug': 'DRUG', 'Adverse-Effect': 'ADE', 'ADE': 'ADE'} + ade_labels = ['O'] * len(ade_tokens) + + for ent in ade_entities.values(): + ent_type = ent.name + start_idx = ent.range[0] + end_idx = ent.range[1] + + for idx in range(start_idx, end_idx + 1): + if idx == start_idx: + ade_labels[idx] = 'B-' + ent_label_map[ent_type] + else: + ade_labels[idx] = 'I-' + ent_label_map[ent_type] + + for (token, label) in zip(ade_tokens, ade_labels): + f.write('{}{}{}\n'.format(token, sep, label)) + f.write('\n') + + print("Data successfully saved in " + filename)