a b/unimol/tasks/unimol.py
1
# Copyright (c) DP Technology.
2
# This source code is licensed under the MIT license found in the
3
# LICENSE file in the root directory of this source tree.
4
5
import logging
6
import os
7
8
import numpy as np
9
from unicore.data import (
10
    Dictionary,
11
    NestedDictionaryDataset,
12
    AppendTokenDataset,
13
    PrependTokenDataset,
14
    RightPadDataset,
15
    EpochShuffleDataset,
16
    TokenizeDataset,
17
    RightPadDataset2D,
18
    FromNumpyDataset,
19
    RawArrayDataset,
20
)
21
from unimol.data import (
22
    KeyDataset,
23
    ConformerSampleDataset,
24
    DistanceDataset,
25
    EdgeTypeDataset,
26
    MaskPointsDataset,
27
    RemoveHydrogenDataset,
28
    AtomTypeDataset,
29
    NormalizeDataset,
30
    CroppingDataset,
31
    RightPadDatasetCoord,
32
    Add2DConformerDataset,
33
    LMDBDataset,
34
)
35
from unicore.tasks import UnicoreTask, register_task
36
37
38
logger = logging.getLogger(__name__)
39
40
41
@register_task("unimol")
42
class UniMolTask(UnicoreTask):
43
    """Task for training transformer auto-encoder models."""
44
45
    @staticmethod
46
    def add_args(parser):
47
        """Add task-specific arguments to the parser."""
48
        parser.add_argument(
49
            "data",
50
            help="colon separated path to data directories list, \
51
                            will be iterated upon during epochs in round-robin manner",
52
        )
53
        parser.add_argument(
54
            "--mask-prob",
55
            default=0.15,
56
            type=float,
57
            help="probability of replacing a token with mask",
58
        )
59
        parser.add_argument(
60
            "--leave-unmasked-prob",
61
            default=0.05,
62
            type=float,
63
            help="probability that a masked token is unmasked",
64
        )
65
        parser.add_argument(
66
            "--random-token-prob",
67
            default=0.05,
68
            type=float,
69
            help="probability of replacing a token with a random token",
70
        )
71
        parser.add_argument(
72
            "--noise-type",
73
            default="uniform",
74
            choices=["trunc_normal", "uniform", "normal", "none"],
75
            help="noise type in coordinate noise",
76
        )
77
        parser.add_argument(
78
            "--noise",
79
            default=1.0,
80
            type=float,
81
            help="coordinate noise for masked atoms",
82
        )
83
        parser.add_argument(
84
            "--remove-hydrogen",
85
            action="store_true",
86
            help="remove hydrogen atoms",
87
        )
88
        parser.add_argument(
89
            "--remove-polar-hydrogen",
90
            action="store_true",
91
            help="remove polar hydrogen atoms",
92
        )
93
        parser.add_argument(
94
            "--max-atoms",
95
            type=int,
96
            default=256,
97
            help="selected maximum number of atoms in a molecule",
98
        )
99
        parser.add_argument(
100
            "--dict-name",
101
            default="dict.txt",
102
            help="dictionary file",
103
        )
104
        parser.add_argument(
105
            "--only-polar",
106
            default=1,
107
            type=int,
108
            help="1: only polar hydrogen ; -1: all hydrogen ; 0: remove all hydrogen ",
109
        )
110
111
    def __init__(self, args, dictionary):
112
        super().__init__(args)
113
        self.dictionary = dictionary
114
        self.seed = args.seed
115
        # add mask token
116
        self.mask_idx = dictionary.add_symbol("[MASK]", is_special=True)
117
        if self.args.only_polar > 0:
118
            self.args.remove_polar_hydrogen = True
119
        elif args.only_polar < 0:
120
            self.args.remove_polar_hydrogen = False
121
        else:
122
            self.args.remove_hydrogen = True
123
124
    @classmethod
125
    def setup_task(cls, args, **kwargs):
126
        dictionary = Dictionary.load(os.path.join(args.data, args.dict_name))
127
        logger.info("dictionary: {} types".format(len(dictionary)))
128
        return cls(args, dictionary)
129
130
    def load_dataset(self, split, combine=False, **kwargs):
131
        """Load a given dataset split.
132
        Args:
133
            split (str): name of the split (e.g., train, valid, test)
134
        """
135
        split_path = os.path.join(self.args.data, split + ".lmdb")
136
137
        raw_dataset = LMDBDataset(split_path)
138
139
        def one_dataset(raw_dataset, coord_seed, mask_seed):
