[735bb5]: / src / utils.py

Download this file

216 lines (157 with data), 5.3 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
# Base Dependencies
# -----------------
import functools
import numpy as np
import operator
import os
import random
import re
from glob import glob
from os.path import join as pjoin
from pathlib import Path
from typing import List, Any, Union
# Local Dependencies
# ------------------
from constants import N2C2_PATH, DDI_PATH, N2C2_ANNONYM_PATTERNS, DDI_ALL_TYPES
# 3rd-Party Dependencies
# ----------------------
import torch
from torch import nn
from transformers import set_seed as transformers_set_seed
def set_seed(seed: int) -> None:
"""Sets the random seed for modules torch, numpy and random.
Args:
seed (int): random seed
"""
transformers_set_seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
torch.backends.cudnn.deterministic = True
torch.cuda.manual_seed_all(seed)
def flatten(array: List[List[Any]]) -> List[Any]:
"""
Flattens a nested 2D list. faster even with a very long array than
[item for subarray in array for item in subarray] or newarray.extend().
Args:
array (List[List[Any]]): a nested list
Returns:
List[Any]: flattened list
"""
return functools.reduce(operator.iconcat, array, [])
def write_list_to_file(output_path: Path, array: List[Any]) -> None:
"""
Writes list of str to file in `output_path`.
Args:
output_path (Path): output file path
array (List[Any]): list of strings
"""
with output_path.open("w", encoding="utf-8") as opened_file:
for entry in array:
opened_file.write(f"{entry}\n")
def read_list_from_file(input_path: Path) -> List[str]:
"""
Reads list of str from file in `input_path`.
Args:
input_path (Path): input file path
Returns:
List[str]: list of strings
"""
if input_path is None:
return []
tokens = []
for line in input_path.read_text(encoding="utf-8").splitlines():
tokens.append(line.rstrip("\n"))
return tokens
def make_dir(dirpath: str):
"""Creates a directory if it doesn't exist"""
if not os.path.exists(dirpath):
os.makedirs(dirpath)
def freeze_params(module: nn.Module) -> None:
"""
Freezes the parameters of this module,
i.e. do not update them during training
Args:
module (nn.Module): freeze parameters of this module
"""
for _, p in module.named_parameters():
p.requires_grad = False
def ddi_binary_relation(rel_type: Union[str, int]) -> int:
"""Converts a DDI's relation type into binary
Args:
rel_type (str): relation type
Returns:
int: 0 if the relation type is `"NO-REL"`, `"0"` or `0`,
1 if the relation type is an string in `["EFFECT", "MECHANISM", "ADVISE", "INT"]` or is an integer `> 0`.
"""
rt = rel_type
if isinstance(rt, str):
if rt in DDI_ALL_TYPES:
rt = DDI_ALL_TYPES.index(rt)
else:
rt = int(rt)
if rt == 0:
return 0
else:
return 1
def doc_id_n2c2(filepath: str) -> str:
"""Extracts the document id of a n2c2 filepath"""
return re.findall(r"\d{2,}", filepath)[-1]
def doc_id_ddi(filepath: str) -> str:
"""Extracts the document id of a ddi filepath"""
file_name = filepath.split()[-1]
return file_name[:-4]
def clean_text_ddi(text: str) -> str:
"""Cleans text of a text fragment from a ddi document
Args:
text (str): text fragment
Returns:
str: cleaned text fragment
"""
# remove more than one space
text = re.sub(r"[\s]+", " ", text)
# include space after ;
text = re.sub(r";", "; ", text)
return text
def clean_text_n2c2(text: str) -> str:
"""Cleans text of a text fragment from a n2c2 document
Args:
text (str): text fragment
Returns:
str: cleaned text fragment
"""
# remove head and tail spaces
# text = text.strip()
# remove newlines
text = re.sub(r"\n", " ", text)
# substitute annonymizations by their type
for repl, pattern in N2C2_ANNONYM_PATTERNS.items():
text = re.sub(pattern, repl, text)
# remove not matching annonymizations
text = re.sub(r"\[\*\*[^\*]+\*\*\]", "", text)
# remove more than one space
text = re.sub(r"[\s]+", " ", text)
# replace two points by one
text = re.sub(r"\.\.", ".", text)
return text
def files_n2c2():
"""Loads the filepaths of the n2c2 dataset splits"""
splits = {}
for split in ["train", "test"]:
files = glob(pjoin(N2C2_PATH, split, "*.txt"))
splits[split] = list(map(lambda file: file[:-4], files))
return splits
def files_ddi():
"""Loads the filepaths of the DDI corpus splits"""
splits = {}
for split in ["train", "test"]:
if split == "train":
splits["train"] = glob(pjoin(DDI_PATH, split, "DrugBank", "*.xml")) + glob(
pjoin(DDI_PATH, split, "MedLine", "*.xml")
)
else:
splits["test"] = glob(
pjoin(DDI_PATH, split, "re", "DrugBank", "*.xml")
) + glob(pjoin(DDI_PATH, split, "re", "MedLine", "*.xml"))
return splits