a b/src/vocabulary.py
1
# coding: utf-8
2
"""
3
Vocabulary module
4
5
Source: https://github.com/joeynmt/joeynmt/blob/main/joeynmt/vocabulary.py
6
"""
7
8
# Base Dependencies
9
# -----------------
10
import sys
11
import logging
12
import numpy as np
13
from collections import Counter
14
from pathlib import Path
15
from typing import Dict, List, Tuple, Optional
16
17
# Local Dependencies
18
# ------------------
19
from constants import (
20
    BOS_ID,
21
    BOS_TOKEN,
22
    EOS_ID,
23
    EOS_TOKEN,
24
    PAD_ID,
25
    PAD_TOKEN,
26
    UNK_ID,
27
    UNK_TOKEN,
28
)
29
from models.relation_collection import RelationCollection
30
from utils import read_list_from_file, write_list_to_file
31
32
# Constants
33
# ---------
34
from constants import DATASETS_PATHS, N2C2_VOCAB_PATH, DDI_VOCAB_PATH
35
36
VOC_MIN_FREQ = 10
37
38
39
logger = logging.getLogger(__name__)
40
41
42
class Vocabulary:
43
    """Vocabulary represents mapping between tokens and indices."""
44
45
    def __init__(self, tokens: List[str]) -> None:
46
        """
47
        Create vocabulary from list of tokens.
48
        Special tokens are added if not already in list.
49
50
        Args:
51
            tokens (List[str]): list of tokens
52
        """
53
        # warning: stoi grows with unknown tokens, don't use for saving or size
54
55
        # special symbols
56
        self.specials = [PAD_TOKEN, BOS_TOKEN, EOS_TOKEN, UNK_TOKEN]
57
58
        # don't allow to access _stoi and _itos outside of this class
59
        self._stoi: Dict[str, int] = {}  # string to index
60
        self._itos: List[str] = []  # index to string
61
62
        # construct
63
        self.add_tokens(tokens=self.specials + tokens)
64
        assert len(self._stoi) == len(self._itos)
65
66
        # assign after stoi is built
67
        self.pad_index = self.lookup(PAD_TOKEN)
68
        self.bos_index = self.lookup(BOS_TOKEN)
69
        self.eos_index = self.lookup(EOS_TOKEN)
70
        self.unk_index = self.lookup(UNK_TOKEN)
71
        assert self.pad_index == PAD_ID
72
        assert self.bos_index == BOS_ID
73
        assert self.eos_index == EOS_ID
74
        assert self.unk_index == UNK_ID
75
        assert self._itos[UNK_ID] == UNK_TOKEN
76
77
    def add_tokens(self, tokens: List[str]) -> None:
78
        """
79
        Add list of tokens to vocabulary
80
81
        Args:
82
            tokens (List[str]): list of tokens to add to the vocabulary
83
        """
84
        for t in tokens:
85
            new_index = len(self._itos)
86
            # add to vocab if not already there
87
            if t not in self._itos:
88
                self._itos.append(t)
89
                self._stoi[t] = new_index
90
91
    def to_file(self, file: Path) -> None:
92
        """
93
        Save the vocabulary to a file, by writing token with index i in line i.
94
95
        Args:
96
            file (Path): path to file where the vocabulary is written
97
        """
98
        write_list_to_file(file, self._itos)
99
100
    def is_unk(self, token: str) -> bool:
101
        """
102
        Check whether a token is covered by the vocabulary
103
104
        Args:
105
            token (str):
106
        Returns:
107
            bool: True if covered, False otherwise
108
        """
109
        return self.lookup(token) == UNK_ID
110
111
    def lookup(self, token: str) -> int:
112
        """
113
        look up the encoding dictionary. (needed for multiprocessing)
114
115
        Args:
116
            token (str): surface str
117
        Returns:
118
            int: token id
119
        """
120
        return self._stoi.get(token, UNK_ID)
