a b/utils.py
1
from typing import List, Tuple, Callable, Dict, Union, Iterable
2
from annotations import Entity, Relation
3
from ehr import HealthRecord
4
5
import os
6
import sys
7
from pickle import dump, load
8
from IPython.core.display import display, HTML
9
import json
10
from collections import defaultdict
11
import pandas as pd
12
import networkx as nx
13
import math
14
import matplotlib.pyplot as plt
15
from io import BytesIO
16
import base64
17
import matplotlib
18
19
20
TPL_HTML = """<span style = "background-color: {color}; border-radius: 5px;">&nbsp;{content}&nbsp;</span>"""
21
22
TPL_HTML_HOVER = """<span style = "background-color: {color}; border-radius: 5px;" class="{grp}">&nbsp;{content}&nbsp;<span style = "background: {color}">{ent_type}</span></span>"""
23
24
COLORS = {"Drug": "#aa9cfc", "Strength": "#ff9561",
25
          "Form": "#7aecec", "Frequency": "#9cc9cc",
26
          "Route": "#ffeb80", "Dosage": "#bfe1d9",
27
          "Reason": "#e4e7d2", "ADE": "#ff8197",
28
          "Duration": "#97c4f5"}
29
30
31
def add_ent_group(entities: Union[Dict[str, Entity], List[Entity]],
32
                  relations: Union[Dict[str, Relation], List[Relation]]) -> List[Entity]:
33
    """
34
    Adds relation group to Entity objects.
35
36
    Parameters
37
    ----------
38
    entities : Union[Dict[str, Entity], List[Entity]]
39
        Entities
40
41
    relations : Union[Dict[str, Relation], List[Relation]])
42
        Relations
43
44
    Returns
45
    -------
46
    List[Entity]
47
        List of Entities with group information added.
48
    """
49
50
    # Convert entities to a dictionary if not
51
    if not isinstance(entities, dict):
52
        ent_dict = {}
53
        for ent in entities:
54
            ent_dict[ent.ann_id] = ent
55
        entities = ent_dict
56
57
    # Add group information
58
    for rel in relations:
59
        entities[rel.arg1.ann_id].relation_group += "group-" + rel.ann_id + " "
60
        entities[rel.arg2.ann_id].relation_group += "group-" + rel.ann_id + " "
61
62
    return list(entities.values())
63
64
65
# noinspection PyTypeChecker
66
def display_ehr(text: str,
67
                entities: Union[Dict[str, Entity], List[Entity]],
68
                relations: Union[Dict[str, Relation], List[Relation]] = None,
69
                return_html: bool = False) -> Union[None, str]:
70
    """
71
    Highlights EHR records with colors and displays
72
    them as HTML. Ideal for working with Jupyter Notebooks
73
74
    Parameters
75
    ----------
76
    text : str
77
        EHR record to render
78
79
    entities : Union[Dict[str, Entity], List[Entity]]
80
         A list of Entity objects
81
82
    relations : Union[Dict[str, Relation], List[Relation]]
83
        A list of relations. If provided, entities should be a dictionary.
84
85
    return_html : bool
86
        Indicator for returning HTML or printing the tagged EHR.
87
        The default is False.
88
89
    Returns
90
    -------
91
    Union[None, str]
92
        If return_html is true, returns html strings
93
        otherwise displays HTML.
94
95
    """
96
    if relations is not None:
97
        entities = add_ent_group(entities, relations)
98
99
    if isinstance(entities, dict):
100
        entities = list(entities.values())
101
102
    # Sort entity by starting range
103
    entities.sort(key=lambda x: x.range[0])
104
105
    # Final text to render
106
    render_text = ""
107
    start_idx = 0
108
109
    # Display legend
110
    if not return_html:
111
        for ent, col in COLORS.items():
112
            render_text += TPL_HTML.format(content=ent, color=col)
113
            render_text += "&nbsp" * 5
114
115
        render_text += '\n'
116
        render_text += '--' * 50
117
        render_text += "\n\n"
118
119
    # Replace each character range with HTML span template
120
    for ent in entities:
121
        if start_idx > ent.range[0]:
122
            continue
123
124
        render_text += text[start_idx:ent.range[0]]
125
126
        if return_html:
127
            render_text += TPL_HTML_HOVER.format(
128
                content=text[ent.range[0]:ent.range[1]],
129
                color=COLORS[ent.name],
130
                grp=ent.relation_group,
131
                ent_type=ent.name)
