Switch to side-by-side view

--- a
+++ b/rule_based_ner/dict_ner.py
@@ -0,0 +1,119 @@
+from __future__ import annotations
+
+from typing import List
+from ehr import HealthRecord
+from annotations import Entity
+from collections import defaultdict
+import re
+
+
+class DictNER:
+    '''
+    A dictionary based NER model.
+    '''
+    def __init__(self):
+        self.ner_re: dict = {}
+
+    def _get_clean_re(self, entity_list: List[str]) -> str:
+        '''
+        Generates a regular expression from a list of entities
+
+        Parameters
+        ----------
+        entity_list : List[str]
+            List of entity text.
+
+        Returns
+        -------
+        entity_re : str
+            Regular expression.
+
+        '''
+        regex_chars = ['(', ')', '[', ']', '{', '}', '+', '*', '?', '$', '^', '&']
+        
+        for i in range(len(entity_list)):
+            # We need to add a \ so it does not take entity text as regex
+            # character
+            for char in regex_chars:
+                entity_list[i] = entity_list[i].replace(char, 
+                                                        '\\' + char)
+        
+        # A space/new line/tab before and after the text to indicate
+        # a seperate word
+        entity_re = '[\n| |\t]|[\n| |\t]'.join(entity_list)
+        entity_re = '[\n| |\t]' + entity_re + '[\n| |\t]'
+        
+        return entity_re
+        
+    def fit(self, train_data: List[HealthRecord]) -> DictNER:
+        '''
+        Generates a dictionary for the model
+
+        Parameters
+        ----------
+        train_data : List[HealthRecord]
+            Records to generate the dictionary from.
+
+        Returns
+        -------
+        DictNER
+            Self object.
+
+        '''
+        ner_dict = defaultdict(list)
+        
+        for data in train_data:
+            for ent in data.entities.values():
+                # We have a specific RE for Dosage
+                if ent.name != 'Strength':
+                    # Ignore text with length 1
+                    if ent.ann_text.lower() not in ner_dict[ent.name]\
+                        and len(ent.ann_text) > 1: 
+                        ner_dict[ent.name].append(ent.ann_text.lower())
+        
+        for name, entity_list in ner_dict.items():
+            ner_dict[name] = self._get_clean_re(entity_list)
+        
+        # Dosage is just a number followed by mg or mcg
+        ner_dict['Strength'] = '\d+[ ]*(?:mg|mcg)'
+        self.ner_re = dict(ner_dict)
+        return self
+    
+    def predict(self, test_data: List[HealthRecord])\
+            -> List[List[Entity]]:
+        '''
+        Returns character ranges for all predicted entities
+
+        Parameters
+        ----------
+        test_data : List[HealthRecord]
+            Text to predict the entities.
+
+        Returns
+        -------
+        List[List[Entity]]
+            Predictions for each example. Each prediction list 
+            contains several Entity objects.
+
+        '''
+        predictions = []
+        for data in test_data:
+            entities = []
+            j = 1
+            for ent_name, ent_re in self.ner_re.items():
+                # Get the start and end character ranges of entities
+                # Remove the extra space at the start and end of entity
+                ranges = [(m.start(0) + 1, m.end(0) - 1, ent_name) \
+                                      for m in re.finditer(ent_re, data.text, re.IGNORECASE)]
+                 
+                # Convert to Entity Objects
+                for r in ranges:
+                    ent = Entity(entity_id = "T" + str(j))
+                    ent.set_range([r[0], r[1]])
+                    ent.set_entity_type(r[2])
+                    entities.append(ent)
+                    j += 1
+            
+            predictions.append(entities)
+                        
+        return predictions