121
122
    def __len__(self) -> int:
123
        return len(self._itos)
124
125
    def __eq__(self, other) -> bool:
126
        if isinstance(other, Vocabulary):
127
            return self._itos == other._itos
128
        return False
129
130
    def array_to_sentence(
131
        self, array: np.ndarray, cut_at_eos: bool = True, skip_pad: bool = True
132
    ) -> List[str]:
133
        """
134
        Converts an array of IDs to a sentence, optionally cutting the result off at the
135
        end-of-sequence token.
136
137
        Args:
138
            array (numpy.ndarray): 1D array containing indices
139
            cut_at_eos (bool): cut the decoded sentences at the first <eos>
140
            skip_pad (bool): skip generated <pad> tokens
141
142
        Returns:
143
            List[str]: list of strings (tokens)
144
        """
145
        sentence = []
146
        for i in array:
147
            s = self._itos[i]
148
            if skip_pad and s == PAD_TOKEN:
149
                continue
150
            sentence.append(s)
151
            # break at the position AFTER eos
152
            if cut_at_eos and s == EOS_TOKEN:
153
                break
154
        return sentence
155
156
    def arrays_to_sentences(
157
        self, arrays: np.ndarray, cut_at_eos: bool = True, skip_pad: bool = True
158
    ) -> List[List[str]]:
159
        """
160
        Convert multiple arrays containing sequences of token IDs to their sentences,
161
        optionally cutting them off at the end-of-sequence token.
162
163
        Args:
164
            arrays (numpy.ndarray): 2D array containing indices
165
            cut_at_eos (bool): cut the decoded sentences at the first <eos>
166
            skip_pad (bool): skip generated <pad> tokens
167
        Returns:
168
            List[List[str]]: list of list of strings (tokens)
169
        """
170
        return [
171
            self.array_to_sentence(
172
                array=array, cut_at_eos=cut_at_eos, skip_pad=skip_pad
173
            )
174
            for array in arrays
175
        ]
176
177
    def sentences_to_ids(
178
        self,
179
        sentences: List[List[str]],
180
        padded: bool = False,
181
        bos: bool = False,
182
        eos: bool = False,
183
    ) -> Tuple[List[List[int]], List[int]]:
184
        """
185
        Encode sentences to indices and pad sequences to the maximum length of the
186
        sentences given if necessary
187
188
        Args:
189
            sentences List[List[str]]: list of tokenized sentences
190
191
        Returns:
192
            - padded ids
193
            - original lengths before padding
194
        """
195
        max_len = max([len(sent) for sent in sentences])
196
        if bos:
197
            max_len += 1
198
        if eos:
199
            max_len += 1
200
        sentences_enc, lengths = [], []
201
        for sent in sentences:
202
            encoded = [self.lookup(s) for s in sent]
203
            if bos:
204
                encoded = [self.bos_index] + encoded
205
            if eos:
206
                encoded = encoded + [self.eos_index]
207
            if padded:
208
                offset = max(0, max_len - len(encoded))
209
                sentences_enc.append(encoded + [self.pad_index] * offset)
210
            else:
211
                sentences_enc.append(encoded)
212
            lengths.append(len(encoded))
213
        return sentences_enc, lengths
214
215
    def log_vocab(self, k: int) -> str:
216
        """first k vocab entities"""
217
        return " ".join(f"({i}) {t}" for i, t in enumerate(self._itos[:k]))
218
219
    def __repr__(self) -> str:
220
        return (
221
            f"{self.__class__.__name__}(len={self.__len__()}, "
222
            f"specials={self.specials})"
223
        )
224
225
    @staticmethod
226
    def sort_and_cut(
227
        counter: Counter, max_size: int = sys.maxsize, min_freq: int = -1
228
    ) -> List[str]:
