# Copyright (c) OpenMMLab. All rights reserved.
import sys
import warnings
from unittest.mock import MagicMock, Mock, patch
import pytest
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
# TODO import test functions from mmcv and delete them from mmaction2
try:
from mmcv.engine import (collect_results_cpu, multi_gpu_test,
single_gpu_test)
pytest.skip(
'Test functions are supported in MMCV', allow_module_level=True)
except (ImportError, ModuleNotFoundError):
warnings.warn(
'DeprecationWarning: single_gpu_test, multi_gpu_test, '
'collect_results_cpu, collect_results_gpu from mmaction2 will be '
'deprecated. Please install mmcv through master branch.')
from mmaction.apis.test import (collect_results_cpu, multi_gpu_test,
single_gpu_test)
class OldStyleModel(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 3, 1)
self.cnt = 0
def forward(self, *args, **kwargs):
result = [self.cnt]
self.cnt += 1
return result
class Model(OldStyleModel):
def train_step(self):
pass
def val_step(self):
pass
class ExampleDataset(Dataset):
def __init__(self):
self.index = 0
self.eval_result = [1, 4, 3, 7, 2, -3, 4, 6]
def __getitem__(self, idx):
results = dict(imgs=torch.tensor([1]))
return results
def __len__(self):
return len(self.eval_result)
def test_single_gpu_test():
test_dataset = ExampleDataset()
loader = DataLoader(test_dataset, batch_size=1)
model = Model()
results = single_gpu_test(model, loader)
assert results == list(range(8))
def mock_tensor_without_cuda(*args, **kwargs):
if 'device' not in kwargs:
return torch.Tensor(*args)
return torch.IntTensor(*args, device='cpu')
@patch('mmaction.apis.test.collect_results_gpu',
Mock(return_value=list(range(8))))
@patch('mmaction.apis.test.collect_results_cpu',
Mock(return_value=list(range(8))))
def test_multi_gpu_test():
test_dataset = ExampleDataset()
loader = DataLoader(test_dataset, batch_size=1)
model = Model()
results = multi_gpu_test(model, loader)
assert results == list(range(8))
results = multi_gpu_test(model, loader, gpu_collect=False)
assert results == list(range(8))
@patch('mmcv.runner.get_dist_info', Mock(return_value=(0, 1)))
@patch('torch.distributed.broadcast', MagicMock)
@patch('torch.distributed.barrier', Mock)
@pytest.mark.skipif(
sys.version_info[:2] == (3, 8), reason='Not for python 3.8')
def test_collect_results_cpu():
def content_for_unittest():
results_part = list(range(8))
size = 8
results = collect_results_cpu(results_part, size)
assert results == list(range(8))
results = collect_results_cpu(results_part, size, 'unittest')
assert results == list(range(8))
if not torch.cuda.is_available():
with patch(
'torch.full',
Mock(
return_value=torch.full(
(512, ), 32, dtype=torch.uint8, device='cpu'))):
with patch('torch.tensor', mock_tensor_without_cuda):
content_for_unittest()
else:
content_for_unittest()