Diff of /unimol/data/dictionary.py [000000] .. [b40915]

Switch to unified view

a b/unimol/data/dictionary.py
1
# Copyright (c) DP Technology.
2
# Copyright (c) Facebook, Inc. and its affiliates.
3
#
4
# This source code is licensed under the MIT license found in the
5
# LICENSE file in the root directory of this source tree.
6
import logging
7
8
import numpy as np
9
10
logger = logging.getLogger(__name__)  # pylint: disable=invalid-name
11
12
class DecoderDictionary:
13
    """A mapping from symbols to consecutive integers"""
14
15
    def __init__(
16
        self,
17
        *,  # begin keyword-only arguments
18
        bos="[CLS]",
19
        pad="[PAD]",
20
        eos="[SEP]",
21
        unk="[UNK]",
22
        extra_special_symbols=None,
23
    ):
24
        self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos
25
        self.symbols = []
26
        self.count = []
27
        self.indices = {}
28
        self.idx2sym = {}
29
        self.specials = set()
30
        self.specials.add(bos)
31
        self.specials.add(unk)
32
        self.specials.add(pad)
33
        self.specials.add(eos)
34
35
    def __eq__(self, other):
36
        return self.indices == other.indices
37
38
    def __getitem__(self, idx):
39
        if idx < len(self.symbols):
40
            return self.symbols[idx]
41
        return self.unk_word
42
43
    def __len__(self):
44
        """Returns the number of symbols in the dictionary"""
45
        return len(self.symbols)
46
47
    def __contains__(self, sym):
48
        return sym in self.indices
49
50
    def vec_index(self, a):
51
        return np.vectorize(self.index)(a)
52
53
    def index(self, sym):
54
        """Returns the index of the specified symbol"""
55
        assert isinstance(sym, str)
56
        if sym in self.indices:
57
            return self.indices[sym]
58
        return self.indices[self.unk_word]
59
    
60
    def index2symbol(self, idx):
61
        """Returns the corresponding symbol of the specified index"""
62
        assert isinstance(idx, int)
63
        if idx in self.idx2sym:
64
            return self.idx2sym[idx]
65
        return self.unk_word
66
    
67
    def special_index(self):
68
        return [self.index(x) for x in self.specials]
69
70
    def add_symbol(self, word, n=1, overwrite=False, is_special=False):
71
        """Adds a word to the dictionary"""
72
        if is_special:
73
            self.specials.add(word)
74
        if word in self.indices and not overwrite:
75
            idx = self.indices[word]
76
            self.count[idx] = self.count[idx] + n
77
            return idx
78
        else:
79
            idx = len(self.symbols)
80
            self.indices[word] = idx
81
            self.idx2sym[idx] = word
82
            self.symbols.append(word)
83
            self.count.append(n)
84
            return idx
85
86
    def bos(self):
87
        """Helper to get index of beginning-of-sentence symbol"""
88
        return self.index(self.bos_word)
89
90
    def pad(self):
91
        """Helper to get index of pad symbol"""
92
        return self.index(self.pad_word)
93
94
    def eos(self):
95
        """Helper to get index of end-of-sentence symbol"""
96
        return self.index(self.eos_word)
97
98
    def unk(self):
99
        """Helper to get index of unk symbol"""
100
        return self.index(self.unk_word)
101
102
    @classmethod
103
    def load(cls, f):
104
        """Loads the dictionary from a text file with the format:
105
106
        ```
107
        <symbol0> <count0>
108
        <symbol1> <count1>
109
        ...
110
        ```
111
        """
112
        d = cls()
113
        d.add_from_file(f)
114
        return d
115
116
    def add_from_file(self, f):
117
        """
118
        Loads a pre-existing dictionary from a text file and adds its symbols
119
        to this instance.
120
        """
121
        if isinstance(f, str):
122
            try:
123
                with open(f, "r", encoding="utf-8") as fd:
124
                    self.add_from_file(fd)
125
            except FileNotFoundError as fnfe:
126
                raise fnfe
127
            except UnicodeError:
128
                raise Exception(
129
                    "Incorrect encoding detected in {}, please "
130
                    "rebuild the dataset".format(f)
131
                )
132
            return
133
134
        lines = f.readlines()
135
136
        for line_idx, line in enumerate(lines):
137
            try:
138
                splits = line.rstrip().rsplit(" ", 1)
139
                line = splits[0]
140
                field = splits[1] if len(splits) > 1 else str(len(lines) - line_idx)
141
                if field == "#overwrite":
142
                    overwrite = True
143
                    line, field = line.rsplit(" ", 1)
144
                else:
145
                    overwrite = False
146
                count = int(field)
147
                word = line
148
                if word in self and not overwrite:
149
                    logger.info(
150
                        "Duplicate word found when loading Dictionary: '{}', index is {}.".format(word, self.indices[word])
151
                    )
152
                else:
153
                    self.add_symbol(word, n=count, overwrite=overwrite)
154
            except ValueError:
155
                raise ValueError(
156
                    "Incorrect dictionary format, expected '<token> <cnt> [flags]'"
157
                )