229
        """
230
        Cut counter to most frequent, sorted numerically and alphabetically
231
232
        Args:
233
            counter (Counter): flattened token list in Counter object
234
            max_size (int): maximum size of vocabulary
235
            min_freq (int): minimum frequency for an item to be included
236
237
        Returns:
238
            List[str]: valid tokens
239
        """
240
        # filter counter by min frequency
241
        if min_freq > -1:
242
            counter = Counter({t: c for t, c in counter.items() if c >= min_freq})
243
244
        # sort by frequency, then alphabetically
245
        tokens_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0])
246
        tokens_and_frequencies.sort(key=lambda tup: tup[1], reverse=True)
247
248
        # cut off
249
        vocab_tokens = [i[0] for i in tokens_and_frequencies[:max_size]]
250
        assert len(vocab_tokens) <= max_size, (len(vocab_tokens), max_size)
251
        return vocab_tokens
252
253
    @staticmethod
254
    def build_vocab(
255
        cfg: Dict, collection: Optional[RelationCollection] = None
256
    ) -> "Vocabulary":
257
        """
258
        Builds vocabulary either from file or sentences.
259
260
        Args:
261
            cfg (Dict): data cfg
262
263
        Returns:
264
            Vocabulary: created from either `tokens` or `vocab_file`
265
        """
266
        vocab_file = cfg.get("voc_file", None)
267
        min_freq = cfg.get("voc_min_freq", 1)  # min freq for an item to be included
268
        max_size = int(cfg.get("voc_limit", sys.maxsize))  # max size of vocabulary
269
        assert max_size > 0
270
271
        if vocab_file is not None:
272
            # load it from file (not to apply `sort_and_cut()`)
273
            unique_tokens = read_list_from_file(Path(vocab_file))
274
275
        elif collection is not None:
276
            # tokenize sentences
277
            tokens = []
278
            for doc in collection.tokens:
279
                for t in doc:
280
                    tokens.append(t.text.lower())
281
282
            # newly create unique token list (language-wise)
283
            counter = Counter(tokens)
284
            unique_tokens = Vocabulary.sort_and_cut(counter, max_size, min_freq)
285
        else:
286
            raise Exception("Please provide a vocab file path or a relation collection.")
287
288
        vocab = Vocabulary(unique_tokens)
289
        assert len(vocab) <= max_size + len(vocab.specials), (len(vocab), max_size)
290
291
        # check for all except for UNK token whether they are OOVs
292
        for s in vocab.specials:
293
            assert s == UNK_TOKEN or not vocab.is_unk(s)
294
295
        return vocab
296
297
    @staticmethod
298
    def create_vocabulary(dataset: str, train_collection: RelationCollection, save_to_disk: bool = True) -> "Vocabulary":
299
        """Creates the vocabulary of a dataset
300
301
        Args:
302
            dataset (str): dataset's name
303
            train_collection (RelationCollection): train split of the dataset
304
305
        Returns:
306
            Vocabulary: _description_
307
        """
308
        # configuration
309
        cfg = {
310
            "voc_min_freq": VOC_MIN_FREQ,
311
        }
312
        # create vocabulary
313
        vocabulary = Vocabulary.build_vocab(cfg=cfg, collection=train_collection)
314
        print(
315
            "Vocabulary created for {} dataset: {} tokens".format(dataset, len(vocabulary))
316
        )
317
318
        # save vocab to file
319
        if save_to_disk:
320
            vocab_file = DATASETS_PATHS[dataset]
321
            vocabulary.to_file(vocab_file)
322
323
        return vocabulary
324
325
326
    def load_vocab(dataset: str) -> "Vocabulary":
327
        """Loads the vocabulary of a dataset
328
329
        Args:
330
            dataset (str): dataset's name
331
332
        Returns:
333
            Vocabulary: vocabulary of the dataset
334
        """
335
        path = {"n2c2": N2C2_VOCAB_PATH, "DDI": DDI_VOCAB_PATH}[dataset]
336
337
        return Vocabulary(read_list_from_file(path))