Diff of /unimol/tasks/unimol.py [000000] .. [b40915]

Switch to side-by-side view

--- a
+++ b/unimol/tasks/unimol.py
@@ -0,0 +1,237 @@
+# Copyright (c) DP Technology.
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+import os
+
+import numpy as np
+from unicore.data import (
+    Dictionary,
+    NestedDictionaryDataset,
+    AppendTokenDataset,
+    PrependTokenDataset,
+    RightPadDataset,
+    EpochShuffleDataset,
+    TokenizeDataset,
+    RightPadDataset2D,
+    FromNumpyDataset,
+    RawArrayDataset,
+)
+from unimol.data import (
+    KeyDataset,
+    ConformerSampleDataset,
+    DistanceDataset,
+    EdgeTypeDataset,
+    MaskPointsDataset,
+    RemoveHydrogenDataset,
+    AtomTypeDataset,
+    NormalizeDataset,
+    CroppingDataset,
+    RightPadDatasetCoord,
+    Add2DConformerDataset,
+    LMDBDataset,
+)
+from unicore.tasks import UnicoreTask, register_task
+
+
+logger = logging.getLogger(__name__)
+
+
+@register_task("unimol")
+class UniMolTask(UnicoreTask):
+    """Task for training transformer auto-encoder models."""
+
+    @staticmethod
+    def add_args(parser):
+        """Add task-specific arguments to the parser."""
+        parser.add_argument(
+            "data",
+            help="colon separated path to data directories list, \
+                            will be iterated upon during epochs in round-robin manner",
+        )
+        parser.add_argument(
+            "--mask-prob",
+            default=0.15,
+            type=float,
+            help="probability of replacing a token with mask",
+        )
+        parser.add_argument(
+            "--leave-unmasked-prob",
+            default=0.05,
+            type=float,
+            help="probability that a masked token is unmasked",
+        )
+        parser.add_argument(
+            "--random-token-prob",
+            default=0.05,
+            type=float,
+            help="probability of replacing a token with a random token",
+        )
+        parser.add_argument(
+            "--noise-type",
+            default="uniform",
+            choices=["trunc_normal", "uniform", "normal", "none"],
+            help="noise type in coordinate noise",
+        )
+        parser.add_argument(
+            "--noise",
+            default=1.0,
+            type=float,
+            help="coordinate noise for masked atoms",
+        )
+        parser.add_argument(
+            "--remove-hydrogen",
+            action="store_true",
+            help="remove hydrogen atoms",
+        )
+        parser.add_argument(
+            "--remove-polar-hydrogen",
+            action="store_true",
+            help="remove polar hydrogen atoms",
+        )
+        parser.add_argument(
+            "--max-atoms",
+            type=int,
+            default=256,
+            help="selected maximum number of atoms in a molecule",
+        )
+        parser.add_argument(
+            "--dict-name",
+            default="dict.txt",
+            help="dictionary file",
+        )
+        parser.add_argument(
+            "--only-polar",
+            default=1,
+            type=int,
+            help="1: only polar hydrogen ; -1: all hydrogen ; 0: remove all hydrogen ",
+        )
+
+    def __init__(self, args, dictionary):
+        super().__init__(args)
+        self.dictionary = dictionary
+        self.seed = args.seed
+        # add mask token
+        self.mask_idx = dictionary.add_symbol("[MASK]", is_special=True)
+        if self.args.only_polar > 0:
+            self.args.remove_polar_hydrogen = True
+        elif args.only_polar < 0:
+            self.args.remove_polar_hydrogen = False
+        else:
+            self.args.remove_hydrogen = True
+
+    @classmethod
+    def setup_task(cls, args, **kwargs):
+        dictionary = Dictionary.load(os.path.join(args.data, args.dict_name))
+        logger.info("dictionary: {} types".format(len(dictionary)))
+        return cls(args, dictionary)
+
+    def load_dataset(self, split, combine=False, **kwargs):
+        """Load a given dataset split.
+        Args:
+            split (str): name of the split (e.g., train, valid, test)
+        """
+        split_path = os.path.join(self.args.data, split + ".lmdb")
+
+        raw_dataset = LMDBDataset(split_path)
+
+        def one_dataset(raw_dataset, coord_seed, mask_seed):
+            if self.args.mode =='train':
+                raw_dataset = Add2DConformerDataset(
+                    raw_dataset, "smi", "atoms", "coordinates"
+                )
+            smi_dataset = KeyDataset(raw_dataset, "smi")
+            dataset = ConformerSampleDataset(
+                raw_dataset, coord_seed, "atoms", "coordinates"
+            )
+            dataset = AtomTypeDataset(raw_dataset, dataset)
+            dataset = RemoveHydrogenDataset(
+                dataset,
+                "atoms",
+                "coordinates",
+                self.args.remove_hydrogen,
+                self.args.remove_polar_hydrogen,
+            )
+            dataset = CroppingDataset(
+                dataset, self.seed, "atoms", "coordinates", self.args.max_atoms
+            )
+            dataset = NormalizeDataset(dataset, "coordinates", normalize_coord=True)
+            token_dataset = KeyDataset(dataset, "atoms")
+            token_dataset = TokenizeDataset(
+                token_dataset, self.dictionary, max_seq_len=self.args.max_seq_len
+            )
+            coord_dataset = KeyDataset(dataset, "coordinates")
+            expand_dataset = MaskPointsDataset(
+                token_dataset,
+                coord_dataset,
+                self.dictionary,
+                pad_idx=self.dictionary.pad(),
+                mask_idx=self.mask_idx,
+                noise_type=self.args.noise_type,
+                noise=self.args.noise,
+                seed=mask_seed,
+                mask_prob=self.args.mask_prob,
+                leave_unmasked_prob=self.args.leave_unmasked_prob,
+                random_token_prob=self.args.random_token_prob,
+            )
+
+            def PrependAndAppend(dataset, pre_token, app_token):
+                dataset = PrependTokenDataset(dataset, pre_token)
+                return AppendTokenDataset(dataset, app_token)
+
+            encoder_token_dataset = KeyDataset(expand_dataset, "atoms")
+            encoder_target_dataset = KeyDataset(expand_dataset, "targets")
+            encoder_coord_dataset = KeyDataset(expand_dataset, "coordinates")
+
+            src_dataset = PrependAndAppend(
+                encoder_token_dataset, self.dictionary.bos(), self.dictionary.eos()
+            )
+            tgt_dataset = PrependAndAppend(
+                encoder_target_dataset, self.dictionary.pad(), self.dictionary.pad()
+            )
+            encoder_coord_dataset = PrependAndAppend(encoder_coord_dataset, 0.0, 0.0)
+            encoder_distance_dataset = DistanceDataset(encoder_coord_dataset)
+
+            edge_type = EdgeTypeDataset(src_dataset, len(self.dictionary))
+            coord_dataset = FromNumpyDataset(coord_dataset)
+            coord_dataset = PrependAndAppend(coord_dataset, 0.0, 0.0)
+            distance_dataset = DistanceDataset(coord_dataset)
+            return {
+                "src_tokens": RightPadDataset(
+                    src_dataset,
+                    pad_idx=self.dictionary.pad(),
+                ),
+                "src_coord": RightPadDatasetCoord(
+                    encoder_coord_dataset,
+                    pad_idx=0,
+                ),
+                "src_distance": RightPadDataset2D(
+                    encoder_distance_dataset,
+                    pad_idx=0,
+                ),
+                "src_edge_type": RightPadDataset2D(
+                    edge_type,
+                    pad_idx=0,
+                ),
+            }, {
+                "tokens_target": RightPadDataset(
+                    tgt_dataset, pad_idx=self.dictionary.pad()
+                ),
+                "distance_target": RightPadDataset2D(distance_dataset, pad_idx=0),
+                "coord_target": RightPadDatasetCoord(coord_dataset, pad_idx=0),
+                "smi_name": RawArrayDataset(smi_dataset),
+            }
+
+        net_input, target = one_dataset(raw_dataset, self.args.seed, self.args.seed)
+        dataset = {"net_input": net_input, "target": target}
+        dataset = NestedDictionaryDataset(dataset)
+        if split in ["train", "train.small"]:
+            dataset = EpochShuffleDataset(dataset, len(dataset), self.args.seed)
+        self.datasets[split] = dataset
+
+    def build_model(self, args):
+        from unicore import models
+
+        model = models.build_model(args, self)
+        return model