import sys
sys.path.append("../")
import logging
import os
from dataclasses import dataclass
from enum import Enum
from typing import List, Optional, Union, Dict
from ehr import HealthRecord
from filelock import FileLock
from transformers import PreTrainedTokenizer
import torch
from torch import nn
from torch.utils.data.dataset import Dataset
logger = logging.getLogger(__name__)
@dataclass
class InputExample:
"""
A single training/test example for token classification.
Args:
guid: Unique id for the example.
words: list. The words of the sequence.
labels: (Optional) list. The labels for each word of the sequence. This should be
specified for train and dev examples, but not for test examples.
"""
guid: str
words: List[str]
labels: Optional[List[str]]
@dataclass
class InputFeatures:
"""
A single set of features of data.
Property names are the same names as the corresponding inputs to a model.
"""
input_ids: List[int]
attention_mask: List[int]
token_type_ids: Optional[List[int]] = None
label_ids: Optional[List[int]] = None
class Split(Enum):
train = "train_dev"
dev = "devel"
test = "test"
class NerTestDataset(Dataset):
"""
Dataset for test examples
"""
features: List[InputFeatures]
pad_token_label_id: int = nn.CrossEntropyLoss().ignore_index
def __init__(self, input_features):
self.features = input_features
def __len__(self):
return len(self.features)
def __getitem__(self, i) -> InputFeatures:
return self.features[i]
class NerDataset(Dataset):
features: List[InputFeatures]
pad_token_label_id: int = nn.CrossEntropyLoss().ignore_index
# Use cross entropy ignore_index as padding label id so that only
# real label ids contribute to the loss later.
def __init__(
self,
data_dir: str,
tokenizer: PreTrainedTokenizer,
labels: List[str],
model_type: str,
max_seq_length: Optional[int] = None,
overwrite_cache=False,
mode: Split = Split.train,
):
# Load data features from cache or dataset file
cached_features_file = os.path.join(
data_dir, "cached_{}_{}_{}".format(mode.value, tokenizer.__class__.__name__, str(max_seq_length)),
)
# Make sure only the first process in distributed training processes the dataset,
# and the others will use the cache.
lock_path = cached_features_file + ".lock"
with FileLock(lock_path):
if os.path.exists(cached_features_file) and not overwrite_cache:
logger.info(f"Loading features from cached file {cached_features_file}")
self.features = torch.load(cached_features_file)
else:
logger.info(f"Creating features from dataset file at {data_dir}")
examples = read_examples_from_file(data_dir, mode)
self.features = convert_examples_to_features(
examples,
labels,
max_seq_length,
tokenizer,
cls_token_at_end=bool(model_type in ["xlnet"]),
# xlnet has a cls token at the end
cls_token=tokenizer.cls_token,
cls_token_segment_id=2 if model_type in ["xlnet"] else 0,
sep_token=tokenizer.sep_token,
sep_token_extra=False,
# roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
pad_on_left=bool(tokenizer.padding_side == "left"),
pad_token=tokenizer.pad_token_id,
pad_token_segment_id=tokenizer.pad_token_type_id,
pad_token_label_id=self.pad_token_label_id,
)
logger.info(f"Saving features into cached file {cached_features_file}")
torch.save(self.features, cached_features_file)
def __len__(self):
return len(self.features)
def __getitem__(self, i) -> InputFeatures:
return self.features[i]
def read_examples_from_file(data_dir, mode: Union[Split, str]) -> List[InputExample]:
if isinstance(mode, Split):
mode = mode.value
file_path = os.path.join(data_dir, f"{mode}.txt")
guid_index = 1
examples = []
with open(file_path, encoding="utf-8") as f:
words = []
labels = []
for line in f:
line = line.rstrip()
if line.startswith("-DOCSTART-") or line == "" or line == "\n":
if words:
examples.append(InputExample(guid=f"{mode}-{guid_index}", words=words, labels=labels))
guid_index += 1
words = []
labels = []
else:
splits = line.split(" ")
words.append(splits[0])
if len(splits) > 1:
labels.append(splits[-1].replace("\n", ""))
else:
# Examples could have no label for mode = "test"
labels.append("O")
if words:
examples.append(InputExample(guid=f"{mode}-{guid_index}", words=words, labels=labels))
return examples
def convert_examples_to_features(
examples: List[InputExample],
label_list: List[str],
max_seq_length: int,
tokenizer: PreTrainedTokenizer,
cls_token_at_end=False,
cls_token="[CLS]",
cls_token_segment_id=1,
sep_token="[SEP]",
sep_token_extra=False,
pad_on_left=False,
pad_token=0,
pad_token_segment_id=0,
pad_token_label_id=-100,
sequence_a_segment_id=0,
mask_padding_with_zero=True,
verbose=1,
) -> List[InputFeatures]:
"""
Loads a data file into a list of `InputFeatures`
`cls_token_at_end` define the location of the CLS token:
- False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP]
- True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS]
`cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet)
"""
label_map = {label: i for i, label in enumerate(label_list)}
features = []
for (ex_index, example) in enumerate(examples):
if ex_index % 10_000 == 0:
logger.info("Writing example %d of %d", ex_index, len(examples))
tokens = []
label_ids = []
for word, label in zip(example.words, example.labels):
tokens.append(word)
if word.startswith("##"):
label_ids.append(pad_token_label_id)
else:
label_ids.append(label_map[label])
# Account for [CLS] and [SEP] with "- 2" and with "- 3" for RoBERTa.
special_tokens_count = tokenizer.num_special_tokens_to_add()
if len(tokens) > max_seq_length - special_tokens_count:
logger.info("Length %d exceeds max seq len, truncating." % len(tokens))
tokens = tokens[: (max_seq_length - special_tokens_count)]
label_ids = label_ids[: (max_seq_length - special_tokens_count)]
# The convention in BERT is:
# (a) For sequence pairs:
# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
# type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
# (b) For single sequences:
# tokens: [CLS] the dog is hairy . [SEP]
# type_ids: 0 0 0 0 0 0 0
#
# Where "type_ids" are used to indicate whether this is the first
# sequence or the second sequence. The embedding vectors for `type=0` and
# `type=1` were learned during pre-training and are added to the wordpiece
# embedding vector (and position vector). This is not *strictly* necessary
# since the [SEP] token unambiguously separates the sequences, but it makes
# it easier for the model to learn the concept of sequences.
#
# For classification tasks, the first vector (corresponding to [CLS]) is
# used as as the "sentence vector". Note that this only makes sense because
# the entire model is fine-tuned.
tokens += [sep_token]
label_ids += [pad_token_label_id]
if sep_token_extra:
# roberta uses an extra separator b/w pairs of sentences
tokens += [sep_token]
label_ids += [pad_token_label_id]
segment_ids = [sequence_a_segment_id] * len(tokens)
if cls_token_at_end:
tokens += [cls_token]
label_ids += [pad_token_label_id]
segment_ids += [cls_token_segment_id]
else:
tokens = [cls_token] + tokens
label_ids = [pad_token_label_id] + label_ids
segment_ids = [cls_token_segment_id] + segment_ids
input_ids = tokenizer.convert_tokens_to_ids(tokens)
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to.
input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
# Zero-pad up to the sequence length.
padding_length = max_seq_length - len(input_ids)
if pad_on_left:
input_ids = ([pad_token] * padding_length) + input_ids
input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask
segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids
label_ids = ([pad_token_label_id] * padding_length) + label_ids
else:
input_ids += [pad_token] * padding_length
input_mask += [0 if mask_padding_with_zero else 1] * padding_length
segment_ids += [pad_token_segment_id] * padding_length
label_ids += [pad_token_label_id] * padding_length
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length
assert len(label_ids) == max_seq_length
if ex_index < 2 and verbose == 1:
logger.info("*** Example ***")
logger.info("guid: %s", example.guid)
logger.info("tokens: %s", " ".join([str(x) for x in tokens]))
logger.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
logger.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
logger.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
logger.info("label_ids: %s", " ".join([str(x) for x in label_ids]))
if "token_type_ids" not in tokenizer.model_input_names:
segment_ids = None
features.append(
InputFeatures(
input_ids=input_ids, attention_mask=input_mask, token_type_ids=segment_ids, label_ids=label_ids
)
)
return features
def get_labels(path: str) -> List[str]:
if path:
with open(path, "r") as f:
labels = f.read().splitlines()
if "O" not in labels:
labels = ["O"] + labels
return labels
else:
return ["O", "B-MISC", "I-MISC", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC"]
def generate_input_files(ehr_records: List[HealthRecord], filename: str,
ade_records: List[Dict] = None, max_len: int = 510,
sep: str = ' '):
"""
Write EHR and ADE records to a file.
Parameters
----------
ehr_records : List[HealthRecord]
List of EHR records.
ade_records : List[Dict]
List of ADE records.
filename : str
File name to write to.
max_len : int, optional
Max length of an example. The default is 510.
sep : str, optional
Token-label separator. The default is a space.
"""
with open(filename, 'w') as f:
for record in ehr_records:
split_idx = record.get_split_points(max_len=max_len)
labels = record.get_labels()
tokens = record.get_tokens()
start = split_idx[0]
end = split_idx[1]
for i in range(1, len(split_idx)):
for (token, label) in zip(tokens[start:end + 1], labels[start:end + 1]):
f.write('{}{}{}\n'.format(token, sep, label))
start = end + 1
if i != len(split_idx) - 1:
end = split_idx[i + 1]
f.write('\n')
f.write('\n')
if ade_records is not None:
for ade in ade_records:
ade_tokens = ade['tokens']
ade_entities = ade['entities']
ent_label_map = {'Drug': 'DRUG', 'Adverse-Effect': 'ADE', 'ADE': 'ADE'}
ade_labels = ['O'] * len(ade_tokens)
for ent in ade_entities.values():
ent_type = ent.name
start_idx = ent.range[0]
end_idx = ent.range[1]
for idx in range(start_idx, end_idx + 1):
if idx == start_idx:
ade_labels[idx] = 'B-' + ent_label_map[ent_type]
else:
ade_labels[idx] = 'I-' + ent_label_map[ent_type]
for (token, label) in zip(ade_tokens, ade_labels):
f.write('{}{}{}\n'.format(token, sep, label))
f.write('\n')
print("Data successfully saved in " + filename)