|
a |
|
b/unimol/tasks/unimol.py |
|
|
1 |
# Copyright (c) DP Technology. |
|
|
2 |
# This source code is licensed under the MIT license found in the |
|
|
3 |
# LICENSE file in the root directory of this source tree. |
|
|
4 |
|
|
|
5 |
import logging |
|
|
6 |
import os |
|
|
7 |
|
|
|
8 |
import numpy as np |
|
|
9 |
from unicore.data import ( |
|
|
10 |
Dictionary, |
|
|
11 |
NestedDictionaryDataset, |
|
|
12 |
AppendTokenDataset, |
|
|
13 |
PrependTokenDataset, |
|
|
14 |
RightPadDataset, |
|
|
15 |
EpochShuffleDataset, |
|
|
16 |
TokenizeDataset, |
|
|
17 |
RightPadDataset2D, |
|
|
18 |
FromNumpyDataset, |
|
|
19 |
RawArrayDataset, |
|
|
20 |
) |
|
|
21 |
from unimol.data import ( |
|
|
22 |
KeyDataset, |
|
|
23 |
ConformerSampleDataset, |
|
|
24 |
DistanceDataset, |
|
|
25 |
EdgeTypeDataset, |
|
|
26 |
MaskPointsDataset, |
|
|
27 |
RemoveHydrogenDataset, |
|
|
28 |
AtomTypeDataset, |
|
|
29 |
NormalizeDataset, |
|
|
30 |
CroppingDataset, |
|
|
31 |
RightPadDatasetCoord, |
|
|
32 |
Add2DConformerDataset, |
|
|
33 |
LMDBDataset, |
|
|
34 |
) |
|
|
35 |
from unicore.tasks import UnicoreTask, register_task |
|
|
36 |
|
|
|
37 |
|
|
|
38 |
logger = logging.getLogger(__name__) |
|
|
39 |
|
|
|
40 |
|
|
|
41 |
@register_task("unimol") |
|
|
42 |
class UniMolTask(UnicoreTask): |
|
|
43 |
"""Task for training transformer auto-encoder models.""" |
|
|
44 |
|
|
|
45 |
@staticmethod |
|
|
46 |
def add_args(parser): |
|
|
47 |
"""Add task-specific arguments to the parser.""" |
|
|
48 |
parser.add_argument( |
|
|
49 |
"data", |
|
|
50 |
help="colon separated path to data directories list, \ |
|
|
51 |
will be iterated upon during epochs in round-robin manner", |
|
|
52 |
) |
|
|
53 |
parser.add_argument( |
|
|
54 |
"--mask-prob", |
|
|
55 |
default=0.15, |
|
|
56 |
type=float, |
|
|
57 |
help="probability of replacing a token with mask", |
|
|
58 |
) |
|
|
59 |
parser.add_argument( |
|
|
60 |
"--leave-unmasked-prob", |
|
|
61 |
default=0.05, |
|
|
62 |
type=float, |
|
|
63 |
help="probability that a masked token is unmasked", |
|
|
64 |
) |
|
|
65 |
parser.add_argument( |
|
|
66 |
"--random-token-prob", |
|
|
67 |
default=0.05, |
|
|
68 |
type=float, |
|
|
69 |
help="probability of replacing a token with a random token", |
|
|
70 |
) |
|
|
71 |
parser.add_argument( |
|
|
72 |
"--noise-type", |
|
|
73 |
default="uniform", |
|
|
74 |
choices=["trunc_normal", "uniform", "normal", "none"], |
|
|
75 |
help="noise type in coordinate noise", |
|
|
76 |
) |
|
|
77 |
parser.add_argument( |
|
|
78 |
"--noise", |
|
|
79 |
default=1.0, |
|
|
80 |
type=float, |
|
|
81 |
help="coordinate noise for masked atoms", |
|
|
82 |
) |
|
|
83 |
parser.add_argument( |
|
|
84 |
"--remove-hydrogen", |
|
|
85 |
action="store_true", |
|
|
86 |
help="remove hydrogen atoms", |
|
|
87 |
) |
|
|
88 |
parser.add_argument( |
|
|
89 |
"--remove-polar-hydrogen", |
|
|
90 |
action="store_true", |
|
|
91 |
help="remove polar hydrogen atoms", |
|
|
92 |
) |
|
|
93 |
parser.add_argument( |
|
|
94 |
"--max-atoms", |
|
|
95 |
type=int, |
|
|
96 |
default=256, |
|
|
97 |
help="selected maximum number of atoms in a molecule", |
|
|
98 |
) |
|
|
99 |
parser.add_argument( |
|
|
100 |
"--dict-name", |
|
|
101 |
default="dict.txt", |
|
|
102 |
help="dictionary file", |
|
|
103 |
) |
|
|
104 |
parser.add_argument( |
|
|
105 |
"--only-polar", |
|
|
106 |
default=1, |
|
|
107 |
type=int, |
|
|
108 |
help="1: only polar hydrogen ; -1: all hydrogen ; 0: remove all hydrogen ", |
|
|
109 |
) |
|
|
110 |
|
|
|
111 |
def __init__(self, args, dictionary): |
|
|
112 |
super().__init__(args) |
|
|
113 |
self.dictionary = dictionary |
|
|
114 |
self.seed = args.seed |
|
|
115 |
# add mask token |
|
|
116 |
self.mask_idx = dictionary.add_symbol("[MASK]", is_special=True) |
|
|
117 |
if self.args.only_polar > 0: |
|
|
118 |
self.args.remove_polar_hydrogen = True |
|
|
119 |
elif args.only_polar < 0: |
|
|
120 |
self.args.remove_polar_hydrogen = False |
|
|
121 |
else: |
|
|
122 |
self.args.remove_hydrogen = True |
|
|
123 |
|
|
|
124 |
@classmethod |
|
|
125 |
def setup_task(cls, args, **kwargs): |
|
|
126 |
dictionary = Dictionary.load(os.path.join(args.data, args.dict_name)) |
|
|
127 |
logger.info("dictionary: {} types".format(len(dictionary))) |
|
|
128 |
return cls(args, dictionary) |
|
|
129 |
|
|
|
130 |
def load_dataset(self, split, combine=False, **kwargs): |
|
|
131 |
"""Load a given dataset split. |
|
|
132 |
Args: |
|
|
133 |
split (str): name of the split (e.g., train, valid, test) |
|
|
134 |
""" |
|
|
135 |
split_path = os.path.join(self.args.data, split + ".lmdb") |
|
|
136 |
|
|
|
137 |
raw_dataset = LMDBDataset(split_path) |
|
|
138 |
|
|
|
139 |
def one_dataset(raw_dataset, coord_seed, mask_seed): |
|
|
140 |
if self.args.mode =='train': |
|
|
141 |
raw_dataset = Add2DConformerDataset( |
|
|
142 |
raw_dataset, "smi", "atoms", "coordinates" |
|
|
143 |
) |
|
|
144 |
smi_dataset = KeyDataset(raw_dataset, "smi") |
|
|
145 |
dataset = ConformerSampleDataset( |
|
|
146 |
raw_dataset, coord_seed, "atoms", "coordinates" |
|
|
147 |
) |
|
|
148 |
dataset = AtomTypeDataset(raw_dataset, dataset) |
|
|
149 |
dataset = RemoveHydrogenDataset( |
|
|
150 |
dataset, |
|
|
151 |
"atoms", |
|
|
152 |
"coordinates", |
|
|
153 |
self.args.remove_hydrogen, |
|
|
154 |
self.args.remove_polar_hydrogen, |
|
|
155 |
) |
|
|
156 |
dataset = CroppingDataset( |
|
|
157 |
dataset, self.seed, "atoms", "coordinates", self.args.max_atoms |
|
|
158 |
) |
|
|
159 |
dataset = NormalizeDataset(dataset, "coordinates", normalize_coord=True) |
|
|
160 |
token_dataset = KeyDataset(dataset, "atoms") |
|
|
161 |
token_dataset = TokenizeDataset( |
|
|
162 |
token_dataset, self.dictionary, max_seq_len=self.args.max_seq_len |
|
|
163 |
) |
|
|
164 |
coord_dataset = KeyDataset(dataset, "coordinates") |
|
|
165 |
expand_dataset = MaskPointsDataset( |
|
|
166 |
token_dataset, |
|
|
167 |
coord_dataset, |
|
|
168 |
self.dictionary, |
|
|
169 |
pad_idx=self.dictionary.pad(), |
|
|
170 |
mask_idx=self.mask_idx, |
|
|
171 |
noise_type=self.args.noise_type, |
|
|
172 |
noise=self.args.noise, |
|
|
173 |
seed=mask_seed, |
|
|
174 |
mask_prob=self.args.mask_prob, |
|
|
175 |
leave_unmasked_prob=self.args.leave_unmasked_prob, |
|
|
176 |
random_token_prob=self.args.random_token_prob, |
|
|
177 |
) |
|
|
178 |
|
|
|
179 |
def PrependAndAppend(dataset, pre_token, app_token): |
|
|
180 |
dataset = PrependTokenDataset(dataset, pre_token) |
|
|
181 |
return AppendTokenDataset(dataset, app_token) |
|
|
182 |
|
|
|
183 |
encoder_token_dataset = KeyDataset(expand_dataset, "atoms") |
|
|
184 |
encoder_target_dataset = KeyDataset(expand_dataset, "targets") |
|
|
185 |
encoder_coord_dataset = KeyDataset(expand_dataset, "coordinates") |
|
|
186 |
|
|
|
187 |
src_dataset = PrependAndAppend( |
|
|
188 |
encoder_token_dataset, self.dictionary.bos(), self.dictionary.eos() |
|
|
189 |
) |
|
|
190 |
tgt_dataset = PrependAndAppend( |
|
|
191 |
encoder_target_dataset, self.dictionary.pad(), self.dictionary.pad() |
|
|
192 |
) |
|
|
193 |
encoder_coord_dataset = PrependAndAppend(encoder_coord_dataset, 0.0, 0.0) |
|
|
194 |
encoder_distance_dataset = DistanceDataset(encoder_coord_dataset) |
|
|
195 |
|
|
|
196 |
edge_type = EdgeTypeDataset(src_dataset, len(self.dictionary)) |
|
|
197 |
coord_dataset = FromNumpyDataset(coord_dataset) |
|
|
198 |
coord_dataset = PrependAndAppend(coord_dataset, 0.0, 0.0) |
|
|
199 |
distance_dataset = DistanceDataset(coord_dataset) |
|
|
200 |
return { |
|
|
201 |
"src_tokens": RightPadDataset( |
|
|
202 |
src_dataset, |
|
|
203 |
pad_idx=self.dictionary.pad(), |
|
|
204 |
), |
|
|
205 |
"src_coord": RightPadDatasetCoord( |
|
|
206 |
encoder_coord_dataset, |
|
|
207 |
pad_idx=0, |
|
|
208 |
), |
|
|
209 |
"src_distance": RightPadDataset2D( |
|
|
210 |
encoder_distance_dataset, |
|
|
211 |
pad_idx=0, |
|
|
212 |
), |
|
|
213 |
"src_edge_type": RightPadDataset2D( |
|
|
214 |
edge_type, |
|
|
215 |
pad_idx=0, |
|
|
216 |
), |
|
|
217 |
}, { |
|
|
218 |
"tokens_target": RightPadDataset( |
|
|
219 |
tgt_dataset, pad_idx=self.dictionary.pad() |
|
|
220 |
), |
|
|
221 |
"distance_target": RightPadDataset2D(distance_dataset, pad_idx=0), |
|
|
222 |
"coord_target": RightPadDatasetCoord(coord_dataset, pad_idx=0), |
|
|
223 |
"smi_name": RawArrayDataset(smi_dataset), |
|
|
224 |
} |
|
|
225 |
|
|
|
226 |
net_input, target = one_dataset(raw_dataset, self.args.seed, self.args.seed) |
|
|
227 |
dataset = {"net_input": net_input, "target": target} |
|
|
228 |
dataset = NestedDictionaryDataset(dataset) |
|
|
229 |
if split in ["train", "train.small"]: |
|
|
230 |
dataset = EpochShuffleDataset(dataset, len(dataset), self.args.seed) |
|
|
231 |
self.datasets[split] = dataset |
|
|
232 |
|
|
|
233 |
def build_model(self, args): |
|
|
234 |
from unicore import models |
|
|
235 |
|
|
|
236 |
model = models.build_model(args, self) |
|
|
237 |
return model |