132
        else:
133
            render_text += TPL_HTML.format(
134
                content=text[ent.range[0]:ent.range[1]],
135
                color=COLORS[ent.name])
136
137
        start_idx = ent.range[1]
138
139
    render_text += text[start_idx:]
140
    render_text = render_text.replace("\n", "<br>")
141
142
    if return_html:
143
        return render_text
144
    else:
145
        display(HTML(render_text))
146
147
148
def display_knowledge_graph(long_relation_df: pd.DataFrame, num_col: int = 2,
149
                            height: int = 8, width: int = 8,
150
                            return_html: bool = False) -> Union[None, str]:
151
    """
152
    Highlights EHR records with colors and displays
153
    them as HTML. Ideal for working with Jupyter Notebooks
154
155
    Parameters
156
    ----------
157
    long_relation_df: pd.DataFrame
158
        Relation dataframe in long format. Should have columns named:
159
        ['drug_id', 'drug', 'arg', 'edge']
160
161
    num_col: int
162
        Number of columns in the grid. Number of rows are automatically
163
        calculated based on this. The default is 2.
164
165
    height: int
166
        The height of a single graph in inches. The default is 6.
167
168
    width: int
169
        The width of a single graph in inches. The default is 6.
170
171
    return_html: bool
172
        Indicator for returning the HTML img tag or displaying the plot.
173
        The default is False.
174
175
    Returns
176
    -------
177
    Union[None, str]
178
        If return_html is true, returns html string
179
        otherwise displays the plot.
180
181
    """
182
    if return_html:
183
        matplotlib.use('Agg')
184
185
    drug_ids = sorted(list(pd.unique(long_relation_df['drug_id'])))
186
    num_row = math.ceil(len(drug_ids) / num_col)
187
188
    if num_row == 0:
189
        return None
190
191
    _ = plt.subplots(num_row, num_col, figsize=(num_col * width, height * num_row))
192
193
    i = 0
194
    for i, d in enumerate(drug_ids):
195
        sub_rel = long_relation_df[long_relation_df["drug_id"] == d]
196
        labels = sub_rel.set_index(['drug', 'arg'])['edge'].to_dict()
197
198
        plt.subplot(num_row, num_col, i + 1)
199
200
        # Knowledge graph for a single drug
201
        graph = nx.from_pandas_edgelist(sub_rel, "drug", "arg", edge_attr=True, create_using=nx.MultiDiGraph())
202
203
        # Drug will always be the first in the graph
204
        color_map = ['#aa9cfc'] + ['skyblue'] * (len(graph.nodes) - 1)
205
206
        pos = nx.spring_layout(graph)
207
208
        # Draw the graph
209
        nx.draw(graph, with_labels=True, font_size=12, pos=pos,
210
                node_color=color_map, node_size=2000)
211
212
        # Draw edge labels
213
        nx.draw_networkx_edge_labels(graph, edge_labels=labels,
214
                                     pos=pos, font_color='red')
215
216
    # Remove axis for empty plots, if any
217
    i += 1
218
    while i < num_row * num_col:
219
        plt.subplot(num_row, num_col, i + 1)
220
        plt.axis('off')
221
        i += 1
222
223
    if not return_html:
224
        plt.show()
225
        return
226
227
    # Create an encoding for the image
228
    tmp_file = BytesIO()
229
230
    plt.tight_layout()
231
    plt.savefig(tmp_file, format="png")
232
233
    encoded = base64.b64encode(tmp_file.getvalue()).decode('utf-8')
234
    img_tag = '<img id="knowledge-graph" src=\'data:image/png;base64,{}\'>'.format(encoded)
235
236
    return img_tag
237
238
239
def read_data(data_dir: str = 'data/',
240
              tokenizer: Callable[[str], List[str]] = None,
241
              is_bert_tokenizer: bool = True,
242
              verbose: int = 0) -> Tuple[List[HealthRecord], List[HealthRecord]]:
243
    """
244
    Reads train and test data
245
246
    Parameters
247
    ----------
248
    data_dir : str, optional
249
        Directory where the data is located.
250
        It should have directories named 'train' and 'test'
251
        The default is 'data/'.
252
253
    tokenizer : Callable[[str], List[str]], optional
254
        The tokenizer function to use.. The default is None.
255
256
    is_bert_tokenizer : bool
257
        If the tokenizer is a BERT-based WordPiece tokenizer
258
259
    verbose : int, optional
260
        1 to print reading progress, 0 otherwise. The default is 0.
261
262
    Returns
263
    -------
264
    Tuple[List[HealthRecord], List[HealthRecord]]
265
        Train data, Test data.
266
267
    """
