a b/biobert_re/utils_re.py
1
import os
2
import time
3
import random
4
from enum import Enum
5
from dataclasses import dataclass, field
6
from typing import List, Optional, Union, Dict, Tuple
7
8
import torch
9
from torch.utils.data.dataset import Dataset
10
11
from filelock import FileLock
12
13
import logging
14
15
from transformers import (InputFeatures,
16
                          InputExample,
17
                          PreTrainedTokenizerBase)
18
19
import pandas as pd
20
from sklearn.metrics import precision_recall_fscore_support
21
22
23
import sys
24
sys.path.append("../")
25
sys.path.append('./biobert_re/')
26
27
from data_processor import glue_convert_examples_to_features, glue_output_modes, glue_processors
28
29
import utils
30
from ehr import HealthRecord
31
from annotations import Relation
32
33
logger = logging.getLogger(__name__)
34
35
36
@dataclass
37
class GlueDataTrainingArguments:
38
    """
39
    Arguments pertaining to what data we are going to input our model for training and eval.
40
41
    Using `HfArgumentParser` we can turn this class into argparse arguments to be able to specify them on the command
42
    line.
43
    """
44
45
    task_name: str = field(metadata={"help": "The name of the task to train on: " + ", ".join(glue_processors.keys())})
46
    data_dir: str = field(
47
        metadata={"help": "The input data dir. Should contain the .tsv files (or other data files) for the task."}
48
    )
49
    max_seq_length: int = field(
50
        default=128,
51
        metadata={
52
            "help": "The maximum total input sequence length after tokenization. Sequences longer "
53
            "than this will be truncated, sequences shorter will be padded."
54
        },
55
    )
56
    overwrite_cache: bool = field(
57
        default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
58
    )
59
60
    def __post_init__(self):
61
        self.task_name = self.task_name.lower()
62
63
64
class Split(Enum):
65
    train = "train"
66
    dev = "dev"
67
    test = "test"
68
69
70
# noinspection PyTypeChecker
71
class REDataset(Dataset):
72
    """
73
    A class representing a training dataset for Relation Extraction.
74
    """
75
76
    args: GlueDataTrainingArguments
77
    output_mode: str
78
    features: List[InputFeatures]
79
80
    def __init__(
81
        self,
82
        args: GlueDataTrainingArguments,
83
        tokenizer: PreTrainedTokenizerBase,
84
        limit_length: Optional[int] = None,
85
        mode: Union[str, Split] = Split.train,
86
        cache_dir: Optional[str] = None,
87
    ):
88
        self.args = args
89
        self.processor = glue_processors[args.task_name]()
90
        self.output_mode = glue_output_modes[args.task_name]
91
        if isinstance(mode, str):
92
            try:
93
                mode = Split[mode]
94
            except KeyError:
95
                raise KeyError("mode is not a valid split name")
96
97
        # Load data features from cache or dataset file
98
        cached_features_file = os.path.join(
99
            cache_dir if cache_dir is not None else args.data_dir,
100
            "cached_{}_{}_{}_{}".format(
101
                mode.value,
102
                tokenizer.__class__.__name__,
103
                str(args.max_seq_length),
104
                args.task_name,
105
            ),
106
        )
107
108
        label_list = self.processor.get_labels()
109
110
        self.label_list = label_list
111
112
        # Make sure only the first process in distributed training processes the dataset,
113
        # and the others will use the cache.
114
        lock_path = cached_features_file + ".lock"
115
        with FileLock(lock_path):
116
117
            if os.path.exists(cached_features_file) and not args.overwrite_cache:
118
                start = time.time()
119
                self.features = torch.load(cached_features_file)
120
                logger.info(f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start)
121
            else:
122
                logger.info(f"Creating features from dataset file at {args.data_dir}")
123
124
                if mode == Split.dev:
125
                    examples = self.processor.get_dev_examples(args.data_dir)
126
                elif mode == Split.test:
127
                    examples = self.processor.get_test_examples(args.data_dir)
128
                else:
129
                    examples = self.processor.get_train_examples(args.data_dir)
130
                if limit_length is not None:
131
                    examples = examples[:limit_length]
132
                self.features = glue_convert_examples_to_features(
133
                    examples,
134
                    tokenizer,
135
                    max_length=args.max_seq_length,
136
                    label_list=label_list,
137
                    output_mode=self.output_mode,
138
                )
139
                start = time.time()
140
                torch.save(self.features, cached_features_file)
141
142
                logger.info("Saving features into cached file %s [took %.3f s]", cached_features_file, time.time() - start)
143
144
    def __len__(self):
145
        return len(self.features)
146
147
    def __getitem__(self, i) -> InputFeatures:
148
        return self.features[i]
149
150
    def get_labels(self):
151
        return self.label_list
152
153
154
class RETestDataset(Dataset):
155
    """
156
    A class representing a test Dataset for relation extraction.
157
    """
158
159
    def __init__(self, test_ehr, tokenizer, max_seq_len, label_list):
