Switch to side-by-side view

--- a
+++ b/mmaction/datasets/pipelines/compose.py
@@ -0,0 +1,61 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from collections.abc import Sequence
+
+from mmcv.utils import build_from_cfg
+
+from ..builder import PIPELINES
+from .augmentations import PytorchVideoTrans, TorchvisionTrans
+
+
+@PIPELINES.register_module()
+class Compose:
+    """Compose a data pipeline with a sequence of transforms.
+
+    Args:
+        transforms (list[dict | callable]):
+            Either config dicts of transforms or transform objects.
+    """
+
+    def __init__(self, transforms):
+        assert isinstance(transforms, Sequence)
+        self.transforms = []
+        for transform in transforms:
+            if isinstance(transform, dict):
+                if transform['type'].startswith('torchvision.'):
+                    trans_type = transform.pop('type')[12:]
+                    transform = TorchvisionTrans(trans_type, **transform)
+                elif transform['type'].startswith('pytorchvideo.'):
+                    trans_type = transform.pop('type')[13:]
+                    transform = PytorchVideoTrans(trans_type, **transform)
+                else:
+                    transform = build_from_cfg(transform, PIPELINES)
+                self.transforms.append(transform)
+            elif callable(transform):
+                self.transforms.append(transform)
+            else:
+                raise TypeError(f'transform must be callable or a dict, '
+                                f'but got {type(transform)}')
+
+    def __call__(self, data):
+        """Call function to apply transforms sequentially.
+
+        Args:
+            data (dict): A result dict contains the data to transform.
+
+        Returns:
+            dict: Transformed data.
+        """
+        #print(data)
+        for t in self.transforms:
+            data = t(data)
+            if data is None:
+                return None
+        return data
+
+    def __repr__(self):
+        format_string = self.__class__.__name__ + '('
+        for t in self.transforms:
+            format_string += '\n'
+            format_string += '    {0}'.format(t)
+        format_string += '\n)'
+        return format_string