268
    train_path = os.path.join(data_dir, "train")
269
    test_path = os.path.join(data_dir, "test")
270
271
    # Get all IDs for train and test data
272
    train_ids = list(set(['.'.join(fname.split('.')[:-1]) \
273
                          for fname in os.listdir(train_path) \
274
                          if not fname.startswith('.')]))
275
276
    test_ids = list(set(['.'.join(fname.split('.')[:-1]) \
277
                         for fname in os.listdir(test_path) \
278
                         if not fname.startswith('.')]))
279
280
    if verbose == 1:
281
        print("Train data:")
282
283
    train_data = []
284
    for idx, fid in enumerate(train_ids):
285
        record = HealthRecord(fid, text_path=os.path.join(train_path, fid + '.txt'),
286
                              ann_path=os.path.join(train_path, fid + '.ann'),
287
                              tokenizer=tokenizer,
288
                              is_bert_tokenizer=is_bert_tokenizer)
289
        train_data.append(record)
290
        if verbose == 1:
291
            draw_progress_bar(idx + 1, len(train_ids))
292
293
    if verbose == 1:
294
        print('\n\nTest Data:')
295
296
    test_data = []
297
    for idx, fid in enumerate(test_ids):
298
        record = HealthRecord(fid, text_path=os.path.join(test_path, fid + '.txt'),
299
                              ann_path=os.path.join(test_path, fid + '.ann'),
300
                              tokenizer=tokenizer,
301
                              is_bert_tokenizer=is_bert_tokenizer)
302
        test_data.append(record)
303
        if verbose == 1:
304
            draw_progress_bar(idx + 1, len(test_ids))
305
306
    return train_data, test_data
307
308
309
def read_ade_data(ade_data_dir: str = 'ade_data/',
310
                  verbose: int = 0) -> List[Dict]:
311
    """
312
    Reads train and test ADE data
313
314
    Parameters
315
    ----------
316
317
    ade_data_dir : str, optional
318
        Directory where the ADE data is located. The default is 'ade_data/'.
319
320
    verbose : int, optional
321
        1 to print reading progress, 0 otherwise. The default is 0.
322
323
    Returns
324
    -------
325
    List[Dict]
326
        ADE data
327
328
    """
329
330
    # Get all the IDs of ADE data
331
    ade_file_ids = sorted(list(set(['.'.join(fname.split('.')[:-1]) \
332
                                    for fname in os.listdir(ade_data_dir) \
333
                                    if not fname.startswith('.')])))
334
335
    # Load ADE data
336
    ade_data = []
337
    for idx, fid in enumerate(ade_file_ids):
338
        with open(ade_data_dir + fid + '.json') as f:
339
            data = json.load(f)
340
            ade_data.extend(data)
341
342
    ade_data = process_ade_files(ade_data)
343
    if verbose == 1:
344
        print("\n\nADE data: Done")
345
346
    return ade_data
347
348
349
def process_ade_files(ade_data: List[dict]) -> List[dict]:
350
    """
351
    Extracts tokens and creates Entity and Relation objects
352
    from raw json data.
353
354
    Parameters
355
    ----------
356
    ade_data : List[dict]
357
        Raw json data.
358
359
    Returns
360
    -------
361
    List[dict]
362
        Tokens, entities and relations.
363
364
    """
365
    ade_records = []
366
367
    for ade in ade_data:
368
        entities = {}
369
        relations = {}
370
        relation_backlog = []
371
372
        # Tokens
373
        tokens = ade['tokens']
374
375
        # Entities
376
        e_num = 1
377
        for ent in ade['entities']:
378
            ent_id = 'T' + "%s" % e_num
379
            if ent['type'] == 'Adverse-Effect':
380
                ent['type'] = 'ADE'
381
382
            ent_obj = Entity(entity_id=ent_id,
383
                             entity_type=ent['type'])
384
385
            r = [ent['start'], ent['end'] - 1]
386
            r = list(map(int, r))
387
            ent_obj.set_range(r)
388
389
            text = ''
390
            for token_ent in ade['tokens'][ent['start']:ent['end']]:
391
                text += token_ent + ' '
392
            ent_obj.set_text(text)
393
394
            entities[ent_id] = ent_obj
395
            e_num += 1
