Diff of /unimol/data/dictionary.py [000000] .. [b40915]

Switch to side-by-side view

--- a
+++ b/unimol/data/dictionary.py
@@ -0,0 +1,157 @@
+# Copyright (c) DP Technology.
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import logging
+
+import numpy as np
+
+logger = logging.getLogger(__name__)  # pylint: disable=invalid-name
+
+class DecoderDictionary:
+    """A mapping from symbols to consecutive integers"""
+
+    def __init__(
+        self,
+        *,  # begin keyword-only arguments
+        bos="[CLS]",
+        pad="[PAD]",
+        eos="[SEP]",
+        unk="[UNK]",
+        extra_special_symbols=None,
+    ):
+        self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos
+        self.symbols = []
+        self.count = []
+        self.indices = {}
+        self.idx2sym = {}
+        self.specials = set()
+        self.specials.add(bos)
+        self.specials.add(unk)
+        self.specials.add(pad)
+        self.specials.add(eos)
+
+    def __eq__(self, other):
+        return self.indices == other.indices
+
+    def __getitem__(self, idx):
+        if idx < len(self.symbols):
+            return self.symbols[idx]
+        return self.unk_word
+
+    def __len__(self):
+        """Returns the number of symbols in the dictionary"""
+        return len(self.symbols)
+
+    def __contains__(self, sym):
+        return sym in self.indices
+
+    def vec_index(self, a):
+        return np.vectorize(self.index)(a)
+
+    def index(self, sym):
+        """Returns the index of the specified symbol"""
+        assert isinstance(sym, str)
+        if sym in self.indices:
+            return self.indices[sym]
+        return self.indices[self.unk_word]
+    
+    def index2symbol(self, idx):
+        """Returns the corresponding symbol of the specified index"""
+        assert isinstance(idx, int)
+        if idx in self.idx2sym:
+            return self.idx2sym[idx]
+        return self.unk_word
+    
+    def special_index(self):
+        return [self.index(x) for x in self.specials]
+
+    def add_symbol(self, word, n=1, overwrite=False, is_special=False):
+        """Adds a word to the dictionary"""
+        if is_special:
+            self.specials.add(word)
+        if word in self.indices and not overwrite:
+            idx = self.indices[word]
+            self.count[idx] = self.count[idx] + n
+            return idx
+        else:
+            idx = len(self.symbols)
+            self.indices[word] = idx
+            self.idx2sym[idx] = word
+            self.symbols.append(word)
+            self.count.append(n)
+            return idx
+
+    def bos(self):
+        """Helper to get index of beginning-of-sentence symbol"""
+        return self.index(self.bos_word)
+
+    def pad(self):
+        """Helper to get index of pad symbol"""
+        return self.index(self.pad_word)
+
+    def eos(self):
+        """Helper to get index of end-of-sentence symbol"""
+        return self.index(self.eos_word)
+
+    def unk(self):
+        """Helper to get index of unk symbol"""
+        return self.index(self.unk_word)
+
+    @classmethod
+    def load(cls, f):
+        """Loads the dictionary from a text file with the format:
+
+        ```
+        <symbol0> <count0>
+        <symbol1> <count1>
+        ...
+        ```
+        """
+        d = cls()
+        d.add_from_file(f)
+        return d
+
+    def add_from_file(self, f):
+        """
+        Loads a pre-existing dictionary from a text file and adds its symbols
+        to this instance.
+        """
+        if isinstance(f, str):
+            try:
+                with open(f, "r", encoding="utf-8") as fd:
+                    self.add_from_file(fd)
+            except FileNotFoundError as fnfe:
+                raise fnfe
+            except UnicodeError:
+                raise Exception(
+                    "Incorrect encoding detected in {}, please "
+                    "rebuild the dataset".format(f)
+                )
+            return
+
+        lines = f.readlines()
+
+        for line_idx, line in enumerate(lines):
+            try:
+                splits = line.rstrip().rsplit(" ", 1)
+                line = splits[0]
+                field = splits[1] if len(splits) > 1 else str(len(lines) - line_idx)
+                if field == "#overwrite":
+                    overwrite = True
+                    line, field = line.rsplit(" ", 1)
+                else:
+                    overwrite = False
+                count = int(field)
+                word = line
+                if word in self and not overwrite:
+                    logger.info(
+                        "Duplicate word found when loading Dictionary: '{}', index is {}.".format(word, self.indices[word])
+                    )
+                else:
+                    self.add_symbol(word, n=count, overwrite=overwrite)
+            except ValueError:
+                raise ValueError(
+                    "Incorrect dictionary format, expected '<token> <cnt> [flags]'"
+                )