# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile
import torch.nn as nn
from tools.deployment.pytorch2onnx import _convert_batchnorm, pytorch2onnx
class DummyModel(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv3d(1, 2, 1)
self.bn = nn.SyncBatchNorm(2)
def forward(self, x):
return self.bn(self.conv(x))
def forward_dummy(self, x):
return (self.forward(x), )
def test_onnx_exporting():
with tempfile.TemporaryDirectory() as tmpdir:
out_file = osp.join(tmpdir, 'tmp.onnx')
model = DummyModel()
model = _convert_batchnorm(model)
# test exporting
pytorch2onnx(model, (1, 1, 1, 1, 1), output_file=out_file)