|
a |
|
b/unimol/data/dictionary.py |
|
|
1 |
# Copyright (c) DP Technology. |
|
|
2 |
# Copyright (c) Facebook, Inc. and its affiliates. |
|
|
3 |
# |
|
|
4 |
# This source code is licensed under the MIT license found in the |
|
|
5 |
# LICENSE file in the root directory of this source tree. |
|
|
6 |
import logging |
|
|
7 |
|
|
|
8 |
import numpy as np |
|
|
9 |
|
|
|
10 |
logger = logging.getLogger(__name__) # pylint: disable=invalid-name |
|
|
11 |
|
|
|
12 |
class DecoderDictionary: |
|
|
13 |
"""A mapping from symbols to consecutive integers""" |
|
|
14 |
|
|
|
15 |
def __init__( |
|
|
16 |
self, |
|
|
17 |
*, # begin keyword-only arguments |
|
|
18 |
bos="[CLS]", |
|
|
19 |
pad="[PAD]", |
|
|
20 |
eos="[SEP]", |
|
|
21 |
unk="[UNK]", |
|
|
22 |
extra_special_symbols=None, |
|
|
23 |
): |
|
|
24 |
self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos |
|
|
25 |
self.symbols = [] |
|
|
26 |
self.count = [] |
|
|
27 |
self.indices = {} |
|
|
28 |
self.idx2sym = {} |
|
|
29 |
self.specials = set() |
|
|
30 |
self.specials.add(bos) |
|
|
31 |
self.specials.add(unk) |
|
|
32 |
self.specials.add(pad) |
|
|
33 |
self.specials.add(eos) |
|
|
34 |
|
|
|
35 |
def __eq__(self, other): |
|
|
36 |
return self.indices == other.indices |
|
|
37 |
|
|
|
38 |
def __getitem__(self, idx): |
|
|
39 |
if idx < len(self.symbols): |
|
|
40 |
return self.symbols[idx] |
|
|
41 |
return self.unk_word |
|
|
42 |
|
|
|
43 |
def __len__(self): |
|
|
44 |
"""Returns the number of symbols in the dictionary""" |
|
|
45 |
return len(self.symbols) |
|
|
46 |
|
|
|
47 |
def __contains__(self, sym): |
|
|
48 |
return sym in self.indices |
|
|
49 |
|
|
|
50 |
def vec_index(self, a): |
|
|
51 |
return np.vectorize(self.index)(a) |
|
|
52 |
|
|
|
53 |
def index(self, sym): |
|
|
54 |
"""Returns the index of the specified symbol""" |
|
|
55 |
assert isinstance(sym, str) |
|
|
56 |
if sym in self.indices: |
|
|
57 |
return self.indices[sym] |
|
|
58 |
return self.indices[self.unk_word] |
|
|
59 |
|
|
|
60 |
def index2symbol(self, idx): |
|
|
61 |
"""Returns the corresponding symbol of the specified index""" |
|
|
62 |
assert isinstance(idx, int) |
|
|
63 |
if idx in self.idx2sym: |
|
|
64 |
return self.idx2sym[idx] |
|
|
65 |
return self.unk_word |
|
|
66 |
|
|
|
67 |
def special_index(self): |
|
|
68 |
return [self.index(x) for x in self.specials] |
|
|
69 |
|
|
|
70 |
def add_symbol(self, word, n=1, overwrite=False, is_special=False): |
|
|
71 |
"""Adds a word to the dictionary""" |
|
|
72 |
if is_special: |
|
|
73 |
self.specials.add(word) |
|
|
74 |
if word in self.indices and not overwrite: |
|
|
75 |
idx = self.indices[word] |
|
|
76 |
self.count[idx] = self.count[idx] + n |
|
|
77 |
return idx |
|
|
78 |
else: |
|
|
79 |
idx = len(self.symbols) |
|
|
80 |
self.indices[word] = idx |
|
|
81 |
self.idx2sym[idx] = word |
|
|
82 |
self.symbols.append(word) |
|
|
83 |
self.count.append(n) |
|
|
84 |
return idx |
|
|
85 |
|
|
|
86 |
def bos(self): |
|
|
87 |
"""Helper to get index of beginning-of-sentence symbol""" |
|
|
88 |
return self.index(self.bos_word) |
|
|
89 |
|
|
|
90 |
def pad(self): |
|
|
91 |
"""Helper to get index of pad symbol""" |
|
|
92 |
return self.index(self.pad_word) |
|
|
93 |
|
|
|
94 |
def eos(self): |
|
|
95 |
"""Helper to get index of end-of-sentence symbol""" |
|
|
96 |
return self.index(self.eos_word) |
|
|
97 |
|
|
|
98 |
def unk(self): |
|
|
99 |
"""Helper to get index of unk symbol""" |
|
|
100 |
return self.index(self.unk_word) |
|
|
101 |
|
|
|
102 |
@classmethod |
|
|
103 |
def load(cls, f): |
|
|
104 |
"""Loads the dictionary from a text file with the format: |
|
|
105 |
|
|
|
106 |
``` |
|
|
107 |
<symbol0> <count0> |
|
|
108 |
<symbol1> <count1> |
|
|
109 |
... |
|
|
110 |
``` |
|
|
111 |
""" |
|
|
112 |
d = cls() |
|
|
113 |
d.add_from_file(f) |
|
|
114 |
return d |
|
|
115 |
|
|
|
116 |
def add_from_file(self, f): |
|
|
117 |
""" |
|
|
118 |
Loads a pre-existing dictionary from a text file and adds its symbols |
|
|
119 |
to this instance. |
|
|
120 |
""" |
|
|
121 |
if isinstance(f, str): |
|
|
122 |
try: |
|
|
123 |
with open(f, "r", encoding="utf-8") as fd: |
|
|
124 |
self.add_from_file(fd) |
|
|
125 |
except FileNotFoundError as fnfe: |
|
|
126 |
raise fnfe |
|
|
127 |
except UnicodeError: |
|
|
128 |
raise Exception( |
|
|
129 |
"Incorrect encoding detected in {}, please " |
|
|
130 |
"rebuild the dataset".format(f) |
|
|
131 |
) |
|
|
132 |
return |
|
|
133 |
|
|
|
134 |
lines = f.readlines() |
|
|
135 |
|
|
|
136 |
for line_idx, line in enumerate(lines): |
|
|
137 |
try: |
|
|
138 |
splits = line.rstrip().rsplit(" ", 1) |
|
|
139 |
line = splits[0] |
|
|
140 |
field = splits[1] if len(splits) > 1 else str(len(lines) - line_idx) |
|
|
141 |
if field == "#overwrite": |
|
|
142 |
overwrite = True |
|
|
143 |
line, field = line.rsplit(" ", 1) |
|
|
144 |
else: |
|
|
145 |
overwrite = False |
|
|
146 |
count = int(field) |
|
|
147 |
word = line |
|
|
148 |
if word in self and not overwrite: |
|
|
149 |
logger.info( |
|
|
150 |
"Duplicate word found when loading Dictionary: '{}', index is {}.".format(word, self.indices[word]) |
|
|
151 |
) |
|
|
152 |
else: |
|
|
153 |
self.add_symbol(word, n=count, overwrite=overwrite) |
|
|
154 |
except ValueError: |
|
|
155 |
raise ValueError( |
|
|
156 |
"Incorrect dictionary format, expected '<token> <cnt> [flags]'" |
|
|
157 |
) |