--- a
+++ b/unimol/data/resampling_dataset.py
@@ -0,0 +1,117 @@
+import logging
+
+import numpy as np
+
+from unicore.data  import BaseWrapperDataset
+from .plasma_utils import PlasmaArray
+
+logger = logging.getLogger(__name__)
+
+
+
+class ResamplingDataset(BaseWrapperDataset):
+    """Randomly samples from a given dataset at each epoch.
+    Sampling is done with or without replacement, depending on the "replace"
+    parameter.
+    Optionally, the epoch size can be rescaled. This is potentially desirable
+    to increase per-epoch coverage of the base dataset (since sampling with
+    replacement means that many items in the dataset will be left out). In the
+    case of sampling without replacement, size_ratio should be strictly less
+    than 1.
+    Args:
+        dataset (~torch.utils.data.Dataset): dataset on which to sample.
+        weights (List[float]): list of probability weights
+            (default: None, which corresponds to uniform sampling).
+        replace (bool): sampling mode; True for "with replacement", or False
+            for "without replacement" (default: True)
+        size_ratio (float): the ratio to subsample to; must be positive
+            (default: 1.0).
+        batch_by_size (bool): whether or not to batch by sequence length
+            (default: True).
+        seed (int): RNG seed to use (default: 0).
+        epoch (int): starting epoch number (default: 1).
+    """
+
+    def __init__(
+        self,
+        dataset,
+        replace=False,
+        size_ratio=1.0,
+        batch_by_size=True,
+        seed=0,
+        epoch=1,
+    ):
+        super().__init__(dataset)
+
+        
+
+        self.replace = replace
+
+        self.size_ratio = float(size_ratio)
+        self.actual_size = np.ceil(len(dataset) * self.size_ratio).astype(int)
+
+        self.seed = seed
+
+        self._cur_epoch = None
+        self._cur_indices = None
+
+        self.set_epoch(epoch)
+
+    def __getitem__(self, index):
+        return self.dataset[self._cur_indices.array[index]]
+
+    def __len__(self):
+        return self.actual_size
+
+    @property
+    def sizes(self):
+        if isinstance(self.dataset.sizes, list):
+            return [s[self._cur_indices.array] for s in self.dataset.sizes]
+        return self.dataset.sizes[self._cur_indices.array]
+
+    def num_tokens(self, index):
+        return self.dataset.num_tokens(self._cur_indices.array[index])
+
+    def size(self, index):
+        return self.dataset.size(self._cur_indices.array[index])
+
+    def ordered_indices(self):
+        return np.arange(len(self))
+
+    def prefetch(self, indices):
+        self.dataset.prefetch(self._cur_indices.array[indices])
+
+    @property
+    def can_reuse_epoch_itr_across_epochs(self):
+        return False
+
+    def set_epoch(self, epoch):
+        # log
+        
+
+        logger.info("ResamplingDataset.set_epoch: {}".format(epoch))
+        super().set_epoch(epoch)
+
+        if epoch == self._cur_epoch:
+            return
+
+        self._cur_epoch = epoch
+
+        # Generate a weighted sample of indices as a function of the
+        # random seed and the current epoch.
+
+        rng = np.random.RandomState(
+            [
+                42,  # magic number
+                self.seed % (2**32),  # global seed
+                self._cur_epoch,  # epoch index
+            ]
+        )
+        self._cur_indices = PlasmaArray(rng.choice(
+                len(self.dataset),
+                self.actual_size,
+                replace=self.replace,
+                p=(None),
+            )
+        )
+        #print(self._cur_epoch, self._cur_indices.array[:10])
\ No newline at end of file