Diff of /ehr.py [000000] .. [1de6ed]

Switch to unified view

a b/ehr.py
1
from annotations import Entity, Relation
2
from typing import List, Dict, Union, Tuple, Callable, Optional
3
import warnings
4
import numpy
5
6
7
class HealthRecord:
8
    """
9
    Objects that represent a single electronic health record
10
    """
11
12
    def __init__(self, record_id: str = "1", text_path: Optional[str] = None,
13
                 ann_path: Optional[str] = None,
14
                 text: Optional[str] = None,
15
                 tokenizer: Callable[[str], List[str]] = None,
16
                 is_bert_tokenizer: bool = True,
17
                 is_training: bool = True) -> None:
18
        """
19
        Initializes a health record object
20
21
        Parameters
22
        ----------
23
        record_id : int
24
            A unique ID for the record.
25
26
        text_path : str
27
            Path for the ehr record txt file.
28
29
        ann_path : str, optional
30
            Path for the annotation file. The default is None.
31
32
        text: str
33
            If text_path is not specified, the actual text for the
34
            record
35
36
        tokenizer: Callable[[str], List[str]], optional
37
            The tokenizer function to use. The default is None.
38
39
        is_bert_tokenizer: bool
40
            If the tokenizer is a BERT-based wordpiece tokenizer.
41
            The default is False.
42
43
        is_training : bool, optional
44
            Specifies if the record is a training example.
45
            The default is True.
46
        """
47
        if is_training and ann_path is None:
48
            raise AttributeError("Annotation path needs to be "
49
                                 "specified for training example.")
50
51
        if text_path is None and text is None:
52
            raise AttributeError("Either text or text path must be "
53
                                 "specified.")
54
55
        self.record_id = record_id
56
        self.is_training = is_training
57
58
        if text_path is not None:
59
            self.text = self._read_ehr(text_path)
60
        else:
61
            self.text = text
62
63
        self.char_to_token_map: List[int] = []
64
        self.token_to_char_map: List[int] = []
65
        self.tokenizer = None
66
        self.is_bert_tokenizer = is_bert_tokenizer
67
        self.elmo = None
68
        self.set_tokenizer(tokenizer)
69
        self.split_idx = None
70
71
        if ann_path is not None:
72
            annotations = self._extract_annotations(ann_path)
73
            self.entities, self.relations = annotations
74
75
        else:
76
            self.entities = None
77
            self.relations = None
78
79
    @staticmethod
80
    def _read_ehr(path: str) -> str:
81
        """
82
        Internal function to read EHR data.
83
84
        Parameters
85
        ----------
86
        path : str
87
            Path for EHR record.
88
89
        Returns
90
        -------
91
        str
92
            EHR record as a string.
93
        """
94
        f = open(path)
95
        raw_data = f.read()
96
        f.close()
97
        return raw_data
98
99
    @staticmethod
100
    def _extract_annotations(path: str) \
101
            -> Tuple[Dict[str, Entity], Dict[str, Relation]]:
102
        """
103
        Internal function that extracts entities and relations
104
        as a dictionary from an annotation file.
105
106
        Parameters
107
        ----------
108
        path : str
109
            Path for the ann file.
110
111
        Returns
112
        -------
113
        Tuple[Dict[str, Entity], Dict[str, Relation]]
114
            Entities and relations.
115
        """
116
        f = open(path)
117
        raw_data = f.read().split('\n')
118
        f.close()
119
120
        entities = {}
121
        relations = {}
122
123
        # Relations with entities that haven't been processed yet
124
        relation_backlog = []
125
126
        for line in raw_data:
127
            if line.startswith('#'):
128
                continue
129
130
            line = line.split('\t')
131
132
            # Remove empty strings from list
133
            line = list(filter(None, line))
134
135
            if not line or not line[0]:
136
                continue
137
138
            if line[0][0] == 'T':
139
                assert len(line) == 3
140
141
                idx = 0
142
                # Find the end of first word, which is the entity type