160
161
        self.re_text_list, self.relation_list = generate_re_test_file(test_ehr)
162
163
        if not self.re_text_list:
164
            self.features = []
165
        else:
166
            examples = []
167
            for (i, text) in enumerate(self.re_text_list):
168
                guid = "%s" % i
169
                examples.append(InputExample(guid=guid, text_a=text, text_b=None, label=None))
170
171
            self.features = glue_convert_examples_to_features(examples, tokenizer,
172
                                                              max_length=max_seq_len,
173
                                                              label_list=label_list)
174
175
    def __len__(self):
176
        return len(self.features)
177
178
    def __getitem__(self, i) -> InputFeatures:
179
        return self.features[i]
180
181
182
def replace_ent_label(text, ent_type, start_idx, end_idx):
183
    label = '@'+ent_type+'$'
184
    return text[:start_idx]+label+text[end_idx:]
185
186
187
def write_file(file, index, sentence, label, sep, is_test, is_label):
188
    if is_test and is_label:  # test_original - test with labels
189
        file.write('{}{}{}{}{}'.format(index, sep, sentence, sep, label))
190
    elif is_test and not is_label:  # test - test with no labels
191
        file.write('{}{}{}'.format(index, sep, sentence))
192
    else:  # train
193
        file.write('{}{}{}'.format(sentence, sep, label))
194
    file.write('\n')
195
196
197
def get_char_split_points(record, max_len):
198
    char_split_points = []
199
200
    split_points = record.get_split_points(max_len=max_len)
201
    for pt in split_points[:-1]:
202
        char_split_points.append(record.get_char_idx(pt)[1])
203
204
    if len(char_split_points) == 1:
205
        return char_split_points
206
    else:
207
        return char_split_points[1:]
208
209
210
def replace_entity_text(split_text, ent1, ent2, split_offset):
211
    # Remove split offset
212
    ent1_start = ent1.range[0] - split_offset
213
    ent1_end = ent1.range[1] - split_offset
214
215
    ent2_start = ent2.range[0] - split_offset
216
    ent2_end = ent2.range[1] - split_offset
217
218
    # If entity 1 is present before entity 2
219
    if ent1_end < ent2_end:
220
        # Insert entity 2 and then entity 1
221
        modified_text = replace_ent_label(split_text, ent2.name, ent2_start, ent2_end)
222
        modified_text = replace_ent_label(modified_text, ent1.name, ent1_start, ent1_end)
223
224
    # If entity 1 is present after entity 2
225
    else:
226
        # Insert entity 1 and then entity 2
227
        modified_text = replace_ent_label(split_text, ent1.name, ent1_start, ent1_end)
228
        modified_text = replace_ent_label(modified_text, ent2.name, ent2_start, ent2_end)
229
230
    return modified_text
231
232
233
def generate_re_input_files(ehr_records: List[HealthRecord], filename: str,
234
                            ade_records: List[Dict] = None, max_len: int = 128,
235
                            is_test=False, is_label=True, is_predict=False, sep: str = '\t'):
236
237
    random.seed(0)
238
239
    index = 0
240
    index_rel_label_map = []
241
242
    with open(filename, 'w') as file:
243
        # Write headers
244
        write_file(file, 'index', 'sentence', 'label', sep, is_test, is_label)
245
246
        # Preprocess EHR records
247
        for record in ehr_records:
248
            text = record.text
249
            entities = record.get_entities()
250
251
            if is_predict:
252
                true_relations = None
253
            else:
254
                true_relations = record.get_relations()
255
256
            # get character split points
257
            char_split_points = get_char_split_points(record, max_len)
258
259
            start = 0
260
            end = char_split_points[0]
261
262
            for i in range(len(char_split_points)):
263
                # Obtain only entities within the split text
264
                range_entities = {ent_id: ent for ent_id, ent in
265
                                  filter(lambda item: int(item[1][0]) >= start and int(item[1][1]) <= end,
266
                                         entities.items())}
267
268
                # Get all possible relations within the split text
269
                possible_relations = utils.map_entities(range_entities, true_relations)
270
271
                for rel, label in possible_relations:
272
                    if label == 0 and rel.name != "ADE-Drug":
273
                        if random.random() > 0.25:
274
                            continue
275
276
                    split_text = text[start:end]
277
                    split_offset = start
278
279
                    ent1 = rel.get_entities()[0]
280
                    ent2 = rel.get_entities()[1]
281
282
                    # Check if both entities are within split text
283
                    if ent1.range[0] >= start and ent1.range[1] < end and \
284
                            ent2.range[0] >= start and ent2.range[1] < end:
285
286
                        modified_text = replace_entity_text(split_text, ent1, ent2, split_offset)
287
288
                        # Replace un-required characters with space
289
                        final_text = modified_text.replace('\n', ' ').replace('\t', ' ')
290
                        write_file(file, index, final_text, label, sep, is_test, is_label)
291
292
                        if is_predict:
293
                            index_rel_label_map.append({'relation': rel})
