Switch to side-by-side view

--- a
+++ b/mmaction/datasets/samplers/distributed_sampler.py
@@ -0,0 +1,135 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+from collections import defaultdict
+
+import torch
+from torch.utils.data import DistributedSampler as _DistributedSampler
+
+
+class DistributedSampler(_DistributedSampler):
+    """DistributedSampler inheriting from
+    ``torch.utils.data.DistributedSampler``.
+
+    In pytorch of lower versions, there is no ``shuffle`` argument. This child
+    class will port one to DistributedSampler.
+    """
+
+    def __init__(self,
+                 dataset,
+                 num_replicas=None,
+                 rank=None,
+                 shuffle=True,
+                 seed=0):
+        super().__init__(
+            dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
+        # for the compatibility from PyTorch 1.3+
+        self.seed = seed if seed is not None else 0
+
+    def __iter__(self):
+        # deterministically shuffle based on epoch
+        if self.shuffle:
+            g = torch.Generator()
+            g.manual_seed(self.epoch + self.seed)
+            indices = torch.randperm(len(self.dataset), generator=g).tolist()
+        else:
+            indices = torch.arange(len(self.dataset)).tolist()
+
+        # add extra samples to make it evenly divisible
+        indices += indices[:(self.total_size - len(indices))]
+        assert len(indices) == self.total_size
+
+        # subsample
+        indices = indices[self.rank:self.total_size:self.num_replicas]
+        assert len(indices) == self.num_samples
+        return iter(indices)
+
+
+class ClassSpecificDistributedSampler(_DistributedSampler):
+    """ClassSpecificDistributedSampler inheriting from
+    ``torch.utils.data.DistributedSampler``.
+
+    Samples are sampled with a class specific probability, which should be an
+    attribute of the dataset (dataset.class_prob, which is a dictionary that
+    map label index to the prob). This sampler is only applicable to single
+    class recognition dataset. This sampler is also compatible with
+    RepeatDataset.
+
+    The default value of dynamic_length is True, which means we use
+    oversampling / subsampling, and the dataset length may changed. If
+    dynamic_length is set as False, the dataset length is fixed.
+    """
+
+    def __init__(self,
+                 dataset,
+                 num_replicas=None,
+                 rank=None,
+                 dynamic_length=True,
+                 shuffle=True,
+                 seed=0):
+        super().__init__(dataset, num_replicas=num_replicas, rank=rank)
+        self.shuffle = shuffle
+
+        if type(dataset).__name__ == 'RepeatDataset':
+            dataset = dataset.dataset
+
+        assert hasattr(dataset, 'class_prob')
+
+        self.class_prob = dataset.class_prob
+        self.dynamic_length = dynamic_length
+        # for the compatibility from PyTorch 1.3+
+        self.seed = seed if seed is not None else 0
+
+    def __iter__(self):
+        g = torch.Generator()
+        g.manual_seed(self.seed + self.epoch)
+
+        class_indices = defaultdict(list)
+
+        # To be compatible with RepeatDataset
+        times = 1
+        dataset = self.dataset
+        if type(dataset).__name__ == 'RepeatDataset':
+            times = dataset.times
+            dataset = dataset.dataset
+        for i, item in enumerate(dataset.video_infos):
+            class_indices[item['label']].append(i)
+
+        if self.dynamic_length:
+            indices = []
+            for k, prob in self.class_prob.items():
+                prob = prob * times
+                for i in range(int(prob // 1)):
+                    indices.extend(class_indices[k])
+                rem = int((prob % 1) * len(class_indices[k]))
+                rem_indices = torch.randperm(
+                    len(class_indices[k]), generator=g).tolist()[:rem]
+                indices.extend(rem_indices)
+            if self.shuffle:
+                shuffle = torch.randperm(len(indices), generator=g).tolist()
+                indices = [indices[i] for i in shuffle]
+
+            # re-calc num_samples & total_size
+            self.num_samples = math.ceil(len(indices) / self.num_replicas)
+            self.total_size = self.num_samples * self.num_replicas
+        else:
+            # We want to keep the dataloader length same as original
+            video_labels = [x['label'] for x in dataset.video_infos]
+            probs = [
+                self.class_prob[lb] / len(class_indices[lb])
+                for lb in video_labels
+            ]
+
+            indices = torch.multinomial(
+                torch.Tensor(probs),
+                self.total_size,
+                replacement=True,
+                generator=g)
+            indices = indices.data.numpy().tolist()
+
+        indices += indices[:(self.total_size - len(indices))]
+        assert len(indices) == self.total_size
+
+        # retrieve indices for current process
+        indices = indices[self.rank:self.total_size:self.num_replicas]
+        assert len(indices) == self.num_samples
+        return iter(indices)