--- a +++ b/tests/test_data/test_compose.py @@ -0,0 +1,72 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import pytest +from mmcv.utils import assert_keys_equal, digit_version + +from mmaction.datasets.pipelines import Compose, ImageToTensor + +try: + import torchvision + torchvision_ok = False + if digit_version(torchvision.__version__) >= digit_version('0.8.0'): + torchvision_ok = True +except (ImportError, ModuleNotFoundError): + torchvision_ok = False + + +def test_compose(): + with pytest.raises(TypeError): + # transform must be callable or a dict + Compose('LoadImage') + + target_keys = ['img', 'img_metas'] + + # test Compose given a data pipeline + img = np.random.randn(256, 256, 3) + results = dict(img=img, abandoned_key=None, img_name='test_image.png') + test_pipeline = [ + dict(type='Collect', keys=['img'], meta_keys=['img_name']), + dict(type='ImageToTensor', keys=['img']) + ] + compose = Compose(test_pipeline) + compose_results = compose(results) + assert assert_keys_equal(compose_results.keys(), target_keys) + assert assert_keys_equal(compose_results['img_metas'].data.keys(), + ['img_name']) + + # test Compose when forward data is None + results = None + image_to_tensor = ImageToTensor(keys=[]) + test_pipeline = [image_to_tensor] + compose = Compose(test_pipeline) + compose_results = compose(results) + assert compose_results is None + + assert repr(compose) == compose.__class__.__name__ + \ + f'(\n {image_to_tensor}\n)' + + +@pytest.mark.skipif( + not torchvision_ok, reason='torchvision >= 0.8.0 is required') +def test_compose_support_torchvision(): + target_keys = ['imgs', 'img_metas'] + + # test Compose given a data pipeline + imgs = [np.random.randn(256, 256, 3)] * 8 + results = dict( + imgs=imgs, + abandoned_key=None, + img_name='test_image.png', + clip_len=8, + num_clips=1) + test_pipeline = [ + dict(type='torchvision.Grayscale', num_output_channels=3), + dict(type='FormatShape', input_format='NCTHW'), + dict(type='Collect', keys=['imgs'], meta_keys=['img_name']), + dict(type='ToTensor', keys=['imgs']) + ] + compose = Compose(test_pipeline) + compose_results = compose(results) + assert assert_keys_equal(compose_results.keys(), target_keys) + assert assert_keys_equal(compose_results['img_metas'].data.keys(), + ['img_name'])