140
            if self.args.mode =='train':
141
                raw_dataset = Add2DConformerDataset(
142
                    raw_dataset, "smi", "atoms", "coordinates"
143
                )
144
            smi_dataset = KeyDataset(raw_dataset, "smi")
145
            dataset = ConformerSampleDataset(
146
                raw_dataset, coord_seed, "atoms", "coordinates"
147
            )
148
            dataset = AtomTypeDataset(raw_dataset, dataset)
149
            dataset = RemoveHydrogenDataset(
150
                dataset,
151
                "atoms",
152
                "coordinates",
153
                self.args.remove_hydrogen,
154
                self.args.remove_polar_hydrogen,
155
            )
156
            dataset = CroppingDataset(
157
                dataset, self.seed, "atoms", "coordinates", self.args.max_atoms
158
            )
159
            dataset = NormalizeDataset(dataset, "coordinates", normalize_coord=True)
160
            token_dataset = KeyDataset(dataset, "atoms")
161
            token_dataset = TokenizeDataset(
162
                token_dataset, self.dictionary, max_seq_len=self.args.max_seq_len
163
            )
164
            coord_dataset = KeyDataset(dataset, "coordinates")
165
            expand_dataset = MaskPointsDataset(
166
                token_dataset,
167
                coord_dataset,
168
                self.dictionary,
169
                pad_idx=self.dictionary.pad(),
170
                mask_idx=self.mask_idx,
171
                noise_type=self.args.noise_type,
172
                noise=self.args.noise,
173
                seed=mask_seed,
174
                mask_prob=self.args.mask_prob,
175
                leave_unmasked_prob=self.args.leave_unmasked_prob,
176
                random_token_prob=self.args.random_token_prob,
177
            )
178
179
            def PrependAndAppend(dataset, pre_token, app_token):
180
                dataset = PrependTokenDataset(dataset, pre_token)
181
                return AppendTokenDataset(dataset, app_token)
182
183
            encoder_token_dataset = KeyDataset(expand_dataset, "atoms")
184
            encoder_target_dataset = KeyDataset(expand_dataset, "targets")
185
            encoder_coord_dataset = KeyDataset(expand_dataset, "coordinates")
186
187
            src_dataset = PrependAndAppend(
188
                encoder_token_dataset, self.dictionary.bos(), self.dictionary.eos()
189
            )
190
            tgt_dataset = PrependAndAppend(
191
                encoder_target_dataset, self.dictionary.pad(), self.dictionary.pad()
192
            )
193
            encoder_coord_dataset = PrependAndAppend(encoder_coord_dataset, 0.0, 0.0)
194
            encoder_distance_dataset = DistanceDataset(encoder_coord_dataset)
195
196
            edge_type = EdgeTypeDataset(src_dataset, len(self.dictionary))
197
            coord_dataset = FromNumpyDataset(coord_dataset)
198
            coord_dataset = PrependAndAppend(coord_dataset, 0.0, 0.0)
199
            distance_dataset = DistanceDataset(coord_dataset)
200
            return {
201
                "src_tokens": RightPadDataset(
202
                    src_dataset,
203
                    pad_idx=self.dictionary.pad(),
204
                ),
205
                "src_coord": RightPadDatasetCoord(
206
                    encoder_coord_dataset,
207
                    pad_idx=0,
208
                ),
209
                "src_distance": RightPadDataset2D(
210
                    encoder_distance_dataset,
211
                    pad_idx=0,
212
                ),
213
                "src_edge_type": RightPadDataset2D(
214
                    edge_type,
215
                    pad_idx=0,
216
                ),
217
            }, {
218
                "tokens_target": RightPadDataset(
219
                    tgt_dataset, pad_idx=self.dictionary.pad()
220
                ),
221
                "distance_target": RightPadDataset2D(distance_dataset, pad_idx=0),
222
                "coord_target": RightPadDatasetCoord(coord_dataset, pad_idx=0),
223
                "smi_name": RawArrayDataset(smi_dataset),
224
            }
225
226
        net_input, target = one_dataset(raw_dataset, self.args.seed, self.args.seed)
227
        dataset = {"net_input": net_input, "target": target}
228
        dataset = NestedDictionaryDataset(dataset)
229
        if split in ["train", "train.small"]:
230
            dataset = EpochShuffleDataset(dataset, len(dataset), self.args.seed)
231
        self.datasets[split] = dataset
232
233
    def build_model(self, args):
234
        from unicore import models
235
236
        model = models.build_model(args, self)
237
        return model