--- a +++ b/unimol/data/mask_points_dataset.py @@ -0,0 +1,267 @@ +# Copyright (c) DP Technology. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from functools import lru_cache + +import numpy as np +import torch +from unicore.data import Dictionary +from unicore.data import BaseWrapperDataset +from . import data_utils + + +class MaskPointsDataset(BaseWrapperDataset): + def __init__( + self, + dataset: torch.utils.data.Dataset, + coord_dataset: torch.utils.data.Dataset, + vocab: Dictionary, + pad_idx: int, + mask_idx: int, + noise_type: str, + noise: float = 1.0, + seed: int = 1, + mask_prob: float = 0.15, + leave_unmasked_prob: float = 0.1, + random_token_prob: float = 0.1, + ): + assert 0.0 < mask_prob < 1.0 + assert 0.0 <= random_token_prob <= 1.0 + assert 0.0 <= leave_unmasked_prob <= 1.0 + assert random_token_prob + leave_unmasked_prob <= 1.0 + + self.dataset = dataset + self.coord_dataset = coord_dataset + self.vocab = vocab + self.pad_idx = pad_idx + self.mask_idx = mask_idx + self.noise_type = noise_type + self.noise = noise + self.seed = seed + self.mask_prob = mask_prob + self.leave_unmasked_prob = leave_unmasked_prob + self.random_token_prob = random_token_prob + + if random_token_prob > 0.0: + weights = np.ones(len(self.vocab)) + weights[vocab.special_index()] = 0 + self.weights = weights / weights.sum() + + self.epoch = None + if self.noise_type == "trunc_normal": + self.noise_f = lambda num_mask: np.clip( + np.random.randn(num_mask, 3) * self.noise, + a_min=-self.noise * 2.0, + a_max=self.noise * 2.0, + ) + elif self.noise_type == "normal": + self.noise_f = lambda num_mask: np.random.randn(num_mask, 3) * self.noise + elif self.noise_type == "uniform": + self.noise_f = lambda num_mask: np.random.uniform( + low=-self.noise, high=self.noise, size=(num_mask, 3) + ) + else: + self.noise_f = lambda num_mask: 0.0 + + def set_epoch(self, epoch, **unused): + super().set_epoch(epoch) + self.coord_dataset.set_epoch(epoch) + self.dataset.set_epoch(epoch) + self.epoch = epoch + + def __getitem__(self, index: int): + return self.__getitem_cached__(self.epoch, index) + + @lru_cache(maxsize=16) + def __getitem_cached__(self, epoch: int, index: int): + ret = {} + with data_utils.numpy_seed(self.seed, epoch, index): + item = self.dataset[index] + coord = self.coord_dataset[index] + sz = len(item) + # don't allow empty sequence + assert sz > 0 + # decide elements to mask + num_mask = int( + # add a random number for probabilistic rounding + self.mask_prob * sz + + np.random.rand() + ) + mask_idc = np.random.choice(sz, num_mask, replace=False) + mask = np.full(sz, False) + mask[mask_idc] = True + ret["targets"] = np.full(len(mask), self.pad_idx) + ret["targets"][mask] = item[mask] + ret["targets"] = torch.from_numpy(ret["targets"]).long() + # decide unmasking and random replacement + rand_or_unmask_prob = self.random_token_prob + self.leave_unmasked_prob + if rand_or_unmask_prob > 0.0: + rand_or_unmask = mask & (np.random.rand(sz) < rand_or_unmask_prob) + if self.random_token_prob == 0.0: + unmask = rand_or_unmask + rand_mask = None + elif self.leave_unmasked_prob == 0.0: + unmask = None + rand_mask = rand_or_unmask + else: + unmask_prob = self.leave_unmasked_prob / rand_or_unmask_prob + decision = np.random.rand(sz) < unmask_prob + unmask = rand_or_unmask & decision + rand_mask = rand_or_unmask & (~decision) + else: + unmask = rand_mask = None + + if unmask is not None: + mask = mask ^ unmask + + new_item = np.copy(item) + new_item[mask] = self.mask_idx + + num_mask = mask.astype(np.int32).sum() + new_coord = np.copy(coord) + new_coord[mask, :] += self.noise_f(num_mask) + + if rand_mask is not None: + num_rand = rand_mask.sum() + if num_rand > 0: + new_item[rand_mask] = np.random.choice( + len(self.vocab), + num_rand, + p=self.weights, + ) + ret["atoms"] = torch.from_numpy(new_item).long() + ret["coordinates"] = torch.from_numpy(new_coord).float() + return ret + + +class MaskPointsPocketDataset(BaseWrapperDataset): + def __init__( + self, + dataset: torch.utils.data.Dataset, + coord_dataset: torch.utils.data.Dataset, + residue_dataset: torch.utils.data.Dataset, + vocab: Dictionary, + pad_idx: int, + mask_idx: int, + noise_type: str, + noise: float = 1.0, + seed: int = 1, + mask_prob: float = 0.15, + leave_unmasked_prob: float = 0.1, + random_token_prob: float = 0.1, + ): + assert 0.0 < mask_prob < 1.0 + assert 0.0 <= random_token_prob <= 1.0 + assert 0.0 <= leave_unmasked_prob <= 1.0 + assert random_token_prob + leave_unmasked_prob <= 1.0 + + self.dataset = dataset + self.coord_dataset = coord_dataset + self.residue_dataset = residue_dataset + self.vocab = vocab + self.pad_idx = pad_idx + self.mask_idx = mask_idx + self.noise_type = noise_type + self.noise = noise + self.seed = seed + self.mask_prob = mask_prob + self.leave_unmasked_prob = leave_unmasked_prob + self.random_token_prob = random_token_prob + + if random_token_prob > 0.0: + weights = np.ones(len(self.vocab)) + weights[vocab.special_index()] = 0 + self.weights = weights / weights.sum() + + self.epoch = None + if self.noise_type == "trunc_normal": + self.noise_f = lambda num_mask: np.clip( + np.random.randn(num_mask, 3) * self.noise, + a_min=-self.noise * 2.0, + a_max=self.noise * 2.0, + ) + elif self.noise_type == "normal": + self.noise_f = lambda num_mask: np.random.randn(num_mask, 3) * self.noise + elif self.noise_type == "uniform": + self.noise_f = lambda num_mask: np.random.uniform( + low=-self.noise, high=self.noise, size=(num_mask, 3) + ) + else: + self.noise_f = lambda num_mask: 0.0 + + def set_epoch(self, epoch, **unused): + super().set_epoch(epoch) + self.coord_dataset.set_epoch(epoch) + self.dataset.set_epoch(epoch) + self.epoch = epoch + + def __getitem__(self, index: int): + return self.__getitem_cached__(self.epoch, index) + + @lru_cache(maxsize=16) + def __getitem_cached__(self, epoch: int, index: int): + ret = {} + with data_utils.numpy_seed(self.seed, epoch, index): + item = self.dataset[index] + coord = self.coord_dataset[index] + sz = len(item) + # don't allow empty sequence + assert sz > 0 + + # mask on the level of residues + residue = self.residue_dataset[index] + res_list = list(set(residue)) + res_sz = len(res_list) + + # decide elements to mask + num_mask = int( + # add a random number for probabilistic rounding + self.mask_prob * res_sz + + np.random.rand() + ) + mask_res = np.random.choice(res_list, num_mask, replace=False).tolist() + mask = np.isin(residue, mask_res) + + ret["targets"] = np.full(len(mask), self.pad_idx) + ret["targets"][mask] = item[mask] + ret["targets"] = torch.from_numpy(ret["targets"]).long() + # decide unmasking and random replacement + rand_or_unmask_prob = self.random_token_prob + self.leave_unmasked_prob + if rand_or_unmask_prob > 0.0: + rand_or_unmask = mask & (np.random.rand(sz) < rand_or_unmask_prob) + if self.random_token_prob == 0.0: + unmask = rand_or_unmask + rand_mask = None + elif self.leave_unmasked_prob == 0.0: + unmask = None + rand_mask = rand_or_unmask + else: + unmask_prob = self.leave_unmasked_prob / rand_or_unmask_prob + decision = np.random.rand(sz) < unmask_prob + unmask = rand_or_unmask & decision + rand_mask = rand_or_unmask & (~decision) + else: + unmask = rand_mask = None + + if unmask is not None: + mask = mask ^ unmask + + new_item = np.copy(item) + new_item[mask] = self.mask_idx + + num_mask = mask.astype(np.int32).sum() + new_coord = np.copy(coord) + new_coord[mask, :] += self.noise_f(num_mask) + + if rand_mask is not None: + num_rand = rand_mask.sum() + if num_rand > 0: + new_item[rand_mask] = np.random.choice( + len(self.vocab), + num_rand, + p=self.weights, + ) + ret["atoms"] = torch.from_numpy(new_item).long() + ret["coordinates"] = torch.from_numpy(new_coord).float() + return ret