# Copyright (c) DP Technology.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from functools import lru_cache
import numpy as np
import torch
from unicore.data import Dictionary
from unicore.data import BaseWrapperDataset
from . import data_utils
class MaskPointsDataset(BaseWrapperDataset):
def __init__(
self,
dataset: torch.utils.data.Dataset,
coord_dataset: torch.utils.data.Dataset,
vocab: Dictionary,
pad_idx: int,
mask_idx: int,
noise_type: str,
noise: float = 1.0,
seed: int = 1,
mask_prob: float = 0.15,
leave_unmasked_prob: float = 0.1,
random_token_prob: float = 0.1,
):
assert 0.0 < mask_prob < 1.0
assert 0.0 <= random_token_prob <= 1.0
assert 0.0 <= leave_unmasked_prob <= 1.0
assert random_token_prob + leave_unmasked_prob <= 1.0
self.dataset = dataset
self.coord_dataset = coord_dataset
self.vocab = vocab
self.pad_idx = pad_idx
self.mask_idx = mask_idx
self.noise_type = noise_type
self.noise = noise
self.seed = seed
self.mask_prob = mask_prob
self.leave_unmasked_prob = leave_unmasked_prob
self.random_token_prob = random_token_prob
if random_token_prob > 0.0:
weights = np.ones(len(self.vocab))
weights[vocab.special_index()] = 0
self.weights = weights / weights.sum()
self.epoch = None
if self.noise_type == "trunc_normal":
self.noise_f = lambda num_mask: np.clip(
np.random.randn(num_mask, 3) * self.noise,
a_min=-self.noise * 2.0,
a_max=self.noise * 2.0,
)
elif self.noise_type == "normal":
self.noise_f = lambda num_mask: np.random.randn(num_mask, 3) * self.noise
elif self.noise_type == "uniform":
self.noise_f = lambda num_mask: np.random.uniform(
low=-self.noise, high=self.noise, size=(num_mask, 3)
)
else:
self.noise_f = lambda num_mask: 0.0
def set_epoch(self, epoch, **unused):
super().set_epoch(epoch)
self.coord_dataset.set_epoch(epoch)
self.dataset.set_epoch(epoch)
self.epoch = epoch
def __getitem__(self, index: int):
return self.__getitem_cached__(self.epoch, index)
@lru_cache(maxsize=16)
def __getitem_cached__(self, epoch: int, index: int):
ret = {}
with data_utils.numpy_seed(self.seed, epoch, index):
item = self.dataset[index]
coord = self.coord_dataset[index]
sz = len(item)
# don't allow empty sequence
assert sz > 0
# decide elements to mask
num_mask = int(
# add a random number for probabilistic rounding
self.mask_prob * sz
+ np.random.rand()
)
mask_idc = np.random.choice(sz, num_mask, replace=False)
mask = np.full(sz, False)
mask[mask_idc] = True
ret["targets"] = np.full(len(mask), self.pad_idx)
ret["targets"][mask] = item[mask]
ret["targets"] = torch.from_numpy(ret["targets"]).long()
# decide unmasking and random replacement
rand_or_unmask_prob = self.random_token_prob + self.leave_unmasked_prob
if rand_or_unmask_prob > 0.0:
rand_or_unmask = mask & (np.random.rand(sz) < rand_or_unmask_prob)
if self.random_token_prob == 0.0:
unmask = rand_or_unmask
rand_mask = None
elif self.leave_unmasked_prob == 0.0:
unmask = None
rand_mask = rand_or_unmask
else:
unmask_prob = self.leave_unmasked_prob / rand_or_unmask_prob
decision = np.random.rand(sz) < unmask_prob
unmask = rand_or_unmask & decision
rand_mask = rand_or_unmask & (~decision)
else:
unmask = rand_mask = None
if unmask is not None:
mask = mask ^ unmask
new_item = np.copy(item)
new_item[mask] = self.mask_idx
num_mask = mask.astype(np.int32).sum()
new_coord = np.copy(coord)
new_coord[mask, :] += self.noise_f(num_mask)
if rand_mask is not None:
num_rand = rand_mask.sum()
if num_rand > 0:
new_item[rand_mask] = np.random.choice(
len(self.vocab),
num_rand,
p=self.weights,
)
ret["atoms"] = torch.from_numpy(new_item).long()
ret["coordinates"] = torch.from_numpy(new_coord).float()
return ret
class MaskPointsPocketDataset(BaseWrapperDataset):
def __init__(
self,
dataset: torch.utils.data.Dataset,
coord_dataset: torch.utils.data.Dataset,
residue_dataset: torch.utils.data.Dataset,
vocab: Dictionary,
pad_idx: int,
mask_idx: int,
noise_type: str,
noise: float = 1.0,
seed: int = 1,
mask_prob: float = 0.15,
leave_unmasked_prob: float = 0.1,
random_token_prob: float = 0.1,
):
assert 0.0 < mask_prob < 1.0
assert 0.0 <= random_token_prob <= 1.0
assert 0.0 <= leave_unmasked_prob <= 1.0
assert random_token_prob + leave_unmasked_prob <= 1.0
self.dataset = dataset
self.coord_dataset = coord_dataset
self.residue_dataset = residue_dataset
self.vocab = vocab
self.pad_idx = pad_idx
self.mask_idx = mask_idx
self.noise_type = noise_type
self.noise = noise
self.seed = seed
self.mask_prob = mask_prob
self.leave_unmasked_prob = leave_unmasked_prob
self.random_token_prob = random_token_prob
if random_token_prob > 0.0:
weights = np.ones(len(self.vocab))
weights[vocab.special_index()] = 0
self.weights = weights / weights.sum()
self.epoch = None
if self.noise_type == "trunc_normal":
self.noise_f = lambda num_mask: np.clip(
np.random.randn(num_mask, 3) * self.noise,
a_min=-self.noise * 2.0,
a_max=self.noise * 2.0,
)
elif self.noise_type == "normal":
self.noise_f = lambda num_mask: np.random.randn(num_mask, 3) * self.noise
elif self.noise_type == "uniform":
self.noise_f = lambda num_mask: np.random.uniform(
low=-self.noise, high=self.noise, size=(num_mask, 3)
)
else:
self.noise_f = lambda num_mask: 0.0
def set_epoch(self, epoch, **unused):
super().set_epoch(epoch)
self.coord_dataset.set_epoch(epoch)
self.dataset.set_epoch(epoch)
self.epoch = epoch
def __getitem__(self, index: int):
return self.__getitem_cached__(self.epoch, index)
@lru_cache(maxsize=16)
def __getitem_cached__(self, epoch: int, index: int):
ret = {}
with data_utils.numpy_seed(self.seed, epoch, index):
item = self.dataset[index]
coord = self.coord_dataset[index]
sz = len(item)
# don't allow empty sequence
assert sz > 0
# mask on the level of residues
residue = self.residue_dataset[index]
res_list = list(set(residue))
res_sz = len(res_list)
# decide elements to mask
num_mask = int(
# add a random number for probabilistic rounding
self.mask_prob * res_sz
+ np.random.rand()
)
mask_res = np.random.choice(res_list, num_mask, replace=False).tolist()
mask = np.isin(residue, mask_res)
ret["targets"] = np.full(len(mask), self.pad_idx)
ret["targets"][mask] = item[mask]
ret["targets"] = torch.from_numpy(ret["targets"]).long()
# decide unmasking and random replacement
rand_or_unmask_prob = self.random_token_prob + self.leave_unmasked_prob
if rand_or_unmask_prob > 0.0:
rand_or_unmask = mask & (np.random.rand(sz) < rand_or_unmask_prob)
if self.random_token_prob == 0.0:
unmask = rand_or_unmask
rand_mask = None
elif self.leave_unmasked_prob == 0.0:
unmask = None
rand_mask = rand_or_unmask
else:
unmask_prob = self.leave_unmasked_prob / rand_or_unmask_prob
decision = np.random.rand(sz) < unmask_prob
unmask = rand_or_unmask & decision
rand_mask = rand_or_unmask & (~decision)
else:
unmask = rand_mask = None
if unmask is not None:
mask = mask ^ unmask
new_item = np.copy(item)
new_item[mask] = self.mask_idx
num_mask = mask.astype(np.int32).sum()
new_coord = np.copy(coord)
new_coord[mask, :] += self.noise_f(num_mask)
if rand_mask is not None:
num_rand = rand_mask.sum()
if num_rand > 0:
new_item[rand_mask] = np.random.choice(
len(self.vocab),
num_rand,
p=self.weights,
)
ret["atoms"] = torch.from_numpy(new_item).long()
ret["coordinates"] = torch.from_numpy(new_coord).float()
return ret