Diff of /ehr.py [000000] .. [1de6ed]

Switch to side-by-side view

--- a
+++ b/ehr.py
@@ -0,0 +1,512 @@
+from annotations import Entity, Relation
+from typing import List, Dict, Union, Tuple, Callable, Optional
+import warnings
+import numpy
+
+
+class HealthRecord:
+    """
+    Objects that represent a single electronic health record
+    """
+
+    def __init__(self, record_id: str = "1", text_path: Optional[str] = None,
+                 ann_path: Optional[str] = None,
+                 text: Optional[str] = None,
+                 tokenizer: Callable[[str], List[str]] = None,
+                 is_bert_tokenizer: bool = True,
+                 is_training: bool = True) -> None:
+        """
+        Initializes a health record object
+
+        Parameters
+        ----------
+        record_id : int
+            A unique ID for the record.
+
+        text_path : str
+            Path for the ehr record txt file.
+
+        ann_path : str, optional
+            Path for the annotation file. The default is None.
+
+        text: str
+            If text_path is not specified, the actual text for the
+            record
+
+        tokenizer: Callable[[str], List[str]], optional
+            The tokenizer function to use. The default is None.
+
+        is_bert_tokenizer: bool
+            If the tokenizer is a BERT-based wordpiece tokenizer.
+            The default is False.
+
+        is_training : bool, optional
+            Specifies if the record is a training example.
+            The default is True.
+        """
+        if is_training and ann_path is None:
+            raise AttributeError("Annotation path needs to be "
+                                 "specified for training example.")
+
+        if text_path is None and text is None:
+            raise AttributeError("Either text or text path must be "
+                                 "specified.")
+
+        self.record_id = record_id
+        self.is_training = is_training
+
+        if text_path is not None:
+            self.text = self._read_ehr(text_path)
+        else:
+            self.text = text
+
+        self.char_to_token_map: List[int] = []
+        self.token_to_char_map: List[int] = []
+        self.tokenizer = None
+        self.is_bert_tokenizer = is_bert_tokenizer
+        self.elmo = None
+        self.set_tokenizer(tokenizer)
+        self.split_idx = None
+
+        if ann_path is not None:
+            annotations = self._extract_annotations(ann_path)
+            self.entities, self.relations = annotations
+
+        else:
+            self.entities = None
+            self.relations = None
+
+    @staticmethod
+    def _read_ehr(path: str) -> str:
+        """
+        Internal function to read EHR data.
+
+        Parameters
+        ----------
+        path : str
+            Path for EHR record.
+
+        Returns
+        -------
+        str
+            EHR record as a string.
+        """
+        f = open(path)
+        raw_data = f.read()
+        f.close()
+        return raw_data
+
+    @staticmethod
+    def _extract_annotations(path: str) \
+            -> Tuple[Dict[str, Entity], Dict[str, Relation]]:
+        """
+        Internal function that extracts entities and relations
+        as a dictionary from an annotation file.
+
+        Parameters
+        ----------
+        path : str
+            Path for the ann file.
+
+        Returns
+        -------
+        Tuple[Dict[str, Entity], Dict[str, Relation]]
+            Entities and relations.
+        """
+        f = open(path)
+        raw_data = f.read().split('\n')
+        f.close()
+
+        entities = {}
+        relations = {}
+
+        # Relations with entities that haven't been processed yet
+        relation_backlog = []
+
+        for line in raw_data:
+            if line.startswith('#'):
+                continue
+
+            line = line.split('\t')
+
+            # Remove empty strings from list
+            line = list(filter(None, line))
+
+            if not line or not line[0]:
+                continue
+
+            if line[0][0] == 'T':
+                assert len(line) == 3
+
+                idx = 0
+                # Find the end of first word, which is the entity type
+                for idx in range(len(line[1])):
+                    if line[1][idx] == ' ':
+                        break
+
+                char_ranges = line[1][idx + 1:]
+
+                # Get all character ranges, separated by ;
+                char_ranges = [r.split() for r in char_ranges.split(';')]
+
+                # Create an Entity object
+                ent = Entity(entity_id=line[0],
+                             entity_type=line[1][:idx])
+
+                r = [char_ranges[0][0], char_ranges[-1][1]]
+                r = list(map(int, r))
+                ent.set_range(r)
+
+                ent.set_text(line[2])
+                entities[line[0]] = ent
+
+            elif line[0][0] == 'R':
+                assert len(line) == 2
+
+                rel_details = line[1].split(' ')
+                entity1 = rel_details[1].split(':')[-1]
+                entity2 = rel_details[2].split(':')[-1]
+
+                if entity1 in entities and entity2 in entities:
+                    rel = Relation(relation_id=line[0],
+                                   relation_type=rel_details[0],
+                                   arg1=entities[entity1],
+                                   arg2=entities[entity2])
+
+                    relations[line[0]] = rel
+                else:
+                    # If the entities aren't processed yet, 
+                    # add them to backlog to process later
+                    relation_backlog.append([line[0], rel_details[0],
+                                             entity1, entity2])
+
+            else:
+                # If the annotation is not a relation or entity, warn user
+                msg = f"Invalid annotation encountered: {line}, File: {path}"
+                warnings.warn(msg)
+
+        for r in relation_backlog:
+            rel = Relation(relation_id=r[0], relation_type=r[1],
+                           arg1=entities[r[2]], arg2=entities[r[3]])
+
+            relations[r[0]] = rel
+
+        return entities, relations
+
+    def _compute_tokens(self) -> None:
+        """
+        Computes the tokens and character <-> token index mappings
+        for EHR text data.
+        """
+        self.tokens = list(map(lambda x: str(x), self.tokenizer(self.text)))
+
+        char_to_token_map = []
+        token_to_char_map = []
+
+        j = 0
+        k = 0
+
+        for i in range(len(self.tokens)):
+            # For BioBERT, a split within a word is denoted by ##
+            if self.is_bert_tokenizer and self.tokens[i].startswith("##"):
+                k += 2
+
+            # Characters that are discarded from tokenization
+            while self.text[j].lower() != self.tokens[i][k].lower():
+                char_to_token_map.append(char_to_token_map[-1])
+                j += 1
+
+            # For SciSpacy, if there are multiple spaces, it removes
+            # one and keeps the rest
+            if self.text[j] == ' ' and self.text[j + 1] == ' ':
+                char_to_token_map.append(char_to_token_map[-1])
+                j += 1
+
+            token_start_idx = j
+            # Go over each letter in token and original text
+            while k < len(self.tokens[i]):
+                if self.text[j].lower() == self.tokens[i][k].lower():
+                    char_to_token_map.append(i)
+                    j += 1
+                    k += 1
+                else:
+                    msg = f"Error computing token to char map. ID: {self.record_id}"
+                    raise Exception(msg)
+
+            token_end_idx = j
+            token_to_char_map.append((token_start_idx, token_end_idx))
+            k = 0
+
+        # Characters at the end which are discarded by tokenizer
+        while j < len(self.text):
+            char_to_token_map.append(char_to_token_map[-1])
+            j += 1
+
+        assert len(char_to_token_map) == len(self.text)
+        assert len(token_to_char_map) == len(self.tokens)
+
+        self.char_to_token_map = char_to_token_map
+        self.token_to_char_map = token_to_char_map
+
+    def get_tokens(self) -> List[str]:
+        """
+        Returns the tokens.
+
+        Returns
+        -------
+        List[str]
+            List of tokens.
+        """
+        if self.tokenizer is None:
+            raise AttributeError("Tokenizer not set.")
+
+        return self.tokens
+
+    def set_tokenizer(self, tokenizer: Callable[[str], List[str]]) \
+            -> None:
+        """
+        Set the tokenizer for the object.
+
+        Parameters
+        ----------
+        tokenizer : Callable[[str], List[str]]
+            The tokenizer function to use.
+        """
+        self.tokenizer = tokenizer
+        if tokenizer is not None:
+            self._compute_tokens()
+
+    def get_token_idx(self, char_idx: int) -> int:
+        """
+        Returns the token index from character index.
+
+        Parameters
+        ----------
+        char_idx : int
+            Character index.
+
+        Returns
+        -------
+        int
+            Token index.
+        """
+        if self.tokenizer is None:
+            raise AttributeError("Tokenizer not set.")
+
+        token_idx = self.char_to_token_map[char_idx]
+
+        return token_idx
+
+    def get_char_idx(self, token_idx: int) -> int:
+        """
+        Returns the index for the first character of the specified
+        token index.
+
+        Parameters
+        ----------
+        token_idx : int
+            Token index.
+
+        Returns
+        -------
+        int
+            Character index.
+        """
+        if self.tokenizer is None:
+            raise AttributeError("Tokenizer not set.")
+
+        char_idx = self.token_to_char_map[token_idx]
+
+        return char_idx
+
+    def get_labels(self) -> List[str]:
+        """
+        Get token labels in IOB format.
+
+        Returns
+        -------
+        List[str]
+            Labels.
+
+        """
+        if self.tokenizer is None:
+            raise AttributeError("No tokens found. Set tokenizer first.")
+
+        ent_label_map = {'Drug': 'DRUG', 'Strength': 'STR', 'Duration': 'DUR',
+                         'Route': 'ROU', 'Form': 'FOR', 'ADE': 'ADE', 'Dosage': 'DOS',
+                         'Reason': 'REA', 'Frequency': 'FRE'}
+
+        labels = ['O'] * len(self.tokens)
+
+        for ent in self.entities.values():
+            start_idx = self.get_token_idx(ent.range[0])
+            end_idx = self.get_token_idx(ent.range[1])
+
+            for idx in range(start_idx, end_idx + 1):
+                if idx == start_idx:
+                    labels[idx] = 'B-' + ent_label_map[ent.name]
+                else:
+                    labels[idx] = 'I-' + ent_label_map[ent.name]
+
+        return labels
+
+    def get_split_points(self, max_len: int = 510,
+                         new_line_ind: List[str] = None,
+                         sent_end_ind: List[str] = None) -> List[int]:
+        """
+        Get the splitting points for tokens.
+
+        > It includes as many paragraphs as it can within the
+        max_len - 2 token limit. (2 less because BERT needs
+                                  to add 2 special tokens)
+
+        > If it can't find a single complete paragraph,
+        it will split on the last verifiable new line that
+        starts with a new sentence.
+
+        > If it can't find that as well, it splits on token max_len - 2.
+
+        Parameters
+        ----------
+        max_len : int, optional
+            Maximum number tokens in one example. The default is 510
+            for BERT.
+
+        new_line_ind : List[str], optional
+            New line indicators. Strings other than numbers.
+            The default is ['[', '#', '-', '>', ' '].
+
+        sent_end_ind : List[str], optional
+            Sentence end indicators. The default is ['.', '?', '!'].
+
+        Returns
+        -------
+        List[int]
+            Splitting indices, includes the first and last index.
+            Need to add 1 to the end indices if accessing
+            with list splicing.
+
+        """
+        if new_line_ind is None:
+            new_line_ind = ['[', '#', '-', '>', ' ']
+
+        if sent_end_ind is None:
+            sent_end_ind = ['.', '?', '!']
+
+        split_idx = [0]
+        last_par_end_idx = 0
+        last_line_end_idx = 0
+
+        for i in range(len(self.text)):
+            curr_counter = self.get_token_idx(i) - split_idx[-1]
+
+            if curr_counter >= max_len:
+                # If not even a single paragraph has ended
+                if last_par_end_idx == 0 and last_line_end_idx != 0:
+                    split_idx.append(last_line_end_idx)
+
+                elif last_par_end_idx != 0:
+                    split_idx.append(last_par_end_idx)
+
+                else:
+                    split_idx.append(self.get_token_idx(i))
+
+                last_par_end_idx = 0
+                last_line_end_idx = 0
+
+            if i < len(self.text) - 2 and self.text[i] == '\n':
+                if self.text[i + 1] == '\n':
+                    last_par_end_idx = self.get_token_idx(i - 1)
+
+                if self.text[i + 1] == '.' or self.text[i + 1] == '*':
+                    last_par_end_idx = self.get_token_idx(i + 1)
+
+                if self.text[i + 1] in new_line_ind or \
+                        self.text[i + 1].isdigit() or \
+                        self.text[i - 1] in sent_end_ind:
+                    last_line_end_idx = self.get_token_idx(i)
+
+        split_idx.append(len(self.tokens))
+        self.split_idx = split_idx
+
+        return self.split_idx
+
+    def get_annotations(self) -> Dict[str, Union[list, dict]]:
+        """
+        Get entities and relations in a dictionary.
+        Entities are referenced with the key 'entities'
+        and relations with 'relations'
+
+        Returns
+        -------
+        Dict[Dict[str, Entity], Dict[str, Relation]]
+            Entities and relations.
+        """
+        if self.entities is None or self.relations is None:
+            raise AttributeError("Annotations not available")
+
+        return {'entities': self.entities, 'relations': self.relations}
+
+    def get_entities(self) -> Dict[str, Entity]:
+        """
+        Get the entities.
+
+        Returns
+        -------
+        Dict[str, Entity]
+            Entity ID: Entity object.
+        """
+        if self.entities is None:
+            raise AttributeError("Entities not set")
+
+        return self.entities
+
+    def get_relations(self) -> Dict[str, Relation]:
+        """
+        Get the entity relations.
+
+        Returns
+        -------
+        Dict[str, Relation]
+            Relation ID: Relation Object.
+        """
+        if self.relations is None:
+            raise AttributeError("Relations not set")
+
+        return self.relations
+
+    def _compute_elmo_embeddings(self) -> None:
+        """
+        Computes the Elmo embeddings for each token in EHR text data.
+        """
+        # noinspection PyUnresolvedReferences
+        elmo_embeddings = self.elmo.embed_sentence(self.tokens)[-1]
+        self.elmo_embeddings = elmo_embeddings
+
+    def set_elmo_embedder(self, elmo: Callable[[str], numpy.ndarray]) -> None:
+        """
+        Set Elmo embedder for object.
+
+        Parameters
+        ----------
+        elmo :
+            The Elmo embedder to use.
+        """
+        self.elmo = elmo
+        if elmo is not None:
+            self._compute_elmo_embeddings()
+
+    def get_elmo_embeddings(self) -> numpy.ndarray:
+        """
+        Get the elmo embeddings.
+
+        Returns
+        -------
+        List[int]:
+            Elmo embeddings for each word
+
+        """
+        if self.elmo_embeddings is None:
+            raise AttributeError("Elmo embeddings not set")
+
+        return self.elmo_embeddings