from typing import List, Tuple, Callable, Dict, Union, Iterable
from annotations import Entity, Relation
from ehr import HealthRecord
import os
import sys
from pickle import dump, load
from IPython.core.display import display, HTML
import json
from collections import defaultdict
import pandas as pd
import networkx as nx
import math
import matplotlib.pyplot as plt
from io import BytesIO
import base64
import matplotlib
TPL_HTML = """<span style = "background-color: {color}; border-radius: 5px;"> {content} </span>"""
TPL_HTML_HOVER = """<span style = "background-color: {color}; border-radius: 5px;" class="{grp}"> {content} <span style = "background: {color}">{ent_type}</span></span>"""
COLORS = {"Drug": "#aa9cfc", "Strength": "#ff9561",
"Form": "#7aecec", "Frequency": "#9cc9cc",
"Route": "#ffeb80", "Dosage": "#bfe1d9",
"Reason": "#e4e7d2", "ADE": "#ff8197",
"Duration": "#97c4f5"}
def add_ent_group(entities: Union[Dict[str, Entity], List[Entity]],
relations: Union[Dict[str, Relation], List[Relation]]) -> List[Entity]:
"""
Adds relation group to Entity objects.
Parameters
----------
entities : Union[Dict[str, Entity], List[Entity]]
Entities
relations : Union[Dict[str, Relation], List[Relation]])
Relations
Returns
-------
List[Entity]
List of Entities with group information added.
"""
# Convert entities to a dictionary if not
if not isinstance(entities, dict):
ent_dict = {}
for ent in entities:
ent_dict[ent.ann_id] = ent
entities = ent_dict
# Add group information
for rel in relations:
entities[rel.arg1.ann_id].relation_group += "group-" + rel.ann_id + " "
entities[rel.arg2.ann_id].relation_group += "group-" + rel.ann_id + " "
return list(entities.values())
# noinspection PyTypeChecker
def display_ehr(text: str,
entities: Union[Dict[str, Entity], List[Entity]],
relations: Union[Dict[str, Relation], List[Relation]] = None,
return_html: bool = False) -> Union[None, str]:
"""
Highlights EHR records with colors and displays
them as HTML. Ideal for working with Jupyter Notebooks
Parameters
----------
text : str
EHR record to render
entities : Union[Dict[str, Entity], List[Entity]]
A list of Entity objects
relations : Union[Dict[str, Relation], List[Relation]]
A list of relations. If provided, entities should be a dictionary.
return_html : bool
Indicator for returning HTML or printing the tagged EHR.
The default is False.
Returns
-------
Union[None, str]
If return_html is true, returns html strings
otherwise displays HTML.
"""
if relations is not None:
entities = add_ent_group(entities, relations)
if isinstance(entities, dict):
entities = list(entities.values())
# Sort entity by starting range
entities.sort(key=lambda x: x.range[0])
# Final text to render
render_text = ""
start_idx = 0
# Display legend
if not return_html:
for ent, col in COLORS.items():
render_text += TPL_HTML.format(content=ent, color=col)
render_text += " " * 5
render_text += '\n'
render_text += '--' * 50
render_text += "\n\n"
# Replace each character range with HTML span template
for ent in entities:
if start_idx > ent.range[0]:
continue
render_text += text[start_idx:ent.range[0]]
if return_html:
render_text += TPL_HTML_HOVER.format(
content=text[ent.range[0]:ent.range[1]],
color=COLORS[ent.name],
grp=ent.relation_group,
ent_type=ent.name)
else:
render_text += TPL_HTML.format(
content=text[ent.range[0]:ent.range[1]],
color=COLORS[ent.name])
start_idx = ent.range[1]
render_text += text[start_idx:]
render_text = render_text.replace("\n", "<br>")
if return_html:
return render_text
else:
display(HTML(render_text))
def display_knowledge_graph(long_relation_df: pd.DataFrame, num_col: int = 2,
height: int = 8, width: int = 8,
return_html: bool = False) -> Union[None, str]:
"""
Highlights EHR records with colors and displays
them as HTML. Ideal for working with Jupyter Notebooks
Parameters
----------
long_relation_df: pd.DataFrame
Relation dataframe in long format. Should have columns named:
['drug_id', 'drug', 'arg', 'edge']
num_col: int
Number of columns in the grid. Number of rows are automatically
calculated based on this. The default is 2.
height: int
The height of a single graph in inches. The default is 6.
width: int
The width of a single graph in inches. The default is 6.
return_html: bool
Indicator for returning the HTML img tag or displaying the plot.
The default is False.
Returns
-------
Union[None, str]
If return_html is true, returns html string
otherwise displays the plot.
"""
if return_html:
matplotlib.use('Agg')
drug_ids = sorted(list(pd.unique(long_relation_df['drug_id'])))
num_row = math.ceil(len(drug_ids) / num_col)
if num_row == 0:
return None
_ = plt.subplots(num_row, num_col, figsize=(num_col * width, height * num_row))
i = 0
for i, d in enumerate(drug_ids):
sub_rel = long_relation_df[long_relation_df["drug_id"] == d]
labels = sub_rel.set_index(['drug', 'arg'])['edge'].to_dict()
plt.subplot(num_row, num_col, i + 1)
# Knowledge graph for a single drug
graph = nx.from_pandas_edgelist(sub_rel, "drug", "arg", edge_attr=True, create_using=nx.MultiDiGraph())
# Drug will always be the first in the graph
color_map = ['#aa9cfc'] + ['skyblue'] * (len(graph.nodes) - 1)
pos = nx.spring_layout(graph)
# Draw the graph
nx.draw(graph, with_labels=True, font_size=12, pos=pos,
node_color=color_map, node_size=2000)
# Draw edge labels
nx.draw_networkx_edge_labels(graph, edge_labels=labels,
pos=pos, font_color='red')
# Remove axis for empty plots, if any
i += 1
while i < num_row * num_col:
plt.subplot(num_row, num_col, i + 1)
plt.axis('off')
i += 1
if not return_html:
plt.show()
return
# Create an encoding for the image
tmp_file = BytesIO()
plt.tight_layout()
plt.savefig(tmp_file, format="png")
encoded = base64.b64encode(tmp_file.getvalue()).decode('utf-8')
img_tag = '<img id="knowledge-graph" src=\'data:image/png;base64,{}\'>'.format(encoded)
return img_tag
def read_data(data_dir: str = 'data/',
tokenizer: Callable[[str], List[str]] = None,
is_bert_tokenizer: bool = True,
verbose: int = 0) -> Tuple[List[HealthRecord], List[HealthRecord]]:
"""
Reads train and test data
Parameters
----------
data_dir : str, optional
Directory where the data is located.
It should have directories named 'train' and 'test'
The default is 'data/'.
tokenizer : Callable[[str], List[str]], optional
The tokenizer function to use.. The default is None.
is_bert_tokenizer : bool
If the tokenizer is a BERT-based WordPiece tokenizer
verbose : int, optional
1 to print reading progress, 0 otherwise. The default is 0.
Returns
-------
Tuple[List[HealthRecord], List[HealthRecord]]
Train data, Test data.
"""
train_path = os.path.join(data_dir, "train")
test_path = os.path.join(data_dir, "test")
# Get all IDs for train and test data
train_ids = list(set(['.'.join(fname.split('.')[:-1]) \
for fname in os.listdir(train_path) \
if not fname.startswith('.')]))
test_ids = list(set(['.'.join(fname.split('.')[:-1]) \
for fname in os.listdir(test_path) \
if not fname.startswith('.')]))
if verbose == 1:
print("Train data:")
train_data = []
for idx, fid in enumerate(train_ids):
record = HealthRecord(fid, text_path=os.path.join(train_path, fid + '.txt'),
ann_path=os.path.join(train_path, fid + '.ann'),
tokenizer=tokenizer,
is_bert_tokenizer=is_bert_tokenizer)
train_data.append(record)
if verbose == 1:
draw_progress_bar(idx + 1, len(train_ids))
if verbose == 1:
print('\n\nTest Data:')
test_data = []
for idx, fid in enumerate(test_ids):
record = HealthRecord(fid, text_path=os.path.join(test_path, fid + '.txt'),
ann_path=os.path.join(test_path, fid + '.ann'),
tokenizer=tokenizer,
is_bert_tokenizer=is_bert_tokenizer)
test_data.append(record)
if verbose == 1:
draw_progress_bar(idx + 1, len(test_ids))
return train_data, test_data
def read_ade_data(ade_data_dir: str = 'ade_data/',
verbose: int = 0) -> List[Dict]:
"""
Reads train and test ADE data
Parameters
----------
ade_data_dir : str, optional
Directory where the ADE data is located. The default is 'ade_data/'.
verbose : int, optional
1 to print reading progress, 0 otherwise. The default is 0.
Returns
-------
List[Dict]
ADE data
"""
# Get all the IDs of ADE data
ade_file_ids = sorted(list(set(['.'.join(fname.split('.')[:-1]) \
for fname in os.listdir(ade_data_dir) \
if not fname.startswith('.')])))
# Load ADE data
ade_data = []
for idx, fid in enumerate(ade_file_ids):
with open(ade_data_dir + fid + '.json') as f:
data = json.load(f)
ade_data.extend(data)
ade_data = process_ade_files(ade_data)
if verbose == 1:
print("\n\nADE data: Done")
return ade_data
def process_ade_files(ade_data: List[dict]) -> List[dict]:
"""
Extracts tokens and creates Entity and Relation objects
from raw json data.
Parameters
----------
ade_data : List[dict]
Raw json data.
Returns
-------
List[dict]
Tokens, entities and relations.
"""
ade_records = []
for ade in ade_data:
entities = {}
relations = {}
relation_backlog = []
# Tokens
tokens = ade['tokens']
# Entities
e_num = 1
for ent in ade['entities']:
ent_id = 'T' + "%s" % e_num
if ent['type'] == 'Adverse-Effect':
ent['type'] = 'ADE'
ent_obj = Entity(entity_id=ent_id,
entity_type=ent['type'])
r = [ent['start'], ent['end'] - 1]
r = list(map(int, r))
ent_obj.set_range(r)
text = ''
for token_ent in ade['tokens'][ent['start']:ent['end']]:
text += token_ent + ' '
ent_obj.set_text(text)
entities[ent_id] = ent_obj
e_num += 1
# Relations
r_num = 1
for relation in ade['relations']:
rel_id = 'R' + "%s" % r_num
rel_details = 'ADE-Drug'
entity1 = "T" + str(relation['head'] + 1)
entity2 = "T" + str(relation['tail'] + 1)
if entity1 in entities and entity2 in entities:
rel = Relation(relation_id=rel_id,
relation_type=rel_details,
arg1=entities[entity1],
arg2=entities[entity2])
relations[rel_id] = rel
else:
relation_backlog.append([rel_id, rel_details,
entity1, entity2])
r_num += 1
ade_records.append({"tokens": tokens, "entities": entities, "relations": relations})
return ade_records
def map_entities(entities: Union[Dict[str, Entity], List[Entity]],
actual_relations: Union[Dict[str, Relation], List[Relation]] = None) \
-> Union[List[Tuple[Relation, None]], List[Tuple[Relation, int]]]:
"""
Maps each drug entity to all other non-drug entities in the list.
Parameters
----------
entities : List[Entity]
List of entities.
actual_relations : List[Relation], optional
List of actual relations (for training data).
The default is None.
Returns
-------
Union[List[Relations], List[Tuple[Relation, int]]]
List of mapped relations. If actual relations are specified,
also returns a flag to indicate if it is an actual relation.
"""
drug_entities = []
non_drug_entities = []
if isinstance(entities, dict):
entities = list(entities.values())
if actual_relations and isinstance(actual_relations, dict):
actual_relations = list(actual_relations.values())
# Splitting each entity to drug and non-drug entities
for ent in entities:
if ent.name.lower() == "drug":
drug_entities.append(ent)
else:
non_drug_entities.append(ent)
relations = []
i = 1
# Mapping each drug entity to each non-drug entity
for ent1 in drug_entities:
for ent2 in non_drug_entities:
rel = Relation(relation_id="R%d" % i,
relation_type=ent2.name + "-Drug",
arg1=ent1, arg2=ent2)
relations.append(rel)
i += 1
if actual_relations is None:
return list(zip(relations, [None] * len(relations)))
# Maps each relation type to list of actual relations
actual_rel_dict = defaultdict(list)
for rel in actual_relations:
actual_rel_dict[rel.name].append(rel)
relation_flags = []
flag = 0
# Computes actual relation flags
for rel in relations:
for act_rel in actual_rel_dict[rel.name]:
if rel == act_rel:
flag = 1
break
relation_flags.append(flag)
flag = 0
return list(zip(relations, relation_flags))
def get_long_relation_table(relations: Iterable[Relation]) -> pd.DataFrame:
"""
Returns the relations in a long table format with the columns
['drug_id', 'drug', 'arg', 'edge'] where arg is entity related
to drug and edge is the entity type.
Parameters
----------
relations : Iterable[Relation]
A list of relations.
Returns
-------
pd.DataFrame
All the relations in a long tabular format.
"""
rel_dict = {'drug_id': [], 'drug': [], 'arg': [], 'edge': []}
for rel in relations:
if rel.arg1.name == "Drug":
rel_dict['drug_id'].append(rel.arg1.ann_id)
rel_dict['drug'].append(rel.arg1.ann_text)
rel_dict['arg'].append(rel.arg2.ann_text)
else:
rel_dict['drug_id'].append(rel.arg2.ann_id)
rel_dict['drug'].append(rel.arg2.ann_text)
rel_dict['arg'].append(rel.arg1.ann_text)
rel_dict['edge'].append(rel.name.split('-')[0])
rel_df = pd.DataFrame(rel_dict)
return rel_df
def get_relation_table(relations: Union[pd.DataFrame, Iterable[Relation]],
is_long_df: bool = True) -> pd.DataFrame:
"""
Returns the relations in a wide table format.
Parameters
----------
relations : Union[pd.DataFrame, Iterable[Relation]]
Either a list of relations, or relations table in long format.
is_long_df : bool
Indicator for relations parameter. True indicates the input is
a long dataframe. False indicates it is a list of relations.
Returns
-------
str
HTML blob of all the relations in a tabular format.
"""
relations = relations.drop_duplicates()
if not is_long_df:
relations = get_long_relation_table(relations)
relations = relations.rename(columns={"drug_id": "Drug ID", "drug": "Drug",
"edge": "Entity Type", "arg": "Entity Text"})
relation_df = (
relations
.groupby(["Drug ID", "Drug", "Entity Type"])["Entity Text"]
.apply(lambda x: list(x))
.reset_index(name="Entity Text")
.set_index(["Drug ID", "Drug", "Entity Type"])
)
relation_df["Entity Text"] = relation_df["Entity Text"].apply(lambda x: "\n".join(x))
empty_header = " <tr style=\"text-align: right;\">\n <th></th>\n <th></th>\n <th></th>\n <th>Entity Text</th>\n </tr>\n"
empty_colname = "<th></th>"
relation_html = (
relation_df
.to_html(classes=['table'], border=0)
.replace("\\n", "<br>")
.replace(empty_header, "")
.replace(empty_colname, "<th>Entity Text</th>")
)
return relation_html
def draw_progress_bar(current, total, string='', bar_len=20):
"""
Draws a progress bar, like [====> ] 40%
Parameters
------------
current: int/float
Current progress
total: int/float
The total from which the current progress is made
string: str
Additional details to write along with progress
bar_len: int
Length of progress bar
"""
percent = current / total
arrow = ">"
if percent == 1:
arrow = ""
# Carriage return, returns to the beginning of line to overwrite
sys.stdout.write("\r")
sys.stdout.write("Progress: [{:<{}}] {}/{}".format("=" * int(bar_len * percent) + arrow,
bar_len, current, total) + string)
sys.stdout.flush()
def is_whitespace(char):
"""
Checks if the character is a whitespace
Parameters
--------------
char: str
A single character string to check
"""
# ord() returns unicode and 0x202F is the unicode for whitespace
if char == " " or char == "\t" or char == "\r" or char == "\n" or ord(char) == 0x202F:
return True
else:
return False
def is_punct(char):
"""
Checks if the character is a punctuation
Parameters
--------------
char: str
A single character string to check
"""
if char == "." or char == "," or char == "!" or char == "?" or char == '\\':
return True
else:
return False
def save_pickle(file, variable):
"""
Saves variable as a pickle file
Parameters
-----------
file: str
File name/path in which the variable is to be stored
variable: object
The variable to be stored in a file
"""
if file.split('.')[-1] != "pkl":
file += ".pkl"
with open(file, 'wb') as f:
dump(variable, f)
print("Variable successfully saved in " + file)
def open_pickle(file):
"""
Returns the variable after reading it from a pickle file
Parameters
-----------
file: str
File name/path from which variable is to be loaded
"""
if file.split('.')[-1] != "pkl":
file += ".pkl"
with open(file, 'rb') as f:
return load(f)