--- a
+++ b/tests/test_data/test_sampler.py
@@ -0,0 +1,96 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from torch.utils.data import DataLoader, Dataset
+
+from mmaction.datasets.samplers import (ClassSpecificDistributedSampler,
+                                        DistributedSampler)
+
+
+class MyDataset(Dataset):
+
+    def __init__(self, class_prob={i: 1 for i in range(10)}):
+        super().__init__()
+        self.class_prob = class_prob
+        self.video_infos = [
+            dict(data=idx, label=idx % 10) for idx in range(100)
+        ]
+
+    def __len__(self):
+        return len(self.video_infos)
+
+    def __getitem__(self, idx):
+        return self.video_infos[idx]
+
+
+def test_distributed_sampler():
+    dataset = MyDataset()
+    sampler = DistributedSampler(dataset, num_replicas=1, rank=0)
+    data_loader = DataLoader(dataset, batch_size=4, sampler=sampler)
+    batches = []
+    for _, data in enumerate(data_loader):
+        batches.append(data)
+
+    assert len(batches) == 25
+    assert sum([len(x['data']) for x in batches]) == 100
+
+    sampler = DistributedSampler(dataset, num_replicas=4, rank=2)
+    data_loader = DataLoader(dataset, batch_size=4, sampler=sampler)
+    batches = []
+    for i, data in enumerate(data_loader):
+        batches.append(data)
+
+    assert len(batches) == 7
+    assert sum([len(x['data']) for x in batches]) == 25
+
+    sampler = DistributedSampler(dataset, num_replicas=6, rank=3)
+    data_loader = DataLoader(dataset, batch_size=4, sampler=sampler)
+    batches = []
+    for i, data in enumerate(data_loader):
+        batches.append(data)
+
+    assert len(batches) == 5
+    assert sum([len(x['data']) for x in batches]) == 17
+
+
+def test_class_specific_distributed_sampler():
+    class_prob = dict(zip(list(range(10)), [1] * 5 + [3] * 5))
+    dataset = MyDataset(class_prob=class_prob)
+
+    sampler = ClassSpecificDistributedSampler(
+        dataset, num_replicas=1, rank=0, dynamic_length=True)
+    data_loader = DataLoader(dataset, batch_size=4, sampler=sampler)
+    batches = []
+    for _, data in enumerate(data_loader):
+        batches.append(data)
+
+    assert len(batches) == 50
+    assert sum([len(x['data']) for x in batches]) == 200
+
+    sampler = ClassSpecificDistributedSampler(
+        dataset, num_replicas=1, rank=0, dynamic_length=False)
+    data_loader = DataLoader(dataset, batch_size=4, sampler=sampler)
+    batches = []
+    for i, data in enumerate(data_loader):
+        batches.append(data)
+
+    assert len(batches) == 25
+    assert sum([len(x['data']) for x in batches]) == 100
+
+    sampler = ClassSpecificDistributedSampler(
+        dataset, num_replicas=6, rank=2, dynamic_length=True)
+    data_loader = DataLoader(dataset, batch_size=4, sampler=sampler)
+    batches = []
+    for i, data in enumerate(data_loader):
+        batches.append(data)
+
+    assert len(batches) == 9
+    assert sum([len(x['data']) for x in batches]) == 34
+
+    sampler = ClassSpecificDistributedSampler(
+        dataset, num_replicas=6, rank=2, dynamic_length=False)
+    data_loader = DataLoader(dataset, batch_size=4, sampler=sampler)
+    batches = []
+    for i, data in enumerate(data_loader):
+        batches.append(data)
+
+    assert len(batches) == 5
+    assert sum([len(x['data']) for x in batches]) == 17