Diff of /src/vocabulary.py [000000] .. [735bb5]

Switch to side-by-side view

--- a
+++ b/src/vocabulary.py
@@ -0,0 +1,337 @@
+# coding: utf-8
+"""
+Vocabulary module
+
+Source: https://github.com/joeynmt/joeynmt/blob/main/joeynmt/vocabulary.py
+"""
+
+# Base Dependencies
+# -----------------
+import sys
+import logging
+import numpy as np
+from collections import Counter
+from pathlib import Path
+from typing import Dict, List, Tuple, Optional
+
+# Local Dependencies
+# ------------------
+from constants import (
+    BOS_ID,
+    BOS_TOKEN,
+    EOS_ID,
+    EOS_TOKEN,
+    PAD_ID,
+    PAD_TOKEN,
+    UNK_ID,
+    UNK_TOKEN,
+)
+from models.relation_collection import RelationCollection
+from utils import read_list_from_file, write_list_to_file
+
+# Constants
+# ---------
+from constants import DATASETS_PATHS, N2C2_VOCAB_PATH, DDI_VOCAB_PATH
+
+VOC_MIN_FREQ = 10
+
+
+logger = logging.getLogger(__name__)
+
+
+class Vocabulary:
+    """Vocabulary represents mapping between tokens and indices."""
+
+    def __init__(self, tokens: List[str]) -> None:
+        """
+        Create vocabulary from list of tokens.
+        Special tokens are added if not already in list.
+
+        Args:
+            tokens (List[str]): list of tokens
+        """
+        # warning: stoi grows with unknown tokens, don't use for saving or size
+
+        # special symbols
+        self.specials = [PAD_TOKEN, BOS_TOKEN, EOS_TOKEN, UNK_TOKEN]
+
+        # don't allow to access _stoi and _itos outside of this class
+        self._stoi: Dict[str, int] = {}  # string to index
+        self._itos: List[str] = []  # index to string
+
+        # construct
+        self.add_tokens(tokens=self.specials + tokens)
+        assert len(self._stoi) == len(self._itos)
+
+        # assign after stoi is built
+        self.pad_index = self.lookup(PAD_TOKEN)
+        self.bos_index = self.lookup(BOS_TOKEN)
+        self.eos_index = self.lookup(EOS_TOKEN)
+        self.unk_index = self.lookup(UNK_TOKEN)
+        assert self.pad_index == PAD_ID
+        assert self.bos_index == BOS_ID
+        assert self.eos_index == EOS_ID
+        assert self.unk_index == UNK_ID
+        assert self._itos[UNK_ID] == UNK_TOKEN
+
+    def add_tokens(self, tokens: List[str]) -> None:
+        """
+        Add list of tokens to vocabulary
+
+        Args:
+            tokens (List[str]): list of tokens to add to the vocabulary
+        """
+        for t in tokens:
+            new_index = len(self._itos)
+            # add to vocab if not already there
+            if t not in self._itos:
+                self._itos.append(t)
+                self._stoi[t] = new_index
+
+    def to_file(self, file: Path) -> None:
+        """
+        Save the vocabulary to a file, by writing token with index i in line i.
+
+        Args:
+            file (Path): path to file where the vocabulary is written
+        """
+        write_list_to_file(file, self._itos)
+
+    def is_unk(self, token: str) -> bool:
+        """
+        Check whether a token is covered by the vocabulary
+
+        Args:
+            token (str):
+        Returns:
+            bool: True if covered, False otherwise
+        """
+        return self.lookup(token) == UNK_ID
+
+    def lookup(self, token: str) -> int:
+        """
+        look up the encoding dictionary. (needed for multiprocessing)
+
+        Args:
+            token (str): surface str
+        Returns:
+            int: token id
+        """
+        return self._stoi.get(token, UNK_ID)
+
+    def __len__(self) -> int:
+        return len(self._itos)
+
+    def __eq__(self, other) -> bool:
+        if isinstance(other, Vocabulary):
+            return self._itos == other._itos
+        return False
+
+    def array_to_sentence(
+        self, array: np.ndarray, cut_at_eos: bool = True, skip_pad: bool = True
+    ) -> List[str]:
+        """
+        Converts an array of IDs to a sentence, optionally cutting the result off at the
+        end-of-sequence token.
+
+        Args:
+            array (numpy.ndarray): 1D array containing indices
+            cut_at_eos (bool): cut the decoded sentences at the first <eos>
+            skip_pad (bool): skip generated <pad> tokens
+
+        Returns:
+            List[str]: list of strings (tokens)
+        """
+        sentence = []
+        for i in array:
+            s = self._itos[i]
+            if skip_pad and s == PAD_TOKEN:
+                continue
+            sentence.append(s)
+            # break at the position AFTER eos
+            if cut_at_eos and s == EOS_TOKEN:
+                break
+        return sentence
+
+    def arrays_to_sentences(
+        self, arrays: np.ndarray, cut_at_eos: bool = True, skip_pad: bool = True
+    ) -> List[List[str]]:
+        """
+        Convert multiple arrays containing sequences of token IDs to their sentences,
+        optionally cutting them off at the end-of-sequence token.
+
+        Args:
+            arrays (numpy.ndarray): 2D array containing indices
+            cut_at_eos (bool): cut the decoded sentences at the first <eos>
+            skip_pad (bool): skip generated <pad> tokens
+        Returns:
+            List[List[str]]: list of list of strings (tokens)
+        """
+        return [
+            self.array_to_sentence(
+                array=array, cut_at_eos=cut_at_eos, skip_pad=skip_pad
+            )
+            for array in arrays
+        ]
+
+    def sentences_to_ids(
+        self,
+        sentences: List[List[str]],
+        padded: bool = False,
+        bos: bool = False,
+        eos: bool = False,
+    ) -> Tuple[List[List[int]], List[int]]:
+        """
+        Encode sentences to indices and pad sequences to the maximum length of the
+        sentences given if necessary
+
+        Args:
+            sentences List[List[str]]: list of tokenized sentences
+
+        Returns:
+            - padded ids
+            - original lengths before padding
+        """
+        max_len = max([len(sent) for sent in sentences])
+        if bos:
+            max_len += 1
+        if eos:
+            max_len += 1
+        sentences_enc, lengths = [], []
+        for sent in sentences:
+            encoded = [self.lookup(s) for s in sent]
+            if bos:
+                encoded = [self.bos_index] + encoded
+            if eos:
+                encoded = encoded + [self.eos_index]
+            if padded:
+                offset = max(0, max_len - len(encoded))
+                sentences_enc.append(encoded + [self.pad_index] * offset)
+            else:
+                sentences_enc.append(encoded)
+            lengths.append(len(encoded))
+        return sentences_enc, lengths
+
+    def log_vocab(self, k: int) -> str:
+        """first k vocab entities"""
+        return " ".join(f"({i}) {t}" for i, t in enumerate(self._itos[:k]))
+
+    def __repr__(self) -> str:
+        return (
+            f"{self.__class__.__name__}(len={self.__len__()}, "
+            f"specials={self.specials})"
+        )
+
+    @staticmethod
+    def sort_and_cut(
+        counter: Counter, max_size: int = sys.maxsize, min_freq: int = -1
+    ) -> List[str]:
+        """
+        Cut counter to most frequent, sorted numerically and alphabetically
+
+        Args:
+            counter (Counter): flattened token list in Counter object
+            max_size (int): maximum size of vocabulary
+            min_freq (int): minimum frequency for an item to be included
+
+        Returns:
+            List[str]: valid tokens
+        """
+        # filter counter by min frequency
+        if min_freq > -1:
+            counter = Counter({t: c for t, c in counter.items() if c >= min_freq})
+
+        # sort by frequency, then alphabetically
+        tokens_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0])
+        tokens_and_frequencies.sort(key=lambda tup: tup[1], reverse=True)
+
+        # cut off
+        vocab_tokens = [i[0] for i in tokens_and_frequencies[:max_size]]
+        assert len(vocab_tokens) <= max_size, (len(vocab_tokens), max_size)
+        return vocab_tokens
+
+    @staticmethod
+    def build_vocab(
+        cfg: Dict, collection: Optional[RelationCollection] = None
+    ) -> "Vocabulary":
+        """
+        Builds vocabulary either from file or sentences.
+
+        Args:
+            cfg (Dict): data cfg
+
+        Returns:
+            Vocabulary: created from either `tokens` or `vocab_file`
+        """
+        vocab_file = cfg.get("voc_file", None)
+        min_freq = cfg.get("voc_min_freq", 1)  # min freq for an item to be included
+        max_size = int(cfg.get("voc_limit", sys.maxsize))  # max size of vocabulary
+        assert max_size > 0
+
+        if vocab_file is not None:
+            # load it from file (not to apply `sort_and_cut()`)
+            unique_tokens = read_list_from_file(Path(vocab_file))
+
+        elif collection is not None:
+            # tokenize sentences
+            tokens = []
+            for doc in collection.tokens:
+                for t in doc:
+                    tokens.append(t.text.lower())
+
+            # newly create unique token list (language-wise)
+            counter = Counter(tokens)
+            unique_tokens = Vocabulary.sort_and_cut(counter, max_size, min_freq)
+        else:
+            raise Exception("Please provide a vocab file path or a relation collection.")
+
+        vocab = Vocabulary(unique_tokens)
+        assert len(vocab) <= max_size + len(vocab.specials), (len(vocab), max_size)
+
+        # check for all except for UNK token whether they are OOVs
+        for s in vocab.specials:
+            assert s == UNK_TOKEN or not vocab.is_unk(s)
+
+        return vocab
+
+    @staticmethod
+    def create_vocabulary(dataset: str, train_collection: RelationCollection, save_to_disk: bool = True) -> "Vocabulary":
+        """Creates the vocabulary of a dataset
+
+        Args:
+            dataset (str): dataset's name
+            train_collection (RelationCollection): train split of the dataset
+
+        Returns:
+            Vocabulary: _description_
+        """
+        # configuration
+        cfg = {
+            "voc_min_freq": VOC_MIN_FREQ,
+        }
+        # create vocabulary
+        vocabulary = Vocabulary.build_vocab(cfg=cfg, collection=train_collection)
+        print(
+            "Vocabulary created for {} dataset: {} tokens".format(dataset, len(vocabulary))
+        )
+
+        # save vocab to file
+        if save_to_disk:
+            vocab_file = DATASETS_PATHS[dataset]
+            vocabulary.to_file(vocab_file)
+
+        return vocabulary
+
+
+    def load_vocab(dataset: str) -> "Vocabulary":
+        """Loads the vocabulary of a dataset
+
+        Args:
+            dataset (str): dataset's name
+
+        Returns:
+            Vocabulary: vocabulary of the dataset
+        """
+        path = {"n2c2": N2C2_VOCAB_PATH, "DDI": DDI_VOCAB_PATH}[dataset]
+
+        return Vocabulary(read_list_from_file(path))
\ No newline at end of file