143
                for idx in range(len(line[1])):
144
                    if line[1][idx] == ' ':
145
                        break
146
147
                char_ranges = line[1][idx + 1:]
148
149
                # Get all character ranges, separated by ;
150
                char_ranges = [r.split() for r in char_ranges.split(';')]
151
152
                # Create an Entity object
153
                ent = Entity(entity_id=line[0],
154
                             entity_type=line[1][:idx])
155
156
                r = [char_ranges[0][0], char_ranges[-1][1]]
157
                r = list(map(int, r))
158
                ent.set_range(r)
159
160
                ent.set_text(line[2])
161
                entities[line[0]] = ent
162
163
            elif line[0][0] == 'R':
164
                assert len(line) == 2
165
166
                rel_details = line[1].split(' ')
167
                entity1 = rel_details[1].split(':')[-1]
168
                entity2 = rel_details[2].split(':')[-1]
169
170
                if entity1 in entities and entity2 in entities:
171
                    rel = Relation(relation_id=line[0],
172
                                   relation_type=rel_details[0],
173
                                   arg1=entities[entity1],
174
                                   arg2=entities[entity2])
175
176
                    relations[line[0]] = rel
177
                else:
178
                    # If the entities aren't processed yet, 
179
                    # add them to backlog to process later
180
                    relation_backlog.append([line[0], rel_details[0],
181
                                             entity1, entity2])
182
183
            else:
184
                # If the annotation is not a relation or entity, warn user
185
                msg = f"Invalid annotation encountered: {line}, File: {path}"
186
                warnings.warn(msg)
187
188
        for r in relation_backlog:
189
            rel = Relation(relation_id=r[0], relation_type=r[1],
190
                           arg1=entities[r[2]], arg2=entities[r[3]])
191
192
            relations[r[0]] = rel
193
194
        return entities, relations
195
196
    def _compute_tokens(self) -> None:
197
        """
198
        Computes the tokens and character <-> token index mappings
199
        for EHR text data.
200
        """
201
        self.tokens = list(map(lambda x: str(x), self.tokenizer(self.text)))
202
203
        char_to_token_map = []
204
        token_to_char_map = []
205
206
        j = 0
207
        k = 0
208
209
        for i in range(len(self.tokens)):
210
            # For BioBERT, a split within a word is denoted by ##
211
            if self.is_bert_tokenizer and self.tokens[i].startswith("##"):
212
                k += 2
213
214
            # Characters that are discarded from tokenization
215
            while self.text[j].lower() != self.tokens[i][k].lower():
216
                char_to_token_map.append(char_to_token_map[-1])
217
                j += 1
218
219
            # For SciSpacy, if there are multiple spaces, it removes
220
            # one and keeps the rest
221
            if self.text[j] == ' ' and self.text[j + 1] == ' ':
222
                char_to_token_map.append(char_to_token_map[-1])
223
                j += 1
224
225
            token_start_idx = j
226
            # Go over each letter in token and original text
227
            while k < len(self.tokens[i]):
228
                if self.text[j].lower() == self.tokens[i][k].lower():
229
                    char_to_token_map.append(i)
230
                    j += 1
231
                    k += 1
232
                else:
233
                    msg = f"Error computing token to char map. ID: {self.record_id}"
234
                    raise Exception(msg)
235
236
            token_end_idx = j
237
            token_to_char_map.append((token_start_idx, token_end_idx))
238
            k = 0
239
240
        # Characters at the end which are discarded by tokenizer
241
        while j < len(self.text):
242
            char_to_token_map.append(char_to_token_map[-1])
243
            j += 1
244
245
        assert len(char_to_token_map) == len(self.text)
246
        assert len(token_to_char_map) == len(self.tokens)
247
248
        self.char_to_token_map = char_to_token_map
249
        self.token_to_char_map = token_to_char_map
250
251
    def get_tokens(self) -> List[str]:
