--- a +++ b/src/vocabulary.py @@ -0,0 +1,337 @@ +# coding: utf-8 +""" +Vocabulary module + +Source: https://github.com/joeynmt/joeynmt/blob/main/joeynmt/vocabulary.py +""" + +# Base Dependencies +# ----------------- +import sys +import logging +import numpy as np +from collections import Counter +from pathlib import Path +from typing import Dict, List, Tuple, Optional + +# Local Dependencies +# ------------------ +from constants import ( + BOS_ID, + BOS_TOKEN, + EOS_ID, + EOS_TOKEN, + PAD_ID, + PAD_TOKEN, + UNK_ID, + UNK_TOKEN, +) +from models.relation_collection import RelationCollection +from utils import read_list_from_file, write_list_to_file + +# Constants +# --------- +from constants import DATASETS_PATHS, N2C2_VOCAB_PATH, DDI_VOCAB_PATH + +VOC_MIN_FREQ = 10 + + +logger = logging.getLogger(__name__) + + +class Vocabulary: + """Vocabulary represents mapping between tokens and indices.""" + + def __init__(self, tokens: List[str]) -> None: + """ + Create vocabulary from list of tokens. + Special tokens are added if not already in list. + + Args: + tokens (List[str]): list of tokens + """ + # warning: stoi grows with unknown tokens, don't use for saving or size + + # special symbols + self.specials = [PAD_TOKEN, BOS_TOKEN, EOS_TOKEN, UNK_TOKEN] + + # don't allow to access _stoi and _itos outside of this class + self._stoi: Dict[str, int] = {} # string to index + self._itos: List[str] = [] # index to string + + # construct + self.add_tokens(tokens=self.specials + tokens) + assert len(self._stoi) == len(self._itos) + + # assign after stoi is built + self.pad_index = self.lookup(PAD_TOKEN) + self.bos_index = self.lookup(BOS_TOKEN) + self.eos_index = self.lookup(EOS_TOKEN) + self.unk_index = self.lookup(UNK_TOKEN) + assert self.pad_index == PAD_ID + assert self.bos_index == BOS_ID + assert self.eos_index == EOS_ID + assert self.unk_index == UNK_ID + assert self._itos[UNK_ID] == UNK_TOKEN + + def add_tokens(self, tokens: List[str]) -> None: + """ + Add list of tokens to vocabulary + + Args: + tokens (List[str]): list of tokens to add to the vocabulary + """ + for t in tokens: + new_index = len(self._itos) + # add to vocab if not already there + if t not in self._itos: + self._itos.append(t) + self._stoi[t] = new_index + + def to_file(self, file: Path) -> None: + """ + Save the vocabulary to a file, by writing token with index i in line i. + + Args: + file (Path): path to file where the vocabulary is written + """ + write_list_to_file(file, self._itos) + + def is_unk(self, token: str) -> bool: + """ + Check whether a token is covered by the vocabulary + + Args: + token (str): + Returns: + bool: True if covered, False otherwise + """ + return self.lookup(token) == UNK_ID + + def lookup(self, token: str) -> int: + """ + look up the encoding dictionary. (needed for multiprocessing) + + Args: + token (str): surface str + Returns: + int: token id + """ + return self._stoi.get(token, UNK_ID) + + def __len__(self) -> int: + return len(self._itos) + + def __eq__(self, other) -> bool: + if isinstance(other, Vocabulary): + return self._itos == other._itos + return False + + def array_to_sentence( + self, array: np.ndarray, cut_at_eos: bool = True, skip_pad: bool = True + ) -> List[str]: + """ + Converts an array of IDs to a sentence, optionally cutting the result off at the + end-of-sequence token. + + Args: + array (numpy.ndarray): 1D array containing indices + cut_at_eos (bool): cut the decoded sentences at the first <eos> + skip_pad (bool): skip generated <pad> tokens + + Returns: + List[str]: list of strings (tokens) + """ + sentence = [] + for i in array: + s = self._itos[i] + if skip_pad and s == PAD_TOKEN: + continue + sentence.append(s) + # break at the position AFTER eos + if cut_at_eos and s == EOS_TOKEN: + break + return sentence + + def arrays_to_sentences( + self, arrays: np.ndarray, cut_at_eos: bool = True, skip_pad: bool = True + ) -> List[List[str]]: + """ + Convert multiple arrays containing sequences of token IDs to their sentences, + optionally cutting them off at the end-of-sequence token. + + Args: + arrays (numpy.ndarray): 2D array containing indices + cut_at_eos (bool): cut the decoded sentences at the first <eos> + skip_pad (bool): skip generated <pad> tokens + Returns: + List[List[str]]: list of list of strings (tokens) + """ + return [ + self.array_to_sentence( + array=array, cut_at_eos=cut_at_eos, skip_pad=skip_pad + ) + for array in arrays + ] + + def sentences_to_ids( + self, + sentences: List[List[str]], + padded: bool = False, + bos: bool = False, + eos: bool = False, + ) -> Tuple[List[List[int]], List[int]]: + """ + Encode sentences to indices and pad sequences to the maximum length of the + sentences given if necessary + + Args: + sentences List[List[str]]: list of tokenized sentences + + Returns: + - padded ids + - original lengths before padding + """ + max_len = max([len(sent) for sent in sentences]) + if bos: + max_len += 1 + if eos: + max_len += 1 + sentences_enc, lengths = [], [] + for sent in sentences: + encoded = [self.lookup(s) for s in sent] + if bos: + encoded = [self.bos_index] + encoded + if eos: + encoded = encoded + [self.eos_index] + if padded: + offset = max(0, max_len - len(encoded)) + sentences_enc.append(encoded + [self.pad_index] * offset) + else: + sentences_enc.append(encoded) + lengths.append(len(encoded)) + return sentences_enc, lengths + + def log_vocab(self, k: int) -> str: + """first k vocab entities""" + return " ".join(f"({i}) {t}" for i, t in enumerate(self._itos[:k])) + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(len={self.__len__()}, " + f"specials={self.specials})" + ) + + @staticmethod + def sort_and_cut( + counter: Counter, max_size: int = sys.maxsize, min_freq: int = -1 + ) -> List[str]: + """ + Cut counter to most frequent, sorted numerically and alphabetically + + Args: + counter (Counter): flattened token list in Counter object + max_size (int): maximum size of vocabulary + min_freq (int): minimum frequency for an item to be included + + Returns: + List[str]: valid tokens + """ + # filter counter by min frequency + if min_freq > -1: + counter = Counter({t: c for t, c in counter.items() if c >= min_freq}) + + # sort by frequency, then alphabetically + tokens_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0]) + tokens_and_frequencies.sort(key=lambda tup: tup[1], reverse=True) + + # cut off + vocab_tokens = [i[0] for i in tokens_and_frequencies[:max_size]] + assert len(vocab_tokens) <= max_size, (len(vocab_tokens), max_size) + return vocab_tokens + + @staticmethod + def build_vocab( + cfg: Dict, collection: Optional[RelationCollection] = None + ) -> "Vocabulary": + """ + Builds vocabulary either from file or sentences. + + Args: + cfg (Dict): data cfg + + Returns: + Vocabulary: created from either `tokens` or `vocab_file` + """ + vocab_file = cfg.get("voc_file", None) + min_freq = cfg.get("voc_min_freq", 1) # min freq for an item to be included + max_size = int(cfg.get("voc_limit", sys.maxsize)) # max size of vocabulary + assert max_size > 0 + + if vocab_file is not None: + # load it from file (not to apply `sort_and_cut()`) + unique_tokens = read_list_from_file(Path(vocab_file)) + + elif collection is not None: + # tokenize sentences + tokens = [] + for doc in collection.tokens: + for t in doc: + tokens.append(t.text.lower()) + + # newly create unique token list (language-wise) + counter = Counter(tokens) + unique_tokens = Vocabulary.sort_and_cut(counter, max_size, min_freq) + else: + raise Exception("Please provide a vocab file path or a relation collection.") + + vocab = Vocabulary(unique_tokens) + assert len(vocab) <= max_size + len(vocab.specials), (len(vocab), max_size) + + # check for all except for UNK token whether they are OOVs + for s in vocab.specials: + assert s == UNK_TOKEN or not vocab.is_unk(s) + + return vocab + + @staticmethod + def create_vocabulary(dataset: str, train_collection: RelationCollection, save_to_disk: bool = True) -> "Vocabulary": + """Creates the vocabulary of a dataset + + Args: + dataset (str): dataset's name + train_collection (RelationCollection): train split of the dataset + + Returns: + Vocabulary: _description_ + """ + # configuration + cfg = { + "voc_min_freq": VOC_MIN_FREQ, + } + # create vocabulary + vocabulary = Vocabulary.build_vocab(cfg=cfg, collection=train_collection) + print( + "Vocabulary created for {} dataset: {} tokens".format(dataset, len(vocabulary)) + ) + + # save vocab to file + if save_to_disk: + vocab_file = DATASETS_PATHS[dataset] + vocabulary.to_file(vocab_file) + + return vocabulary + + + def load_vocab(dataset: str) -> "Vocabulary": + """Loads the vocabulary of a dataset + + Args: + dataset (str): dataset's name + + Returns: + Vocabulary: vocabulary of the dataset + """ + path = {"n2c2": N2C2_VOCAB_PATH, "DDI": DDI_VOCAB_PATH}[dataset] + + return Vocabulary(read_list_from_file(path)) \ No newline at end of file