--- a
+++ b/biobert_ner/utils_ner.py
@@ -0,0 +1,373 @@
+import sys
+sys.path.append("../")
+
+import logging
+import os
+from dataclasses import dataclass
+from enum import Enum
+from typing import List, Optional, Union, Dict
+from ehr import HealthRecord
+
+from filelock import FileLock
+
+from transformers import PreTrainedTokenizer
+import torch
+from torch import nn
+from torch.utils.data.dataset import Dataset
+
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class InputExample:
+    """
+    A single training/test example for token classification.
+
+    Args:
+        guid: Unique id for the example.
+        words: list. The words of the sequence.
+        labels: (Optional) list. The labels for each word of the sequence. This should be
+        specified for train and dev examples, but not for test examples.
+    """
+
+    guid: str
+    words: List[str]
+    labels: Optional[List[str]]
+
+
+@dataclass
+class InputFeatures:
+    """
+    A single set of features of data.
+    Property names are the same names as the corresponding inputs to a model.
+    """
+
+    input_ids: List[int]
+    attention_mask: List[int]
+    token_type_ids: Optional[List[int]] = None
+    label_ids: Optional[List[int]] = None
+
+
+class Split(Enum):
+    train = "train_dev"
+    dev = "devel"
+    test = "test"
+
+class NerTestDataset(Dataset):
+    """
+    Dataset for test examples
+    """
+    features: List[InputFeatures]
+    pad_token_label_id: int = nn.CrossEntropyLoss().ignore_index
+
+    def __init__(self, input_features):
+        self.features = input_features
+
+    def __len__(self):
+        return len(self.features)
+
+    def __getitem__(self, i) -> InputFeatures:
+        return self.features[i]
+
+
+class NerDataset(Dataset):
+
+    features: List[InputFeatures]
+    pad_token_label_id: int = nn.CrossEntropyLoss().ignore_index
+
+    # Use cross entropy ignore_index as padding label id so that only
+    # real label ids contribute to the loss later.
+
+    def __init__(
+            self,
+            data_dir: str,
+            tokenizer: PreTrainedTokenizer,
+            labels: List[str],
+            model_type: str,
+            max_seq_length: Optional[int] = None,
+            overwrite_cache=False,
+            mode: Split = Split.train,
+    ):
+        # Load data features from cache or dataset file
+        cached_features_file = os.path.join(
+            data_dir, "cached_{}_{}_{}".format(mode.value, tokenizer.__class__.__name__, str(max_seq_length)),
+        )
+
+        # Make sure only the first process in distributed training processes the dataset,
+        # and the others will use the cache.
+        lock_path = cached_features_file + ".lock"
+        with FileLock(lock_path):
+
+            if os.path.exists(cached_features_file) and not overwrite_cache:
+                logger.info(f"Loading features from cached file {cached_features_file}")
+                self.features = torch.load(cached_features_file)
+            else:
+                logger.info(f"Creating features from dataset file at {data_dir}")
+                examples = read_examples_from_file(data_dir, mode)
+                self.features = convert_examples_to_features(
+                    examples,
+                    labels,
+                    max_seq_length,
+                    tokenizer,
+                    cls_token_at_end=bool(model_type in ["xlnet"]),
+                    # xlnet has a cls token at the end
+                    cls_token=tokenizer.cls_token,
+                    cls_token_segment_id=2 if model_type in ["xlnet"] else 0,
+                    sep_token=tokenizer.sep_token,
+                    sep_token_extra=False,
+                    # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
+                    pad_on_left=bool(tokenizer.padding_side == "left"),
+                    pad_token=tokenizer.pad_token_id,
+                    pad_token_segment_id=tokenizer.pad_token_type_id,
+                    pad_token_label_id=self.pad_token_label_id,
+                )
+                logger.info(f"Saving features into cached file {cached_features_file}")
+                torch.save(self.features, cached_features_file)
+
+    def __len__(self):
+        return len(self.features)
+
+    def __getitem__(self, i) -> InputFeatures:
+        return self.features[i]
+
+
+def read_examples_from_file(data_dir, mode: Union[Split, str]) -> List[InputExample]:
+    if isinstance(mode, Split):
+        mode = mode.value
+    file_path = os.path.join(data_dir, f"{mode}.txt")
+    guid_index = 1
+    examples = []
+    with open(file_path, encoding="utf-8") as f:
+        words = []
+        labels = []
+        for line in f:
+            line = line.rstrip()
+            if line.startswith("-DOCSTART-") or line == "" or line == "\n":
+                if words:
+                    examples.append(InputExample(guid=f"{mode}-{guid_index}", words=words, labels=labels))
+                    guid_index += 1
+                    words = []
+                    labels = []
+            else:
+                splits = line.split(" ")
+                words.append(splits[0])
+                if len(splits) > 1:
+                    labels.append(splits[-1].replace("\n", ""))
+                else:
+                    # Examples could have no label for mode = "test"
+                    labels.append("O")
+        if words:
+            examples.append(InputExample(guid=f"{mode}-{guid_index}", words=words, labels=labels))
+    return examples
+
+
+def convert_examples_to_features(
+        examples: List[InputExample],
+        label_list: List[str],
+        max_seq_length: int,
+        tokenizer: PreTrainedTokenizer,
+        cls_token_at_end=False,
+        cls_token="[CLS]",
+        cls_token_segment_id=1,
+        sep_token="[SEP]",
+        sep_token_extra=False,
+        pad_on_left=False,
+        pad_token=0,
+        pad_token_segment_id=0,
+        pad_token_label_id=-100,
+        sequence_a_segment_id=0,
+        mask_padding_with_zero=True,
+        verbose=1,
+) -> List[InputFeatures]:
+    """
+    Loads a data file into a list of `InputFeatures`
+        `cls_token_at_end` define the location of the CLS token:
+            - False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP]
+            - True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS]
+        `cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet)
+    """
+
+    label_map = {label: i for i, label in enumerate(label_list)}
+
+    features = []
+    for (ex_index, example) in enumerate(examples):
+        if ex_index % 10_000 == 0:
+            logger.info("Writing example %d of %d", ex_index, len(examples))
+
+        tokens = []
+        label_ids = []
+        for word, label in zip(example.words, example.labels):
+            tokens.append(word)
+            if word.startswith("##"):
+                label_ids.append(pad_token_label_id)
+            else:
+                label_ids.append(label_map[label])
+
+        # Account for [CLS] and [SEP] with "- 2" and with "- 3" for RoBERTa.
+        special_tokens_count = tokenizer.num_special_tokens_to_add()
+        if len(tokens) > max_seq_length - special_tokens_count:
+            logger.info("Length %d exceeds max seq len, truncating." % len(tokens))
+            tokens = tokens[: (max_seq_length - special_tokens_count)]
+            label_ids = label_ids[: (max_seq_length - special_tokens_count)]
+
+        # The convention in BERT is:
+        # (a) For sequence pairs:
+        #  tokens:   [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
+        #  type_ids:   0   0  0    0    0     0       0   0   1  1  1  1   1   1
+        # (b) For single sequences:
+        #  tokens:   [CLS] the dog is hairy . [SEP]
+        #  type_ids:   0   0   0   0  0     0   0
+        #
+        # Where "type_ids" are used to indicate whether this is the first
+        # sequence or the second sequence. The embedding vectors for `type=0` and
+        # `type=1` were learned during pre-training and are added to the wordpiece
+        # embedding vector (and position vector). This is not *strictly* necessary
+        # since the [SEP] token unambiguously separates the sequences, but it makes
+        # it easier for the model to learn the concept of sequences.
+        #
+        # For classification tasks, the first vector (corresponding to [CLS]) is
+        # used as as the "sentence vector". Note that this only makes sense because
+        # the entire model is fine-tuned.
+
+        tokens += [sep_token]
+        label_ids += [pad_token_label_id]
+        if sep_token_extra:
+            # roberta uses an extra separator b/w pairs of sentences
+            tokens += [sep_token]
+            label_ids += [pad_token_label_id]
+        segment_ids = [sequence_a_segment_id] * len(tokens)
+
+        if cls_token_at_end:
+            tokens += [cls_token]
+            label_ids += [pad_token_label_id]
+            segment_ids += [cls_token_segment_id]
+        else:
+            tokens = [cls_token] + tokens
+            label_ids = [pad_token_label_id] + label_ids
+            segment_ids = [cls_token_segment_id] + segment_ids
+
+        input_ids = tokenizer.convert_tokens_to_ids(tokens)
+
+        # The mask has 1 for real tokens and 0 for padding tokens. Only real
+        # tokens are attended to.
+        input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
+
+        # Zero-pad up to the sequence length.
+        padding_length = max_seq_length - len(input_ids)
+        if pad_on_left:
+            input_ids = ([pad_token] * padding_length) + input_ids
+            input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask
+            segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids
+            label_ids = ([pad_token_label_id] * padding_length) + label_ids
+        else:
+            input_ids += [pad_token] * padding_length
+            input_mask += [0 if mask_padding_with_zero else 1] * padding_length
+            segment_ids += [pad_token_segment_id] * padding_length
+            label_ids += [pad_token_label_id] * padding_length
+
+        assert len(input_ids) == max_seq_length
+        assert len(input_mask) == max_seq_length
+        assert len(segment_ids) == max_seq_length
+        assert len(label_ids) == max_seq_length
+
+        if ex_index < 2 and verbose == 1:
+            logger.info("*** Example ***")
+            logger.info("guid: %s", example.guid)
+            logger.info("tokens: %s", " ".join([str(x) for x in tokens]))
+            logger.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
+            logger.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
+            logger.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
+            logger.info("label_ids: %s", " ".join([str(x) for x in label_ids]))
+
+        if "token_type_ids" not in tokenizer.model_input_names:
+            segment_ids = None
+
+        features.append(
+            InputFeatures(
+                input_ids=input_ids, attention_mask=input_mask, token_type_ids=segment_ids, label_ids=label_ids
+            )
+        )
+    return features
+
+
+def get_labels(path: str) -> List[str]:
+    if path:
+        with open(path, "r") as f:
+            labels = f.read().splitlines()
+        if "O" not in labels:
+            labels = ["O"] + labels
+        return labels
+    else:
+        return ["O", "B-MISC", "I-MISC", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC"]
+
+
+def generate_input_files(ehr_records: List[HealthRecord], filename: str,
+                         ade_records: List[Dict] = None, max_len: int = 510,
+                         sep: str = ' '):
+    """
+    Write EHR and ADE records to a file.
+
+    Parameters
+    ----------
+    ehr_records : List[HealthRecord]
+        List of EHR records.
+
+    ade_records : List[Dict]
+        List of ADE records.
+
+    filename : str
+        File name to write to.
+
+    max_len : int, optional
+        Max length of an example. The default is 510.
+
+    sep : str, optional
+        Token-label separator. The default is a space.
+
+    """
+    with open(filename, 'w') as f:
+        for record in ehr_records:
+
+            split_idx = record.get_split_points(max_len=max_len)
+            labels = record.get_labels()
+            tokens = record.get_tokens()
+
+            start = split_idx[0]
+            end = split_idx[1]
+
+            for i in range(1, len(split_idx)):
+                for (token, label) in zip(tokens[start:end + 1], labels[start:end + 1]):
+                    f.write('{}{}{}\n'.format(token, sep, label))
+
+                start = end + 1
+                if i != len(split_idx) - 1:
+                    end = split_idx[i + 1]
+                    f.write('\n')
+            f.write('\n')
+
+        if ade_records is not None:
+
+            for ade in ade_records:
+                ade_tokens = ade['tokens']
+                ade_entities = ade['entities']
+
+                ent_label_map = {'Drug': 'DRUG', 'Adverse-Effect': 'ADE', 'ADE': 'ADE'}
+                ade_labels = ['O'] * len(ade_tokens)
+
+                for ent in ade_entities.values():
+                    ent_type = ent.name
+                    start_idx = ent.range[0]
+                    end_idx = ent.range[1]
+
+                    for idx in range(start_idx, end_idx + 1):
+                        if idx == start_idx:
+                            ade_labels[idx] = 'B-' + ent_label_map[ent_type]
+                        else:
+                            ade_labels[idx] = 'I-' + ent_label_map[ent_type]
+
+                for (token, label) in zip(ade_tokens, ade_labels):
+                    f.write('{}{}{}\n'.format(token, sep, label))
+                f.write('\n')
+
+    print("Data successfully saved in " + filename)