252
        """
253
        Returns the tokens.
254
255
        Returns
256
        -------
257
        List[str]
258
            List of tokens.
259
        """
260
        if self.tokenizer is None:
261
            raise AttributeError("Tokenizer not set.")
262
263
        return self.tokens
264
265
    def set_tokenizer(self, tokenizer: Callable[[str], List[str]]) \
266
            -> None:
267
        """
268
        Set the tokenizer for the object.
269
270
        Parameters
271
        ----------
272
        tokenizer : Callable[[str], List[str]]
273
            The tokenizer function to use.
274
        """
275
        self.tokenizer = tokenizer
276
        if tokenizer is not None:
277
            self._compute_tokens()
278
279
    def get_token_idx(self, char_idx: int) -> int:
280
        """
281
        Returns the token index from character index.
282
283
        Parameters
284
        ----------
285
        char_idx : int
286
            Character index.
287
288
        Returns
289
        -------
290
        int
291
            Token index.
292
        """
293
        if self.tokenizer is None:
294
            raise AttributeError("Tokenizer not set.")
295
296
        token_idx = self.char_to_token_map[char_idx]
297
298
        return token_idx
299
300
    def get_char_idx(self, token_idx: int) -> int:
301
        """
302
        Returns the index for the first character of the specified
303
        token index.
304
305
        Parameters
306
        ----------
307
        token_idx : int
308
            Token index.
309
310
        Returns
311
        -------
312
        int
313
            Character index.
314
        """
315
        if self.tokenizer is None:
316
            raise AttributeError("Tokenizer not set.")
317
318
        char_idx = self.token_to_char_map[token_idx]
319
320
        return char_idx
321
322
    def get_labels(self) -> List[str]:
323
        """
324
        Get token labels in IOB format.
325
326
        Returns
327
        -------
328
        List[str]
329
            Labels.
330
331
        """
332
        if self.tokenizer is None:
333
            raise AttributeError("No tokens found. Set tokenizer first.")
334
335
        ent_label_map = {'Drug': 'DRUG', 'Strength': 'STR', 'Duration': 'DUR',
336
                         'Route': 'ROU', 'Form': 'FOR', 'ADE': 'ADE', 'Dosage': 'DOS',
337
                         'Reason': 'REA', 'Frequency': 'FRE'}
338
339
        labels = ['O'] * len(self.tokens)
340
341
        for ent in self.entities.values():
342
            start_idx = self.get_token_idx(ent.range[0])
343
            end_idx = self.get_token_idx(ent.range[1])
344
345
            for idx in range(start_idx, end_idx + 1):
346
                if idx == start_idx:
347
                    labels[idx] = 'B-' + ent_label_map[ent.name]
348
                else:
349
                    labels[idx] = 'I-' + ent_label_map[ent.name]
350
351
        return labels
352
353
    def get_split_points(self, max_len: int = 510,
354
                         new_line_ind: List[str] = None,
355
                         sent_end_ind: List[str] = None) -> List[int]:
356
        """
357
        Get the splitting points for tokens.
358
359
        > It includes as many paragraphs as it can within the
360
        max_len - 2 token limit. (2 less because BERT needs
361
                                  to add 2 special tokens)
362
363
        > If it can't find a single complete paragraph,
364
        it will split on the last verifiable new line that
365
        starts with a new sentence.
366
367
        > If it can't find that as well, it splits on token max_len - 2.
368
369
        Parameters
370
        ----------
371
        max_len : int, optional
372
            Maximum number tokens in one example. The default is 510
373
            for BERT.
374
375
        new_line_ind : List[str], optional
376
            New line indicators. Strings other than numbers.
377
            The default is ['[', '#', '-', '>', ' '].
378
379
        sent_end_ind : List[str], optional
380
            Sentence end indicators. The default is ['.', '?', '!'].
381
382
        Returns
383
        -------
384
        List[int]
385
            Splitting indices, includes the first and last index.
386
            Need to add 1 to the end indices if accessing
387
            with list splicing.
388
389
        """