396
397
            # Relations
398
        r_num = 1
399
        for relation in ade['relations']:
400
            rel_id = 'R' + "%s" % r_num
401
            rel_details = 'ADE-Drug'
402
            entity1 = "T" + str(relation['head'] + 1)
403
            entity2 = "T" + str(relation['tail'] + 1)
404
405
            if entity1 in entities and entity2 in entities:
406
                rel = Relation(relation_id=rel_id,
407
                               relation_type=rel_details,
408
                               arg1=entities[entity1],
409
                               arg2=entities[entity2])
410
411
                relations[rel_id] = rel
412
413
            else:
414
                relation_backlog.append([rel_id, rel_details,
415
                                         entity1, entity2])
416
            r_num += 1
417
418
        ade_records.append({"tokens": tokens, "entities": entities, "relations": relations})
419
    return ade_records
420
421
422
def map_entities(entities: Union[Dict[str, Entity], List[Entity]],
423
                 actual_relations: Union[Dict[str, Relation], List[Relation]] = None) \
424
        -> Union[List[Tuple[Relation, None]], List[Tuple[Relation, int]]]:
425
    """
426
    Maps each drug entity to all other non-drug entities in the list.
427
428
    Parameters
429
    ----------
430
    entities : List[Entity]
431
        List of entities.
432
433
    actual_relations : List[Relation], optional
434
        List of actual relations (for training data).
435
        The default is None.
436
437
    Returns
438
    -------
439
    Union[List[Relations], List[Tuple[Relation, int]]]
440
        List of mapped relations. If actual relations are specified,
441
        also returns a flag to indicate if it is an actual relation.
442
443
    """
444
445
    drug_entities = []
446
    non_drug_entities = []
447
448
    if isinstance(entities, dict):
449
        entities = list(entities.values())
450
451
    if actual_relations and isinstance(actual_relations, dict):
452
        actual_relations = list(actual_relations.values())
453
454
    # Splitting each entity to drug and non-drug entities
455
    for ent in entities:
456
        if ent.name.lower() == "drug":
457
            drug_entities.append(ent)
458
        else:
459
            non_drug_entities.append(ent)
460
461
    relations = []
462
    i = 1
463
464
    # Mapping each drug entity to each non-drug entity
465
    for ent1 in drug_entities:
466
        for ent2 in non_drug_entities:
467
            rel = Relation(relation_id="R%d" % i,
468
                           relation_type=ent2.name + "-Drug",
469
                           arg1=ent1, arg2=ent2)
470
            relations.append(rel)
471
            i += 1
472
473
    if actual_relations is None:
474
        return list(zip(relations, [None] * len(relations)))
475
476
    # Maps each relation type to list of actual relations
477
    actual_rel_dict = defaultdict(list)
478
    for rel in actual_relations:
479
        actual_rel_dict[rel.name].append(rel)
480
481
    relation_flags = []
482
    flag = 0
483
484
    # Computes actual relation flags
485
    for rel in relations:
486
        for act_rel in actual_rel_dict[rel.name]:
487
            if rel == act_rel:
488
                flag = 1
489
                break
490
491
        relation_flags.append(flag)
492
        flag = 0
493
494
    return list(zip(relations, relation_flags))
495
496
497
def get_long_relation_table(relations: Iterable[Relation]) -> pd.DataFrame:
498
    """
499
    Returns the relations in a long table format with the columns
500
    ['drug_id', 'drug', 'arg', 'edge'] where arg is entity related
501
    to drug and edge is the entity type.
502
503
    Parameters
504
    ----------
505
    relations : Iterable[Relation]
506
        A list of relations.
507
508
    Returns
509
    -------
510
    pd.DataFrame
511
        All the relations in a long tabular format.
512
513
    """
514
    rel_dict = {'drug_id': [], 'drug': [], 'arg': [], 'edge': []}
515
516
    for rel in relations:
517
        if rel.arg1.name == "Drug":
518
            rel_dict['drug_id'].append(rel.arg1.ann_id)
519
            rel_dict['drug'].append(rel.arg1.ann_text)
520
            rel_dict['arg'].append(rel.arg2.ann_text)
521
522
        else:
523
            rel_dict['drug_id'].append(rel.arg2.ann_id)
524
            rel_dict['drug'].append(rel.arg2.ann_text)
525
            rel_dict['arg'].append(rel.arg1.ann_text)
526
527
        rel_dict['edge'].append(rel.name.split('-')[0])
