--- a +++ b/unimol/data/resampling_dataset.py @@ -0,0 +1,117 @@ +import logging + +import numpy as np + +from unicore.data import BaseWrapperDataset +from .plasma_utils import PlasmaArray + +logger = logging.getLogger(__name__) + + + +class ResamplingDataset(BaseWrapperDataset): + """Randomly samples from a given dataset at each epoch. + Sampling is done with or without replacement, depending on the "replace" + parameter. + Optionally, the epoch size can be rescaled. This is potentially desirable + to increase per-epoch coverage of the base dataset (since sampling with + replacement means that many items in the dataset will be left out). In the + case of sampling without replacement, size_ratio should be strictly less + than 1. + Args: + dataset (~torch.utils.data.Dataset): dataset on which to sample. + weights (List[float]): list of probability weights + (default: None, which corresponds to uniform sampling). + replace (bool): sampling mode; True for "with replacement", or False + for "without replacement" (default: True) + size_ratio (float): the ratio to subsample to; must be positive + (default: 1.0). + batch_by_size (bool): whether or not to batch by sequence length + (default: True). + seed (int): RNG seed to use (default: 0). + epoch (int): starting epoch number (default: 1). + """ + + def __init__( + self, + dataset, + replace=False, + size_ratio=1.0, + batch_by_size=True, + seed=0, + epoch=1, + ): + super().__init__(dataset) + + + + self.replace = replace + + self.size_ratio = float(size_ratio) + self.actual_size = np.ceil(len(dataset) * self.size_ratio).astype(int) + + self.seed = seed + + self._cur_epoch = None + self._cur_indices = None + + self.set_epoch(epoch) + + def __getitem__(self, index): + return self.dataset[self._cur_indices.array[index]] + + def __len__(self): + return self.actual_size + + @property + def sizes(self): + if isinstance(self.dataset.sizes, list): + return [s[self._cur_indices.array] for s in self.dataset.sizes] + return self.dataset.sizes[self._cur_indices.array] + + def num_tokens(self, index): + return self.dataset.num_tokens(self._cur_indices.array[index]) + + def size(self, index): + return self.dataset.size(self._cur_indices.array[index]) + + def ordered_indices(self): + return np.arange(len(self)) + + def prefetch(self, indices): + self.dataset.prefetch(self._cur_indices.array[indices]) + + @property + def can_reuse_epoch_itr_across_epochs(self): + return False + + def set_epoch(self, epoch): + # log + + + logger.info("ResamplingDataset.set_epoch: {}".format(epoch)) + super().set_epoch(epoch) + + if epoch == self._cur_epoch: + return + + self._cur_epoch = epoch + + # Generate a weighted sample of indices as a function of the + # random seed and the current epoch. + + rng = np.random.RandomState( + [ + 42, # magic number + self.seed % (2**32), # global seed + self._cur_epoch, # epoch index + ] + ) + self._cur_indices = PlasmaArray(rng.choice( + len(self.dataset), + self.actual_size, + replace=self.replace, + p=(None), + ) + ) + #print(self._cur_epoch, self._cur_indices.array[:10]) \ No newline at end of file