# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import pytest
import torch
from mmcv.parallel import DataContainer as DC
from mmcv.utils import assert_dict_has_keys
from mmaction.datasets.pipelines import (Collect, FormatAudioShape,
FormatGCNInput, FormatShape,
ImageToTensor, Rename,
ToDataContainer, ToTensor, Transpose)
def test_rename():
org_name = 'a'
new_name = 'b'
mapping = {org_name: new_name}
rename = Rename(mapping)
results = dict(a=2)
results = rename(results)
assert results['b'] == 2
assert 'a' not in results
def test_to_tensor():
to_tensor = ToTensor(['str'])
with pytest.raises(TypeError):
# str cannot be converted to tensor
results = dict(str='0')
to_tensor(results)
# convert tensor, numpy, sequence, int, float to tensor
target_keys = ['tensor', 'numpy', 'sequence', 'int', 'float']
to_tensor = ToTensor(target_keys)
original_results = dict(
tensor=torch.randn(2, 3),
numpy=np.random.randn(2, 3),
sequence=list(range(10)),
int=1,
float=0.1)
results = to_tensor(original_results)
assert assert_dict_has_keys(results, target_keys)
for key in target_keys:
assert isinstance(results[key], torch.Tensor)
assert torch.equal(results[key].data, original_results[key])
# Add an additional key which is not in keys.
original_results = dict(
tensor=torch.randn(2, 3),
numpy=np.random.randn(2, 3),
sequence=list(range(10)),
int=1,
float=0.1,
str='test')
results = to_tensor(original_results)
assert assert_dict_has_keys(results, target_keys)
for key in target_keys:
assert isinstance(results[key], torch.Tensor)
assert torch.equal(results[key].data, original_results[key])
assert repr(to_tensor) == to_tensor.__class__.__name__ + \
f'(keys={target_keys})'
def test_to_data_container():
# check user-defined fields
fields = (dict(key='key1', stack=True), dict(key='key2'))
to_data_container = ToDataContainer(fields=fields)
target_keys = ['key1', 'key2']
original_results = dict(key1=np.random.randn(10, 20), key2=['a', 'b'])
results = to_data_container(original_results.copy())
assert assert_dict_has_keys(results, target_keys)
for key in target_keys:
assert isinstance(results[key], DC)
assert np.all(results[key].data == original_results[key])
assert results['key1'].stack
assert not results['key2'].stack
# Add an additional key which is not in keys.
original_results = dict(
key1=np.random.randn(10, 20), key2=['a', 'b'], key3='value3')
results = to_data_container(original_results.copy())
assert assert_dict_has_keys(results, target_keys)
for key in target_keys:
assert isinstance(results[key], DC)
assert np.all(results[key].data == original_results[key])
assert results['key1'].stack
assert not results['key2'].stack
assert repr(to_data_container) == (
to_data_container.__class__.__name__ + f'(fields={fields})')
def test_image_to_tensor():
original_results = dict(imgs=np.random.randn(256, 256, 3))
keys = ['imgs']
image_to_tensor = ImageToTensor(keys)
results = image_to_tensor(original_results)
assert results['imgs'].shape == torch.Size([3, 256, 256])
assert isinstance(results['imgs'], torch.Tensor)
assert torch.equal(results['imgs'].data, original_results['imgs'])
assert repr(image_to_tensor) == image_to_tensor.__class__.__name__ + \
f'(keys={keys})'
def test_transpose():
results = dict(imgs=np.random.randn(256, 256, 3))
keys = ['imgs']
order = [2, 0, 1]
transpose = Transpose(keys, order)
results = transpose(results)
assert results['imgs'].shape == (3, 256, 256)
assert repr(transpose) == transpose.__class__.__name__ + \
f'(keys={keys}, order={order})'
def test_collect():
inputs = dict(
imgs=np.random.randn(256, 256, 3),
label=[1],
filename='test.txt',
original_shape=(256, 256, 3),
img_shape=(256, 256, 3),
pad_shape=(256, 256, 3),
flip_direction='vertical',
img_norm_cfg=dict(to_bgr=False))
keys = ['imgs', 'label']
collect = Collect(keys)
results = collect(inputs)
assert sorted(list(results.keys())) == sorted(
['imgs', 'label', 'img_metas'])
imgs = inputs.pop('imgs')
assert set(results['img_metas'].data) == set(inputs)
for key in results['img_metas'].data:
assert results['img_metas'].data[key] == inputs[key]
assert repr(collect) == collect.__class__.__name__ + \
(f'(keys={keys}, meta_keys={collect.meta_keys}, '
f'nested={collect.nested})')
inputs['imgs'] = imgs
collect = Collect(keys, nested=True)
results = collect(inputs)
assert sorted(list(results.keys())) == sorted(
['imgs', 'label', 'img_metas'])
for k in results:
assert isinstance(results[k], list)
def test_format_shape():
with pytest.raises(ValueError):
# invalid input format
FormatShape('NHWC')
# 'NCHW' input format
results = dict(
imgs=np.random.randn(3, 224, 224, 3), num_clips=1, clip_len=3)
format_shape = FormatShape('NCHW')
assert format_shape(results)['input_shape'] == (3, 3, 224, 224)
# `NCTHW` input format with num_clips=1, clip_len=3
results = dict(
imgs=np.random.randn(3, 224, 224, 3), num_clips=1, clip_len=3)
format_shape = FormatShape('NCTHW')
assert format_shape(results)['input_shape'] == (1, 3, 3, 224, 224)
# `NCTHW` input format with num_clips=2, clip_len=3
results = dict(
imgs=np.random.randn(18, 224, 224, 3), num_clips=2, clip_len=3)
assert format_shape(results)['input_shape'] == (6, 3, 3, 224, 224)
target_keys = ['imgs', 'input_shape']
assert assert_dict_has_keys(results, target_keys)
assert repr(format_shape) == format_shape.__class__.__name__ + \
"(input_format='NCTHW')"
# 'NPTCHW' input format
results = dict(
imgs=np.random.randn(72, 224, 224, 3),
num_clips=9,
clip_len=1,
num_proposals=8)
format_shape = FormatShape('NPTCHW')
assert format_shape(results)['input_shape'] == (8, 9, 3, 224, 224)
def test_format_audio_shape():
with pytest.raises(ValueError):
# invalid input format
FormatAudioShape('XXXX')
# 'NCTF' input format
results = dict(audios=np.random.randn(3, 128, 8))
format_shape = FormatAudioShape('NCTF')
assert format_shape(results)['input_shape'] == (3, 1, 128, 8)
assert repr(format_shape) == format_shape.__class__.__name__ + \
"(input_format='NCTF')"
def test_format_gcn_input():
with pytest.raises(ValueError):
# invalid input format
FormatGCNInput('XXXX')
# 'NCTVM' input format
results = dict(
keypoint=np.random.randn(2, 300, 17, 2),
keypoint_score=np.random.randn(2, 300, 17))
format_shape = FormatGCNInput('NCTVM', num_person=2)
assert format_shape(results)['input_shape'] == (3, 300, 17, 2)
assert repr(format_shape) == format_shape.__class__.__name__ + \
"(input_format='NCTVM')"
# test real num_person < 2
results = dict(
keypoint=np.random.randn(1, 300, 17, 2),
keypoint_score=np.random.randn(1, 300, 17))
assert format_shape(results)['input_shape'] == (3, 300, 17, 2)
assert repr(format_shape) == format_shape.__class__.__name__ + \
"(input_format='NCTVM')"
# test real num_person > 2
results = dict(
keypoint=np.random.randn(3, 300, 17, 2),
keypoint_score=np.random.randn(3, 300, 17))
assert format_shape(results)['input_shape'] == (3, 300, 17, 2)
assert repr(format_shape) == format_shape.__class__.__name__ + \
"(input_format='NCTVM')"