a b/ViTPose/tests/test_onnx.py
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import os.path as osp
3
import tempfile
4
5
import torch.nn as nn
6
7
from tools.deployment.pytorch2onnx import _convert_batchnorm, pytorch2onnx
8
9
10
class DummyModel(nn.Module):
11
12
    def __init__(self):
13
        super().__init__()
14
        self.conv = nn.Conv3d(1, 2, 1)
15
        self.bn = nn.SyncBatchNorm(2)
16
17
    def forward(self, x):
18
        return self.bn(self.conv(x))
19
20
    def forward_dummy(self, x):
21
        return (self.forward(x), )
22
23
24
def test_onnx_exporting():
25
    with tempfile.TemporaryDirectory() as tmpdir:
26
        out_file = osp.join(tmpdir, 'tmp.onnx')
27
        model = DummyModel()
28
        model = _convert_batchnorm(model)
29
        # test exporting
30
        pytorch2onnx(model, (1, 1, 1, 1, 1), output_file=out_file)