--- a
+++ b/mmaction/datasets/dataset_wrappers.py
@@ -0,0 +1,71 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+
+from .builder import DATASETS, build_dataset
+
+
+@DATASETS.register_module()
+class RepeatDataset:
+    """A wrapper of repeated dataset.
+
+    The length of repeated dataset will be ``times`` larger than the original
+    dataset. This is useful when the data loading time is long but the dataset
+    is small. Using RepeatDataset can reduce the data loading time between
+    epochs.
+
+    Args:
+        dataset (dict): The config of the dataset to be repeated.
+        times (int): Repeat times.
+        test_mode (bool): Store True when building test or validation dataset.
+            Default: False.
+    """
+
+    def __init__(self, dataset, times, test_mode=False):
+        dataset['test_mode'] = test_mode
+        self.dataset = build_dataset(dataset)
+        self.times = times
+
+        self._ori_len = len(self.dataset)
+
+    def __getitem__(self, idx):
+        """Get data."""
+        return self.dataset[idx % self._ori_len]
+
+    def __len__(self):
+        """Length after repetition."""
+        return self.times * self._ori_len
+
+
+@DATASETS.register_module()
+class ConcatDataset:
+    """A wrapper of concatenated dataset.
+
+    The length of concatenated dataset will be the sum of lengths of all
+    datasets. This is useful when you want to train a model with multiple data
+    sources.
+
+    Args:
+        datasets (list[dict]): The configs of the datasets.
+        test_mode (bool): Store True when building test or validation dataset.
+            Default: False.
+    """
+
+    def __init__(self, datasets, test_mode=False):
+
+        for item in datasets:
+            item['test_mode'] = test_mode
+
+        datasets = [build_dataset(cfg) for cfg in datasets]
+        self.datasets = datasets
+        self.lens = [len(x) for x in self.datasets]
+        self.cumsum = np.cumsum(self.lens)
+
+    def __getitem__(self, idx):
+        """Get data."""
+        dataset_idx = np.searchsorted(self.cumsum, idx, side='right')
+        item_idx = idx if dataset_idx == 0 else idx - self.cumsum[dataset_idx]
+        return self.datasets[dataset_idx][item_idx]
+
+    def __len__(self):
+        """Length after repetition."""
+        return sum(self.lens)