Diff of /src/utils.py [000000] .. [735bb5]

Switch to unified view

a b/src/utils.py
1
# Base Dependencies
2
# -----------------
3
import functools
4
import numpy as np
5
import operator
6
import os
7
import random
8
import re
9
10
from glob import glob
11
from os.path import join as pjoin
12
from pathlib import Path
13
from typing import List, Any, Union
14
15
# Local Dependencies
16
# ------------------
17
from constants import N2C2_PATH, DDI_PATH, N2C2_ANNONYM_PATTERNS, DDI_ALL_TYPES
18
19
# 3rd-Party Dependencies
20
# ----------------------
21
import torch
22
from torch import nn
23
from transformers import set_seed as transformers_set_seed
24
25
26
def set_seed(seed: int) -> None:
27
    """Sets the random seed for modules torch, numpy and random.
28
29
    Args:
30
        seed (int): random seed
31
    """
32
    transformers_set_seed(seed)
33
    torch.manual_seed(seed)
34
    np.random.seed(seed)
35
    random.seed(seed)
36
    
37
    if torch.cuda.is_available() and torch.cuda.device_count() > 0:
38
        torch.backends.cudnn.deterministic = True
39
        torch.cuda.manual_seed_all(seed)
40
41
42
def flatten(array: List[List[Any]]) -> List[Any]:
43
    """
44
    Flattens a nested 2D list. faster even with a very long array than
45
    [item for subarray in array for item in subarray] or newarray.extend().
46
47
    Args:
48
        array (List[List[Any]]): a nested list
49
    Returns:
50
        List[Any]: flattened list
51
    """
52
    return functools.reduce(operator.iconcat, array, [])
53
54
55
def write_list_to_file(output_path: Path, array: List[Any]) -> None:
56
    """
57
    Writes list of str to file in `output_path`.
58
59
    Args:
60
        output_path (Path): output file path
61
        array (List[Any]): list of strings
62
    """
63
    with output_path.open("w", encoding="utf-8") as opened_file:
64
        for entry in array:
65
            opened_file.write(f"{entry}\n")
66
67
68
def read_list_from_file(input_path: Path) -> List[str]:
69
    """
70
    Reads list of str from file in `input_path`.
71
72
    Args:
73
        input_path (Path): input file path
74
    Returns:
75
        List[str]: list of strings
76
    """
77
    if input_path is None:
78
        return []
79
80
    tokens = []
81
    for line in input_path.read_text(encoding="utf-8").splitlines():
82
        tokens.append(line.rstrip("\n"))
83
84
    return tokens
85
86
87
def make_dir(dirpath: str):
88
    """Creates a directory if it doesn't exist"""
89
    if not os.path.exists(dirpath):
90
        os.makedirs(dirpath)
91
92
93
def freeze_params(module: nn.Module) -> None:
94
    """
95
    Freezes the parameters of this module,
96
    i.e. do not update them during training
97
98
    Args:
99
        module (nn.Module): freeze parameters of this module
100
    """
101
    for _, p in module.named_parameters():
102
        p.requires_grad = False
103
104
105
def ddi_binary_relation(rel_type: Union[str, int]) -> int:
106
    """Converts a DDI's relation type into binary
107
108
    Args:
109
        rel_type (str): relation type
110
111
    Returns:
112
        int: 0 if the relation type is `"NO-REL"`, `"0"` or `0`,
113
        1 if the relation type is an string in `["EFFECT", "MECHANISM", "ADVISE", "INT"]` or is an integer `> 0`.
114
    """
115
116
    rt = rel_type
117
    if isinstance(rt, str):
118
        if rt in DDI_ALL_TYPES:
119
            rt = DDI_ALL_TYPES.index(rt)
120
        else:
121
            rt = int(rt)
122
    if rt == 0:
123
        return 0
124
    else:
125
        return 1
126
127
128
def doc_id_n2c2(filepath: str) -> str:
129
    """Extracts the document id of a n2c2 filepath"""
130
    return re.findall(r"\d{2,}", filepath)[-1]
131
132
133
def doc_id_ddi(filepath: str) -> str:
134
    """Extracts the document id of a ddi filepath"""
135
    file_name = filepath.split()[-1]
136
    return file_name[:-4]
137
138
139
def clean_text_ddi(text: str) -> str:
140
    """Cleans text of a text fragment from a ddi document
141
142
    Args:
143
        text (str): text fragment
144
145
    Returns:
146
        str: cleaned text fragment
147
    """
148
    # remove more than one space
149
    text = re.sub(r"[\s]+", " ", text)
150
151
    # include space after ;
152
    text = re.sub(r";", "; ", text)
153
154
    return text
155
156
157
def clean_text_n2c2(text: str) -> str:
158
    """Cleans text of a text fragment from a n2c2 document
159
160
    Args:
161
        text (str): text fragment
162
163
    Returns:
164
        str: cleaned text fragment
165
    """
166
167
    # remove head and tail spaces
168
    # text = text.strip()
169
170
    # remove newlines
171
    text = re.sub(r"\n", " ", text)
172
173
    # substitute annonymizations by their type
174
    for repl, pattern in N2C2_ANNONYM_PATTERNS.items():
175
        text = re.sub(pattern, repl, text)
176
177
    # remove not matching annonymizations
178
    text = re.sub(r"\[\*\*[^\*]+\*\*\]", "", text)
179
180
    # remove more than one space
181
    text = re.sub(r"[\s]+", " ", text)
182
183
    # replace two points by one
184
    text = re.sub(r"\.\.", ".", text)
185
186
    return text
187
188
189
def files_n2c2():
190
    """Loads the filepaths of the n2c2 dataset splits"""
191
    splits = {}
192
193
    for split in ["train", "test"]:
194
        files = glob(pjoin(N2C2_PATH, split, "*.txt"))
195
        splits[split] = list(map(lambda file: file[:-4], files))
196
197
    return splits
198
199
200
def files_ddi():
201
    """Loads the filepaths of the DDI corpus splits"""
202
    splits = {}
203
204
    for split in ["train", "test"]:
205
        if split == "train":
206
            splits["train"] = glob(pjoin(DDI_PATH, split, "DrugBank", "*.xml")) + glob(
207
                pjoin(DDI_PATH, split, "MedLine", "*.xml")
208
            )
209
210
        else:
211
            splits["test"] = glob(
212
                pjoin(DDI_PATH, split, "re", "DrugBank", "*.xml")
213
            ) + glob(pjoin(DDI_PATH, split, "re", "MedLine", "*.xml"))
214
215
    return splits