390
        if new_line_ind is None:
391
            new_line_ind = ['[', '#', '-', '>', ' ']
392
393
        if sent_end_ind is None:
394
            sent_end_ind = ['.', '?', '!']
395
396
        split_idx = [0]
397
        last_par_end_idx = 0
398
        last_line_end_idx = 0
399
400
        for i in range(len(self.text)):
401
            curr_counter = self.get_token_idx(i) - split_idx[-1]
402
403
            if curr_counter >= max_len:
404
                # If not even a single paragraph has ended
405
                if last_par_end_idx == 0 and last_line_end_idx != 0:
406
                    split_idx.append(last_line_end_idx)
407
408
                elif last_par_end_idx != 0:
409
                    split_idx.append(last_par_end_idx)
410
411
                else:
412
                    split_idx.append(self.get_token_idx(i))
413
414
                last_par_end_idx = 0
415
                last_line_end_idx = 0
416
417
            if i < len(self.text) - 2 and self.text[i] == '\n':
418
                if self.text[i + 1] == '\n':
419
                    last_par_end_idx = self.get_token_idx(i - 1)
420
421
                if self.text[i + 1] == '.' or self.text[i + 1] == '*':
422
                    last_par_end_idx = self.get_token_idx(i + 1)
423
424
                if self.text[i + 1] in new_line_ind or \
425
                        self.text[i + 1].isdigit() or \
426
                        self.text[i - 1] in sent_end_ind:
427
                    last_line_end_idx = self.get_token_idx(i)
428
429
        split_idx.append(len(self.tokens))
430
        self.split_idx = split_idx
431
432
        return self.split_idx
433
434
    def get_annotations(self) -> Dict[str, Union[list, dict]]:
435
        """
436
        Get entities and relations in a dictionary.
437
        Entities are referenced with the key 'entities'
438
        and relations with 'relations'
439
440
        Returns
441
        -------
442
        Dict[Dict[str, Entity], Dict[str, Relation]]
443
            Entities and relations.
444
        """
445
        if self.entities is None or self.relations is None:
446
            raise AttributeError("Annotations not available")
447
448
        return {'entities': self.entities, 'relations': self.relations}
449
450
    def get_entities(self) -> Dict[str, Entity]:
451
        """
452
        Get the entities.
453
454
        Returns
455
        -------
456
        Dict[str, Entity]
457
            Entity ID: Entity object.
458
        """
459
        if self.entities is None:
460
            raise AttributeError("Entities not set")
461
462
        return self.entities
463
464
    def get_relations(self) -> Dict[str, Relation]:
465
        """
466
        Get the entity relations.
467
468
        Returns
469
        -------
470
        Dict[str, Relation]
471
            Relation ID: Relation Object.
472
        """
473
        if self.relations is None:
474
            raise AttributeError("Relations not set")
475
476
        return self.relations
477
478
    def _compute_elmo_embeddings(self) -> None:
479
        """
480
        Computes the Elmo embeddings for each token in EHR text data.
481
        """
482
        # noinspection PyUnresolvedReferences
483
        elmo_embeddings = self.elmo.embed_sentence(self.tokens)[-1]
484
        self.elmo_embeddings = elmo_embeddings
485
486
    def set_elmo_embedder(self, elmo: Callable[[str], numpy.ndarray]) -> None:
487
        """
488
        Set Elmo embedder for object.
489
490
        Parameters
491
        ----------
492
        elmo :
493
            The Elmo embedder to use.
494
        """
495
        self.elmo = elmo
496
        if elmo is not None:
497
            self._compute_elmo_embeddings()
498
499
    def get_elmo_embeddings(self) -> numpy.ndarray:
500
        """
501
        Get the elmo embeddings.
502
503
        Returns
504
        -------
505
        List[int]:
506
            Elmo embeddings for each word
507
508
        """
509
        if self.elmo_embeddings is None:
510
            raise AttributeError("Elmo embeddings not set")
511
512
        return self.elmo_embeddings