Switch to unified view

a b/src/preprocessing/generate_relations.py
1
# coding: utf-8
2
3
# Base Dependencies
4
# ------------------
5
import json
6
import xml.etree.ElementTree as ET
7
8
from os.path import join as pjoin
9
from pathlib import Path
10
from tqdm import tqdm
11
from typing import List, Tuple, Set, Dict
12
13
# Local Dependencies
14
# --------------------
15
from models import Document, Entity, RelationN2C2, RelationDDI, RelationCollection
16
from utils import files_ddi, files_n2c2, doc_id_n2c2, make_dir
17
18
# 3rd-Party Dependencies
19
# ----------------------
20
from PyRuSH import PyRuSHSentencizer
21
22
# Constants
23
# ---------
24
from constants import N2C2_PATH, DDI_PATH
25
26
27
# Auxiliar Functions
28
# ------------------
29
def read_txt(file: Path) -> str:
30
    """Reads a .txt file
31
32
    Args:
33
        file (Path): path to the .txt file
34
    """
35
    # read text file
36
    with open(file, "r", encoding="utf-8") as fin:
37
        text = fin.read()
38
39
    return text
40
41
42
def read_json(file: Path) -> str:
43
    """Reads a .json file
44
45
    Args:
46
        file (Path): path to the .json file
47
    """
48
    return json.loads(read_txt(file))
49
50
51
def read_annotations_n2c2(file: Path) -> Tuple[List[Entity], Set[str]]:
52
    """Reads a n2c2 .ann file and extracts the entities and the relations
53
54
    Args:
55
        file (Path): path to the n2c2 annotation file
56
    """
57
58
    # read file
59
    with open(file, "r", encoding="utf-8") as fin:
60
        annotations: List[str] = fin.readlines()
61
62
    # process file
63
    doc_id: str = doc_id_n2c2(file)
64
    entities: List[Entity] = list()
65
    gt_relations: Set[str] = set()  # ground-truth relations
66
67
    for line in annotations:
68
        if line.startswith("T"):  # process entity
69
            entities.append(Entity.from_n2c2_annotation(doc_id, line))
70
71
        elif line.startswith("R"):  # process relation
72
            id, definition = line.strip().split("\t")
73
            type, entity1_id, entity2_id = definition.split()
74
            entity1_id = entity1_id.split(":")[1]
75
            entity2_id = entity2_id.split(":")[1]
76
            gt_relations.add("{}-{}".format(entity1_id, entity2_id))
77
78
        else:  # ignore annotator's note
79
            continue
80
81
    # sort entities by their end character
82
    entities.sort(key=lambda ent: ent.end)
83
84
    return entities, gt_relations
85
86
87
# Main Functions
88
# ---------------
89
def generate_relations(
90
    dataset: str, save_to_disk: bool = True
91
) -> Dict[str, RelationCollection]:
92
    """Generates relations of a given dataset and saves them to disk
93
94
    Args:
95
        dataset (str): dataset's name
96
        save_to_disk (bool, optional): the relation collections are saved to disk in a datading or not. Defaults to True.
97
98
    Raises:
99
        ValueError: unsupported dataset
100
101
    Returns:
102
        Dict[str, RelationCollection]: train and test relation collections
103
    """
104
    if dataset == "n2c2":
105
        return generate_relations_n2c2(save_to_disk=save_to_disk)
106
    elif dataset == "ddi":
107
        return generate_relations_ddi(save_to_disk=save_to_disk)
108
    else:
109
        raise ValueError("unsupported dataset '{}'".format(dataset))
110
111
112
def generate_relations_n2c2(save_to_disk: bool = True) -> Dict[str, RelationCollection]:
113
    """Generates relations of the n2c2 dataset
114
        1. Per document
115
        2. Read all entities, all true relations
116
        3. Separate in to drugs and per attribute
117
        4. For each relation type, combine each drug with each attribute within the same sentence
118
119
    Args:
120
        save_to_disk (bool): the relation collections are saved to disk in a datading or not. Default to True.
121
122
    Returns:
123
        Dict[str, RelationCollection]: train and test relation collections
124
    """
125
    print("Generating relations for the n2c2 dataset...\n")
126
127
    dataset = files_n2c2()
128
    collections = {}
129
130
    for split, files in dataset.items():
131
132
        print(split, ": ")
133
        split_entities = []
134
        split_relations = []
135
136
        for basepath in tqdm(files):
137
            # process clinical text, split in sentences
138
            document: Document = Document.from_json(read_txt(basepath + ".json"))
139
140
            # read annotation file
141
            entities, gt_relations = read_annotations_n2c2(basepath + ".ann")
142
143
            # generate relations
144
            relations = RelationN2C2.generate_relations_n2c2(
145
                document, entities, gt_relations, (split == "test")
146
            )
147
148
            split_entities.extend(entities)
149
            split_relations.extend(relations)
150
151
        # create collection
152
        collection = RelationCollection(split_relations)
153
154
        # remove invalid relations
155
        collection = collection[collection.valid_indexes()]
156
157
        # write to databing
158
        if save_to_disk:
159
            make_dir(pjoin(N2C2_PATH, "{}_datading".format(split)))
160
            collection.to_datading(
161
                pjoin(N2C2_PATH, "{}_datading".format(split), "relations.msgpack")
162
            )
163
        
164
        collections[split] = collection
165
166
    return collections
167
168
169
def generate_relations_ddi(save_to_disk: bool = True) -> Dict[str, RelationCollection]:
170
    """Generates relations of the ddi dataset
171
172
    Args:
173
        save_to_disk (bool): the relation collections are saved to disk in a datading or not. Default to True.
174
175
    Returns:
176
        Dict[str, RelationCollection]: train and test relation collections
177
    """
178
    print("Generating relations for the DDI Extraction corpus...")
179
180
    dataset = files_ddi()
181
    collections = {}
182
183
    for split, files in dataset.items():
184
        print(split, ": ")
185
        split_relations = []
186
187
        for file in tqdm(files):
188
            xml_tree = ET.parse(file)
189
            relations = RelationDDI.generate_relations_ddi(xml_tree)
190
            split_relations.extend(relations)
191
192
        # create collection
193
        collection = RelationCollection(split_relations)
194
195
        # remove invalid relations
196
        collection = collection[collection.valid_indexes()]
197
198
        # write to databing
199
        if save_to_disk:
200
            make_dir(pjoin(DDI_PATH, "{}_datading".format(split)))
201
            collection.to_datading(
202
                pjoin(DDI_PATH, "{}_datading".format(split), "relations.msgpack")
203
            )
204
            
205
        collections[split] = collection
206
207
    return collections