528
529
    rel_df = pd.DataFrame(rel_dict)
530
    return rel_df
531
532
533
def get_relation_table(relations: Union[pd.DataFrame, Iterable[Relation]],
534
                       is_long_df: bool = True) -> pd.DataFrame:
535
    """
536
    Returns the relations in a wide table format.
537
538
    Parameters
539
    ----------
540
    relations : Union[pd.DataFrame, Iterable[Relation]]
541
        Either a list of relations, or relations table in long format.
542
543
    is_long_df : bool
544
        Indicator for relations parameter. True indicates the input is
545
        a long dataframe. False indicates it is a list of relations.
546
547
    Returns
548
    -------
549
    str
550
        HTML blob of all the relations in a tabular format.
551
552
    """
553
    relations = relations.drop_duplicates()
554
555
    if not is_long_df:
556
        relations = get_long_relation_table(relations)
557
558
    relations = relations.rename(columns={"drug_id": "Drug ID", "drug": "Drug",
559
                                          "edge": "Entity Type", "arg": "Entity Text"})
560
561
    relation_df = (
562
        relations
563
        .groupby(["Drug ID", "Drug", "Entity Type"])["Entity Text"]
564
        .apply(lambda x: list(x))
565
        .reset_index(name="Entity Text")
566
        .set_index(["Drug ID", "Drug", "Entity Type"])
567
    )
568
569
    relation_df["Entity Text"] = relation_df["Entity Text"].apply(lambda x: "\n".join(x))
570
571
    empty_header = "    <tr style=\"text-align: right;\">\n      <th></th>\n      <th></th>\n      <th></th>\n      <th>Entity Text</th>\n    </tr>\n"
572
    empty_colname = "<th></th>"
573
574
    relation_html = (
575
        relation_df
576
        .to_html(classes=['table'], border=0)
577
        .replace("\\n", "<br>")
578
        .replace(empty_header, "")
579
        .replace(empty_colname, "<th>Entity Text</th>")
580
    )
581
    return relation_html
582
583
584
def draw_progress_bar(current, total, string='', bar_len=20):
585
    """
586
    Draws a progress bar, like [====>    ] 40%
587
588
    Parameters
589
    ------------
590
    current: int/float
591
             Current progress
592
593
    total: int/float
594
           The total from which the current progress is made
595
596
    string: str
597
            Additional details to write along with progress
598
599
    bar_len: int
600
            Length of progress bar
601
    """
602
    percent = current / total
603
    arrow = ">"
604
    if percent == 1:
605
        arrow = ""
606
    # Carriage return, returns to the beginning of line to overwrite
607
    sys.stdout.write("\r")
608
    sys.stdout.write("Progress: [{:<{}}] {}/{}".format("=" * int(bar_len * percent) + arrow,
609
                                                       bar_len, current, total) + string)
610
    sys.stdout.flush()
611
612
613
def is_whitespace(char):
614
    """
615
    Checks if the character is a whitespace
616
617
    Parameters
618
    --------------
619
    char: str
620
          A single character string to check
621
    """
622
    # ord() returns unicode and 0x202F is the unicode for whitespace
623
    if char == " " or char == "\t" or char == "\r" or char == "\n" or ord(char) == 0x202F:
624
        return True
625
    else:
626
        return False
627
628
629
def is_punct(char):
630
    """
631
    Checks if the character is a punctuation
632
633
    Parameters
634
    --------------
635
    char: str
636
          A single character string to check
637
    """
638
    if char == "." or char == "," or char == "!" or char == "?" or char == '\\':
639
        return True
640
    else:
641
        return False
642
643
644
def save_pickle(file, variable):
645
    """
646
    Saves variable as a pickle file
647
648
    Parameters
649
    -----------
650
    file: str
651
          File name/path in which the variable is to be stored
652
653
    variable: object
654
              The variable to be stored in a file
655
    """
656
    if file.split('.')[-1] != "pkl":
657
        file += ".pkl"
658
659
    with open(file, 'wb') as f:
660
        dump(variable, f)
661
        print("Variable successfully saved in " + file)
662
663
664
def open_pickle(file):
665
    """
666
    Returns the variable after reading it from a pickle file
667
668
    Parameters
669
    -----------
670
    file: str
671
          File name/path from which variable is to be loaded
672
    """
673
    if file.split('.')[-1] != "pkl":
674
        file += ".pkl"
675
676
    with open(file, 'rb') as f:
677
        return load(f)