|
a |
|
b/src/vocabulary.py |
|
|
1 |
# coding: utf-8 |
|
|
2 |
""" |
|
|
3 |
Vocabulary module |
|
|
4 |
|
|
|
5 |
Source: https://github.com/joeynmt/joeynmt/blob/main/joeynmt/vocabulary.py |
|
|
6 |
""" |
|
|
7 |
|
|
|
8 |
# Base Dependencies |
|
|
9 |
# ----------------- |
|
|
10 |
import sys |
|
|
11 |
import logging |
|
|
12 |
import numpy as np |
|
|
13 |
from collections import Counter |
|
|
14 |
from pathlib import Path |
|
|
15 |
from typing import Dict, List, Tuple, Optional |
|
|
16 |
|
|
|
17 |
# Local Dependencies |
|
|
18 |
# ------------------ |
|
|
19 |
from constants import ( |
|
|
20 |
BOS_ID, |
|
|
21 |
BOS_TOKEN, |
|
|
22 |
EOS_ID, |
|
|
23 |
EOS_TOKEN, |
|
|
24 |
PAD_ID, |
|
|
25 |
PAD_TOKEN, |
|
|
26 |
UNK_ID, |
|
|
27 |
UNK_TOKEN, |
|
|
28 |
) |
|
|
29 |
from models.relation_collection import RelationCollection |
|
|
30 |
from utils import read_list_from_file, write_list_to_file |
|
|
31 |
|
|
|
32 |
# Constants |
|
|
33 |
# --------- |
|
|
34 |
from constants import DATASETS_PATHS, N2C2_VOCAB_PATH, DDI_VOCAB_PATH |
|
|
35 |
|
|
|
36 |
VOC_MIN_FREQ = 10 |
|
|
37 |
|
|
|
38 |
|
|
|
39 |
logger = logging.getLogger(__name__) |
|
|
40 |
|
|
|
41 |
|
|
|
42 |
class Vocabulary: |
|
|
43 |
"""Vocabulary represents mapping between tokens and indices.""" |
|
|
44 |
|
|
|
45 |
def __init__(self, tokens: List[str]) -> None: |
|
|
46 |
""" |
|
|
47 |
Create vocabulary from list of tokens. |
|
|
48 |
Special tokens are added if not already in list. |
|
|
49 |
|
|
|
50 |
Args: |
|
|
51 |
tokens (List[str]): list of tokens |
|
|
52 |
""" |
|
|
53 |
# warning: stoi grows with unknown tokens, don't use for saving or size |
|
|
54 |
|
|
|
55 |
# special symbols |
|
|
56 |
self.specials = [PAD_TOKEN, BOS_TOKEN, EOS_TOKEN, UNK_TOKEN] |
|
|
57 |
|
|
|
58 |
# don't allow to access _stoi and _itos outside of this class |
|
|
59 |
self._stoi: Dict[str, int] = {} # string to index |
|
|
60 |
self._itos: List[str] = [] # index to string |
|
|
61 |
|
|
|
62 |
# construct |
|
|
63 |
self.add_tokens(tokens=self.specials + tokens) |
|
|
64 |
assert len(self._stoi) == len(self._itos) |
|
|
65 |
|
|
|
66 |
# assign after stoi is built |
|
|
67 |
self.pad_index = self.lookup(PAD_TOKEN) |
|
|
68 |
self.bos_index = self.lookup(BOS_TOKEN) |
|
|
69 |
self.eos_index = self.lookup(EOS_TOKEN) |
|
|
70 |
self.unk_index = self.lookup(UNK_TOKEN) |
|
|
71 |
assert self.pad_index == PAD_ID |
|
|
72 |
assert self.bos_index == BOS_ID |
|
|
73 |
assert self.eos_index == EOS_ID |
|
|
74 |
assert self.unk_index == UNK_ID |
|
|
75 |
assert self._itos[UNK_ID] == UNK_TOKEN |
|
|
76 |
|
|
|
77 |
def add_tokens(self, tokens: List[str]) -> None: |
|
|
78 |
""" |
|
|
79 |
Add list of tokens to vocabulary |
|
|
80 |
|
|
|
81 |
Args: |
|
|
82 |
tokens (List[str]): list of tokens to add to the vocabulary |
|
|
83 |
""" |
|
|
84 |
for t in tokens: |
|
|
85 |
new_index = len(self._itos) |
|
|
86 |
# add to vocab if not already there |
|
|
87 |
if t not in self._itos: |
|
|
88 |
self._itos.append(t) |
|
|
89 |
self._stoi[t] = new_index |
|
|
90 |
|
|
|
91 |
def to_file(self, file: Path) -> None: |
|
|
92 |
""" |
|
|
93 |
Save the vocabulary to a file, by writing token with index i in line i. |
|
|
94 |
|
|
|
95 |
Args: |
|
|
96 |
file (Path): path to file where the vocabulary is written |
|
|
97 |
""" |
|
|
98 |
write_list_to_file(file, self._itos) |
|
|
99 |
|
|
|
100 |
def is_unk(self, token: str) -> bool: |
|
|
101 |
""" |
|
|
102 |
Check whether a token is covered by the vocabulary |
|
|
103 |
|
|
|
104 |
Args: |
|
|
105 |
token (str): |
|
|
106 |
Returns: |
|
|
107 |
bool: True if covered, False otherwise |
|
|
108 |
""" |
|
|
109 |
return self.lookup(token) == UNK_ID |
|
|
110 |
|
|
|
111 |
def lookup(self, token: str) -> int: |
|
|
112 |
""" |
|
|
113 |
look up the encoding dictionary. (needed for multiprocessing) |
|
|
114 |
|
|
|
115 |
Args: |
|
|
116 |
token (str): surface str |
|
|
117 |
Returns: |
|
|
118 |
int: token id |
|
|
119 |
""" |
|
|
120 |
return self._stoi.get(token, UNK_ID) |
|
|
121 |
|
|
|
122 |
def __len__(self) -> int: |
|
|
123 |
return len(self._itos) |
|
|
124 |
|
|
|
125 |
def __eq__(self, other) -> bool: |
|
|
126 |
if isinstance(other, Vocabulary): |
|
|
127 |
return self._itos == other._itos |
|
|
128 |
return False |
|
|
129 |
|
|
|
130 |
def array_to_sentence( |
|
|
131 |
self, array: np.ndarray, cut_at_eos: bool = True, skip_pad: bool = True |
|
|
132 |
) -> List[str]: |
|
|
133 |
""" |
|
|
134 |
Converts an array of IDs to a sentence, optionally cutting the result off at the |
|
|
135 |
end-of-sequence token. |
|
|
136 |
|
|
|
137 |
Args: |
|
|
138 |
array (numpy.ndarray): 1D array containing indices |
|
|
139 |
cut_at_eos (bool): cut the decoded sentences at the first <eos> |
|
|
140 |
skip_pad (bool): skip generated <pad> tokens |
|
|
141 |
|
|
|
142 |
Returns: |
|
|
143 |
List[str]: list of strings (tokens) |
|
|
144 |
""" |
|
|
145 |
sentence = [] |
|
|
146 |
for i in array: |
|
|
147 |
s = self._itos[i] |
|
|
148 |
if skip_pad and s == PAD_TOKEN: |
|
|
149 |
continue |
|
|
150 |
sentence.append(s) |
|
|
151 |
# break at the position AFTER eos |
|
|
152 |
if cut_at_eos and s == EOS_TOKEN: |
|
|
153 |
break |
|
|
154 |
return sentence |
|
|
155 |
|
|
|
156 |
def arrays_to_sentences( |
|
|
157 |
self, arrays: np.ndarray, cut_at_eos: bool = True, skip_pad: bool = True |
|
|
158 |
) -> List[List[str]]: |
|
|
159 |
""" |
|
|
160 |
Convert multiple arrays containing sequences of token IDs to their sentences, |
|
|
161 |
optionally cutting them off at the end-of-sequence token. |
|
|
162 |
|
|
|
163 |
Args: |
|
|
164 |
arrays (numpy.ndarray): 2D array containing indices |
|
|
165 |
cut_at_eos (bool): cut the decoded sentences at the first <eos> |
|
|
166 |
skip_pad (bool): skip generated <pad> tokens |
|
|
167 |
Returns: |
|
|
168 |
List[List[str]]: list of list of strings (tokens) |
|
|
169 |
""" |
|
|
170 |
return [ |
|
|
171 |
self.array_to_sentence( |
|
|
172 |
array=array, cut_at_eos=cut_at_eos, skip_pad=skip_pad |
|
|
173 |
) |
|
|
174 |
for array in arrays |
|
|
175 |
] |
|
|
176 |
|
|
|
177 |
def sentences_to_ids( |
|
|
178 |
self, |
|
|
179 |
sentences: List[List[str]], |
|
|
180 |
padded: bool = False, |
|
|
181 |
bos: bool = False, |
|
|
182 |
eos: bool = False, |
|
|
183 |
) -> Tuple[List[List[int]], List[int]]: |
|
|
184 |
""" |
|
|
185 |
Encode sentences to indices and pad sequences to the maximum length of the |
|
|
186 |
sentences given if necessary |
|
|
187 |
|
|
|
188 |
Args: |
|
|
189 |
sentences List[List[str]]: list of tokenized sentences |
|
|
190 |
|
|
|
191 |
Returns: |
|
|
192 |
- padded ids |
|
|
193 |
- original lengths before padding |
|
|
194 |
""" |
|
|
195 |
max_len = max([len(sent) for sent in sentences]) |
|
|
196 |
if bos: |
|
|
197 |
max_len += 1 |
|
|
198 |
if eos: |
|
|
199 |
max_len += 1 |
|
|
200 |
sentences_enc, lengths = [], [] |
|
|
201 |
for sent in sentences: |
|
|
202 |
encoded = [self.lookup(s) for s in sent] |
|
|
203 |
if bos: |
|
|
204 |
encoded = [self.bos_index] + encoded |
|
|
205 |
if eos: |
|
|
206 |
encoded = encoded + [self.eos_index] |
|
|
207 |
if padded: |
|
|
208 |
offset = max(0, max_len - len(encoded)) |
|
|
209 |
sentences_enc.append(encoded + [self.pad_index] * offset) |
|
|
210 |
else: |
|
|
211 |
sentences_enc.append(encoded) |
|
|
212 |
lengths.append(len(encoded)) |
|
|
213 |
return sentences_enc, lengths |
|
|
214 |
|
|
|
215 |
def log_vocab(self, k: int) -> str: |
|
|
216 |
"""first k vocab entities""" |
|
|
217 |
return " ".join(f"({i}) {t}" for i, t in enumerate(self._itos[:k])) |
|
|
218 |
|
|
|
219 |
def __repr__(self) -> str: |
|
|
220 |
return ( |
|
|
221 |
f"{self.__class__.__name__}(len={self.__len__()}, " |
|
|
222 |
f"specials={self.specials})" |
|
|
223 |
) |
|
|
224 |
|
|
|
225 |
@staticmethod |
|
|
226 |
def sort_and_cut( |
|
|
227 |
counter: Counter, max_size: int = sys.maxsize, min_freq: int = -1 |
|
|
228 |
) -> List[str]: |
|
|
229 |
""" |
|
|
230 |
Cut counter to most frequent, sorted numerically and alphabetically |
|
|
231 |
|
|
|
232 |
Args: |
|
|
233 |
counter (Counter): flattened token list in Counter object |
|
|
234 |
max_size (int): maximum size of vocabulary |
|
|
235 |
min_freq (int): minimum frequency for an item to be included |
|
|
236 |
|
|
|
237 |
Returns: |
|
|
238 |
List[str]: valid tokens |
|
|
239 |
""" |
|
|
240 |
# filter counter by min frequency |
|
|
241 |
if min_freq > -1: |
|
|
242 |
counter = Counter({t: c for t, c in counter.items() if c >= min_freq}) |
|
|
243 |
|
|
|
244 |
# sort by frequency, then alphabetically |
|
|
245 |
tokens_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0]) |
|
|
246 |
tokens_and_frequencies.sort(key=lambda tup: tup[1], reverse=True) |
|
|
247 |
|
|
|
248 |
# cut off |
|
|
249 |
vocab_tokens = [i[0] for i in tokens_and_frequencies[:max_size]] |
|
|
250 |
assert len(vocab_tokens) <= max_size, (len(vocab_tokens), max_size) |
|
|
251 |
return vocab_tokens |
|
|
252 |
|
|
|
253 |
@staticmethod |
|
|
254 |
def build_vocab( |
|
|
255 |
cfg: Dict, collection: Optional[RelationCollection] = None |
|
|
256 |
) -> "Vocabulary": |
|
|
257 |
""" |
|
|
258 |
Builds vocabulary either from file or sentences. |
|
|
259 |
|
|
|
260 |
Args: |
|
|
261 |
cfg (Dict): data cfg |
|
|
262 |
|
|
|
263 |
Returns: |
|
|
264 |
Vocabulary: created from either `tokens` or `vocab_file` |
|
|
265 |
""" |
|
|
266 |
vocab_file = cfg.get("voc_file", None) |
|
|
267 |
min_freq = cfg.get("voc_min_freq", 1) # min freq for an item to be included |
|
|
268 |
max_size = int(cfg.get("voc_limit", sys.maxsize)) # max size of vocabulary |
|
|
269 |
assert max_size > 0 |
|
|
270 |
|
|
|
271 |
if vocab_file is not None: |
|
|
272 |
# load it from file (not to apply `sort_and_cut()`) |
|
|
273 |
unique_tokens = read_list_from_file(Path(vocab_file)) |
|
|
274 |
|
|
|
275 |
elif collection is not None: |
|
|
276 |
# tokenize sentences |
|
|
277 |
tokens = [] |
|
|
278 |
for doc in collection.tokens: |
|
|
279 |
for t in doc: |
|
|
280 |
tokens.append(t.text.lower()) |
|
|
281 |
|
|
|
282 |
# newly create unique token list (language-wise) |
|
|
283 |
counter = Counter(tokens) |
|
|
284 |
unique_tokens = Vocabulary.sort_and_cut(counter, max_size, min_freq) |
|
|
285 |
else: |
|
|
286 |
raise Exception("Please provide a vocab file path or a relation collection.") |
|
|
287 |
|
|
|
288 |
vocab = Vocabulary(unique_tokens) |
|
|
289 |
assert len(vocab) <= max_size + len(vocab.specials), (len(vocab), max_size) |
|
|
290 |
|
|
|
291 |
# check for all except for UNK token whether they are OOVs |
|
|
292 |
for s in vocab.specials: |
|
|
293 |
assert s == UNK_TOKEN or not vocab.is_unk(s) |
|
|
294 |
|
|
|
295 |
return vocab |
|
|
296 |
|
|
|
297 |
@staticmethod |
|
|
298 |
def create_vocabulary(dataset: str, train_collection: RelationCollection, save_to_disk: bool = True) -> "Vocabulary": |
|
|
299 |
"""Creates the vocabulary of a dataset |
|
|
300 |
|
|
|
301 |
Args: |
|
|
302 |
dataset (str): dataset's name |
|
|
303 |
train_collection (RelationCollection): train split of the dataset |
|
|
304 |
|
|
|
305 |
Returns: |
|
|
306 |
Vocabulary: _description_ |
|
|
307 |
""" |
|
|
308 |
# configuration |
|
|
309 |
cfg = { |
|
|
310 |
"voc_min_freq": VOC_MIN_FREQ, |
|
|
311 |
} |
|
|
312 |
# create vocabulary |
|
|
313 |
vocabulary = Vocabulary.build_vocab(cfg=cfg, collection=train_collection) |
|
|
314 |
print( |
|
|
315 |
"Vocabulary created for {} dataset: {} tokens".format(dataset, len(vocabulary)) |
|
|
316 |
) |
|
|
317 |
|
|
|
318 |
# save vocab to file |
|
|
319 |
if save_to_disk: |
|
|
320 |
vocab_file = DATASETS_PATHS[dataset] |
|
|
321 |
vocabulary.to_file(vocab_file) |
|
|
322 |
|
|
|
323 |
return vocabulary |
|
|
324 |
|
|
|
325 |
|
|
|
326 |
def load_vocab(dataset: str) -> "Vocabulary": |
|
|
327 |
"""Loads the vocabulary of a dataset |
|
|
328 |
|
|
|
329 |
Args: |
|
|
330 |
dataset (str): dataset's name |
|
|
331 |
|
|
|
332 |
Returns: |
|
|
333 |
Vocabulary: vocabulary of the dataset |
|
|
334 |
""" |
|
|
335 |
path = {"n2c2": N2C2_VOCAB_PATH, "DDI": DDI_VOCAB_PATH}[dataset] |
|
|
336 |
|
|
|
337 |
return Vocabulary(read_list_from_file(path)) |