294
                        else:
295
                            index_rel_label_map.append({'label': label, 'relation': rel})
296
297
                        index += 1
298
299
                start = end
300
                if i != len(char_split_points)-1:
301
                    end = char_split_points[i+1]
302
                else:
303
                    end = len(text)+1
304
305
        # Preprocess ADE records
306
        if ade_records is not None:
307
            for record in ade_records:
308
                entities = record['entities']
309
                true_relations = record['relations']
310
                possible_relations = utils.map_entities(entities, true_relations)
311
312
                for rel, label in possible_relations:
313
314
                    if label == 1 and random.random() > 0.5:
315
                        continue
316
317
                    new_tokens = record['tokens'].copy()
318
319
                    for ent in rel.get_entities():
320
                        ent_type = ent.name
321
322
                        start_tok = ent.range[0]
323
                        end_tok = ent.range[1]+1
324
325
                        for i in range(start_tok, end_tok):
326
                            new_tokens[i] = '@'+ent_type+'$'
327
328
                    """Remove consecutive repeating entities.
329
                    Eg. this is @ADE$ @ADE$ @ADE$ for @Drug$ @Drug$ -> this is @ADE$ for @Drug$"""
330
                    final_tokens = [new_tokens[i] for i in range(len(new_tokens))\
331
                                    if (i == 0) or new_tokens[i] != new_tokens[i-1]]
332
333
                    final_text = " ".join(final_tokens)
334
335
                    write_file(file, index, final_text, label, sep, is_test, is_label)
336
                    index_rel_label_map.append({'label': label, 'relation': rel})
337
                    index += 1
338
339
    filename, ext = filename.split('.')
340
    utils.save_pickle(filename+'_rel.pkl', index_rel_label_map)
341
342
343
def get_eval_results(answer_path, output_path):
344
    """
345
    Get evaluation metrics for predictions
346
347
    Parameters
348
    ------------
349
    answer_path : test.tsv file. Tab-separated.
350
                  One example per a line. True labels at the 3rd column.
351
352
    output_path : test_predictions.txt. Model generated predictions.
353
    """
354
    testdf = pd.read_csv(answer_path, sep="\t", index_col=0)
355
    preddf = pd.read_csv(output_path, sep="\t", header=None)
356
357
    pred = [preddf.iloc[i].tolist() for i in preddf.index]
358
    pred_class = [int(v[1]) for v in pred[1:]]
359
360
    p, r, f, s = precision_recall_fscore_support(y_pred=pred_class, y_true=testdf["label"])
361
    results = dict()
362
    results["f1 score"] = f[1]
363
    results["recall"] = r[1]
364
    results["precision"] = p[1]
365
    results["specificity"] = r[0]
366
    return results
367
368
369
def generate_re_test_file(ehr_record: HealthRecord,
370
                          max_len: int = 128) -> Tuple[List[str], List[Relation]]:
371
    """
372
    Generates test file for Relation Extraction.
373
374
    Parameters
375
    -----------
376
    ehr_record : HealthRecord
377
        The EHR record with entities set.
378
379
    max_len : int
380
        The maximum length of sequence.
381
382
    Returns
383
    --------
384
    Tuple[List[str], List[Relation]]
385
        List of sequences with entity replaced by it's tag.
386
        And a list of relation objects representing relation in those sequences.
387
    """
388
    random.seed(0)
389
390
    re_text_list = []
391
    relation_list = []
392
393
    text = ehr_record.text
394
    entities = ehr_record.get_entities()
395
    if isinstance(entities, dict):
396
        entities = list(entities.values())
397
398
    # get character split points
399
    char_split_points = get_char_split_points(ehr_record, max_len)
400
401
    start = 0
402
    end = char_split_points[0]
403
404
    for i in range(len(char_split_points)):
405
        # Obtain only entities within the split text
406
        range_entities = [ent for ent in filter(lambda item: int(item[0]) >= start and int(item[1]) <= end,
407
                                                entities)]
408
409
        # Get all possible relations within the split text
410
        possible_relations = utils.map_entities(range_entities)
411
412
        for rel, label in possible_relations:
413
            split_text = text[start:end]
414
            split_offset = start
415
416
            ent1 = rel.get_entities()[0]
417
            ent2 = rel.get_entities()[1]
418
419
            # Check if both entities are within split text
420
            if ent1[0] >= start and ent1[1] < end and \
421
                    ent2[0] >= start and ent2[1] < end:
422
423
                modified_text = replace_entity_text(split_text, ent1, ent2, split_offset)
424
425
                # Replace un-required characters with space
426
                final_text = modified_text.replace('\n', ' ').replace('\t', ' ')
427
428
                re_text_list.append(final_text)
429
                relation_list.append(rel)
430
431
        start = end
432
        if i != len(char_split_points)-1:
433
            end = char_split_points[i+1]
434
        else:
435
            end = len(text)+1
436
437
    assert len(re_text_list) == len(relation_list)
438
439
    return re_text_list, relation_list