--- a +++ b/rule_based_ner/dict_ner.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +from typing import List +from ehr import HealthRecord +from annotations import Entity +from collections import defaultdict +import re + + +class DictNER: + ''' + A dictionary based NER model. + ''' + def __init__(self): + self.ner_re: dict = {} + + def _get_clean_re(self, entity_list: List[str]) -> str: + ''' + Generates a regular expression from a list of entities + + Parameters + ---------- + entity_list : List[str] + List of entity text. + + Returns + ------- + entity_re : str + Regular expression. + + ''' + regex_chars = ['(', ')', '[', ']', '{', '}', '+', '*', '?', '$', '^', '&'] + + for i in range(len(entity_list)): + # We need to add a \ so it does not take entity text as regex + # character + for char in regex_chars: + entity_list[i] = entity_list[i].replace(char, + '\\' + char) + + # A space/new line/tab before and after the text to indicate + # a seperate word + entity_re = '[\n| |\t]|[\n| |\t]'.join(entity_list) + entity_re = '[\n| |\t]' + entity_re + '[\n| |\t]' + + return entity_re + + def fit(self, train_data: List[HealthRecord]) -> DictNER: + ''' + Generates a dictionary for the model + + Parameters + ---------- + train_data : List[HealthRecord] + Records to generate the dictionary from. + + Returns + ------- + DictNER + Self object. + + ''' + ner_dict = defaultdict(list) + + for data in train_data: + for ent in data.entities.values(): + # We have a specific RE for Dosage + if ent.name != 'Strength': + # Ignore text with length 1 + if ent.ann_text.lower() not in ner_dict[ent.name]\ + and len(ent.ann_text) > 1: + ner_dict[ent.name].append(ent.ann_text.lower()) + + for name, entity_list in ner_dict.items(): + ner_dict[name] = self._get_clean_re(entity_list) + + # Dosage is just a number followed by mg or mcg + ner_dict['Strength'] = '\d+[ ]*(?:mg|mcg)' + self.ner_re = dict(ner_dict) + return self + + def predict(self, test_data: List[HealthRecord])\ + -> List[List[Entity]]: + ''' + Returns character ranges for all predicted entities + + Parameters + ---------- + test_data : List[HealthRecord] + Text to predict the entities. + + Returns + ------- + List[List[Entity]] + Predictions for each example. Each prediction list + contains several Entity objects. + + ''' + predictions = [] + for data in test_data: + entities = [] + j = 1 + for ent_name, ent_re in self.ner_re.items(): + # Get the start and end character ranges of entities + # Remove the extra space at the start and end of entity + ranges = [(m.start(0) + 1, m.end(0) - 1, ent_name) \ + for m in re.finditer(ent_re, data.text, re.IGNORECASE)] + + # Convert to Entity Objects + for r in ranges: + ent = Entity(entity_id = "T" + str(j)) + ent.set_range([r[0], r[1]]) + ent.set_entity_type(r[2]) + entities.append(ent) + j += 1 + + predictions.append(entities) + + return predictions