--- a +++ b/sybil/utils/sampler.py @@ -0,0 +1,100 @@ +import math +from typing import TypeVar, Optional, Iterator, Sequence + +import torch +from torch.utils.data import Dataset +import torch.distributed as dist + + +T_co = TypeVar('T_co', covariant=True) + + +class DistributedWeightedSampler(torch.utils.data.distributed.DistributedSampler): + r"""Extension of pytorch's native distributed sampler, but supports weighted sampling + .. note:: + Dataset is assumed to be of constant size. + Arguments: + dataset: Dataset used for sampling. + num_replicas (int, optional): Number of processes participating in + distributed training. By default, :attr:`rank` is retrieved from the + current distributed group. + rank (int, optional): Rank of the current process within :attr:`num_replicas`. + By default, :attr:`rank` is retrieved from the current distributed + group. + indices. + seed (int, optional): random seed used to shuffle the sampler if + :attr:`shuffle=True`. This number should be identical across all + processes in the distributed group. Default: ``0``. + drop_last (bool, optional): if ``True``, then the sampler will drop the + tail of the data to make it evenly divisible across the number of + replicas. If ``False``, the sampler will add extra indices to make + the data evenly divisible across the replicas. Default: ``False``. + .. warning:: + In distributed mode, calling the :meth:`set_epoch` method at + the beginning of each epoch **before** creating the :class:`DataLoader` iterator + is necessary to make shuffling work properly across multiple epochs. Otherwise, + the same ordering will be always used. + Example:: + >>> sampler = DistributedSampler(dataset) if is_distributed else None + >>> loader = DataLoader(dataset, shuffle=(sampler is None), + ... sampler=sampler) + >>> for epoch in range(start_epoch, n_epochs): + ... if is_distributed: + ... sampler.set_epoch(epoch) + ... train(loader) + """ + + def __init__(self, dataset: Dataset, weights: Sequence[float], + replacement: bool = True, generator=None, num_replicas: Optional[int] = None, + rank: Optional[int] = None, + seed: int = 0, drop_last: bool = False) -> None: + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = dist.get_rank() + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.drop_last = drop_last + self.weights = torch.Tensor(weights) + self.replacement = replacement + self.generator = generator + # If the dataset length is evenly divisible by # of replicas, then there + # is no need to drop any data, since the dataset will be split equally. + if self.drop_last and len(self.dataset) % self.num_replicas != 0: + # Split to nearest available length that is evenly divisible. + # This is to ensure each rank receives the same amount of data when + # using this Sampler. + self.num_samples = math.ceil( + (len(self.dataset) - self.num_replicas) / self.num_replicas + ) + else: + self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) + self.total_size = self.num_samples * self.num_replicas + self.seed = seed + + def __iter__(self) -> Iterator[T_co]: + indices = list(range(len(self.dataset))) + + if not self.drop_last: + # add extra samples to make it evenly divisible + indices += indices[:(self.total_size - len(indices))] + else: + # remove tail of data to make it evenly divisible. + indices = indices[:self.total_size] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank:self.total_size:self.num_replicas] + weights = self.weights[self.rank:self.total_size:self.num_replicas] + + assert len(indices) == self.num_samples + + rand_tensor = torch.multinomial(self.weights, self.num_samples, self.replacement, generator=self.generator) + return iter(rand_tensor) +