Switch to side-by-side view

--- a
+++ b/src/ml_models/bilstm/embeddings.py
@@ -0,0 +1,202 @@
+# coding: utf-8
+"""
+Embedding module
+
+"""
+
+# Base Dependencies
+# -----------------
+import math
+import logging
+from pathlib import Path
+from typing import Dict, Optional
+
+# Local Dependencies
+# ------------------
+from utils import freeze_params
+from vocabulary import Vocabulary
+
+# PyTorch Dependencies
+# --------------------
+import torch
+from torch import Tensor, nn
+
+
+logger = logging.getLogger(__name__)
+
+
+# Embeddings Class
+# ----------------
+class Embeddings(nn.Module):
+    """
+    Simple embeddings class
+
+    Source: https://github.com/joeynmt/joeynmt/blob/main/joeynmt/embeddings.py
+    """
+
+    def __init__(
+        self,
+        embedding_dim: int = 64,
+        scale: bool = False,
+        vocab_size: int = 0,
+        padding_idx: Optional[int] = 0,
+        freeze: bool = False,
+        **kwargs,
+    ):
+        """Creates a new embedding for the vocabulary
+
+        Args:
+            embedding_dim (int, optional): the embedding dimension. Defaults to 64.
+            scale (bool, optional): indicates if the embeddings will be scale tiems the sqrt of their dimension. Defaults to False.
+            vocab_size (int, optional): size of the vocabulary, i.e., input dimension. Defaults to 0.
+            padding_idx (int, optional): index used for padding. Defaults to 1.
+            freeze (bool, optional): indicates if the embeddings are trained (False) or left untouched (True). Defaults to False.
+        """
+        # pylint: disable=unused-argument
+        super().__init__()
+
+        self.embedding_dim = embedding_dim
+        self.scale = scale
+        self.vocab_size = vocab_size
+        self.lut = nn.Embedding(vocab_size, self.embedding_dim, padding_idx=padding_idx)
+
+        if freeze:
+            freeze_params(self)
+
+    def forward(self, x: Tensor) -> Tensor:
+        """Perform lookup for input `x` in the embedding table.
+
+        Args:
+            x (Tensor): index in the vocabulary
+        Returns:
+            embedded representation for `x`
+        """
+
+        if self.scale:
+            return self.lut(x) * math.sqrt(self.embedding_dim)
+        return self.lut(x)
+
+    def __repr__(self) -> str:
+        return (
+            f"{self.__class__.__name__}("
+            f"embedding_dim={self.embedding_dim}, "
+            f"vocab_size={self.vocab_size})"
+        )
+
+    # from fairseq
+    def load_from_file(self, embed_path: Path, vocab: Vocabulary) -> None:
+        """Loads pretrained embedding weights from text file.
+        - First line is expected to contain vocabulary size and dimension.
+          The dimension has to match the model's specified embedding size,
+          the vocabulary size is used in logging only.
+        - Each line should contain word and embedding weights
+          separated by spaces.
+        - The pretrained vocabulary items that are not part of the
+          joeynmt's vocabulary will be ignored (not loaded from the file).
+        - The initialization (specified in config["model"]["embed_initializer"])
+          of joeynmt's vocabulary items that are not part of the
+          pretrained vocabulary will be kept (not overwritten in this func).
+        - This function should be called after initialization!
+        Example:
+            2 5
+            the -0.0230 -0.0264  0.0287  0.0171  0.1403
+            at -0.0395 -0.1286  0.0275  0.0254 -0.0932
+
+        Args:
+            embed_path (Path): embedding weights text file
+            vocab (Vocabulary): Vocabulary object
+        """
+        # pylint: disable=logging-too-many-args
+        unk_in = False
+        bos_in = False
+        eos_in = False
+        embed_dict: Dict[int, Tensor] = {}
+        # parse file
+        with embed_path.open("r", encoding="utf-8", errors="ignore") as f_embed:
+            vocab_size, d = map(int, f_embed.readline().split())
+            assert self.embedding_dim == d, "Embedding dimension doesn't match."
+            for line in f_embed.readlines():
+                tokens = line.rstrip().split(" ")
+                if tokens[0] in vocab.specials or not vocab.is_unk(tokens[0]):
+                    if vocab.lookup(tokens[0]) == vocab.unk_index:
+                        unk_in = True
+                    # elif vocab.lookup(tokens[0]) == vocab.bos_index:
+                    #     bos_in = True
+                    # elif vocab.lookup(tokens[0]) == vocab.eos_index:
+                    #     eos_in = True
+
+                    embed_dict[vocab.lookup(tokens[0])] = torch.FloatTensor(
+                        [float(t) for t in tokens[1:]]
+                    )
+
+            logger.warning(
+                "Loaded {} of {} ({}) tokens om the pretrained WE.".format(
+                    len(embed_dict),
+                    len(vocab),
+                    len(embed_dict) / len(vocab),
+                )
+            )
+
+        # assign
+        for idx, weights in embed_dict.items():
+            if idx < self.vocab_size:
+                assert self.embedding_dim == len(weights)
+                self.lut.weight.data[idx] = weights
+
+        if not unk_in:
+            self.lut.weight.data[vocab.unk_index] = torch.mean(
+                self.lut.weight.data, axis=0
+            )
+
+        logger.warning(
+            "Loaded {} of {} ({}) tokens of the vocabulary.".format(
+                len(embed_dict),
+                len(vocab),
+                len(embed_dict) / len(vocab),
+            )
+        )
+
+
+# RDEmbeddings Class
+# -----------------
+class RDEmbeddings(Embeddings):
+    def __init__(
+        self,
+        input_dim: int = 0,
+        embedding_dim: int = 64,
+        scale: bool = False,
+        freeze: bool = False,
+        **kwargs,
+    ):
+        """Relative Distance Embedding
+
+        Args:
+            input_dim (int, optional): the maximum absolute value of positions. Defaults to 0.
+            embedding_dim (int, optional): the embedding dimension. Defaults to 64.
+            scale (bool, optional): indicates if the embeddings will be scale tiems the sqrt of their dimension. Defaults to False.
+            freeze (bool, optional): indicates if the embeddings are trained (False) or left untouched (True). Defaults to False.
+        """
+        self.input_dim = input_dim
+
+        super().__init__(
+            embedding_dim=embedding_dim,
+            scale=scale,
+            vocab_size=(self.input_dim * 2 + 1),
+            padding_idx=None,
+            freeze=freeze,
+            **kwargs,
+        )
+
+    def forward(self, x: Tensor) -> Tensor:
+        """Perform lookup for input `x` in the embedding table.
+
+        Args:
+            x (Tensor): index in the vocabulary
+        Returns:
+            embedded representation for `x`
+        """
+
+        # delimits relative distance values to the input dimension
+        x = torch.clamp(x, min=-self.input_dim, max=self.input_dim) + self.input_dim
+
+        return super().forward(x)