a b/tools/pytorch2torchscript.py
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import argparse
3
4
import mmcv
5
import numpy as np
6
import torch
7
import torch._C
8
import torch.serialization
9
from mmcv.runner import load_checkpoint
10
from torch import nn
11
12
from mmseg.models import build_segmentor
13
14
torch.manual_seed(3)
15
16
17
def digit_version(version_str):
18
    digit_version = []
19
    for x in version_str.split('.'):
20
        if x.isdigit():
21
            digit_version.append(int(x))
22
        elif x.find('rc') != -1:
23
            patch_version = x.split('rc')
24
            digit_version.append(int(patch_version[0]) - 1)
25
            digit_version.append(int(patch_version[1]))
26
    return digit_version
27
28
29
def check_torch_version():
30
    torch_minimum_version = '1.8.0'
31
    torch_version = digit_version(torch.__version__)
32
33
    assert (torch_version >= digit_version(torch_minimum_version)), \
34
        f'Torch=={torch.__version__} is not support for converting to ' \
35
        f'torchscript. Please install pytorch>={torch_minimum_version}.'
36
37
38
def _convert_batchnorm(module):
39
    module_output = module
40
    if isinstance(module, torch.nn.SyncBatchNorm):
41
        module_output = torch.nn.BatchNorm2d(module.num_features, module.eps,
42
                                             module.momentum, module.affine,
43
                                             module.track_running_stats)
44
        if module.affine:
45
            module_output.weight.data = module.weight.data.clone().detach()
46
            module_output.bias.data = module.bias.data.clone().detach()
47
            # keep requires_grad unchanged
48
            module_output.weight.requires_grad = module.weight.requires_grad
49
            module_output.bias.requires_grad = module.bias.requires_grad
50
        module_output.running_mean = module.running_mean
51
        module_output.running_var = module.running_var
52
        module_output.num_batches_tracked = module.num_batches_tracked
53
    for name, child in module.named_children():
54
        module_output.add_module(name, _convert_batchnorm(child))
55
    del module
56
    return module_output
57
58
59
def _demo_mm_inputs(input_shape, num_classes):
60
    """Create a superset of inputs needed to run test or train batches.
61
62
    Args:
63
        input_shape (tuple):
64
            input batch dimensions
65
        num_classes (int):
66
            number of semantic classes
67
    """
68
    (N, C, H, W) = input_shape
69
    rng = np.random.RandomState(0)
70
    imgs = rng.rand(*input_shape)
71
    segs = rng.randint(
72
        low=0, high=num_classes - 1, size=(N, 1, H, W)).astype(np.uint8)
73
    img_metas = [{
74
        'img_shape': (H, W, C),
75
        'ori_shape': (H, W, C),
76
        'pad_shape': (H, W, C),
77
        'filename': '<demo>.png',
78
        'scale_factor': 1.0,
79
        'flip': False,
80
    } for _ in range(N)]
81
    mm_inputs = {
82
        'imgs': torch.FloatTensor(imgs).requires_grad_(True),
83
        'img_metas': img_metas,
84
        'gt_semantic_seg': torch.LongTensor(segs)
85
    }
86
    return mm_inputs
87
88
89
def pytorch2libtorch(model,
90
                     input_shape,
91
                     show=False,
92
                     output_file='tmp.pt',
93
                     verify=False):
94
    """Export Pytorch model to TorchScript model and verify the outputs are
95
    same between Pytorch and TorchScript.
96
97
    Args:
98
        model (nn.Module): Pytorch model we want to export.
99
        input_shape (tuple): Use this input shape to construct
100
            the corresponding dummy input and execute the model.
101
        show (bool): Whether print the computation graph. Default: False.
102
        output_file (string): The path to where we store the
103
            output TorchScript model. Default: `tmp.pt`.
104
        verify (bool): Whether compare the outputs between
105
            Pytorch and TorchScript. Default: False.
106
    """
107
    if isinstance(model.decode_head, nn.ModuleList):
108
        num_classes = model.decode_head[-1].num_classes
109
    else:
110
        num_classes = model.decode_head.num_classes
111
112
    mm_inputs = _demo_mm_inputs(input_shape, num_classes)
113
114
    imgs = mm_inputs.pop('imgs')
115
116
    # replace the original forword with forward_dummy
117
    model.forward = model.forward_dummy
118
    model.eval()
119
    traced_model = torch.jit.trace(
120
        model,
121
        example_inputs=imgs,
122
        check_trace=verify,
123
    )
124
125
    if show:
126
        print(traced_model.graph)
127
128
    traced_model.save(output_file)
129
    print('Successfully exported TorchScript model: {}'.format(output_file))
130
131
132
def parse_args():
133
    parser = argparse.ArgumentParser(
134
        description='Convert MMSeg to TorchScript')
135
    parser.add_argument('config', help='test config file path')
136
    parser.add_argument('--checkpoint', help='checkpoint file', default=None)
137
    parser.add_argument(
138
        '--show', action='store_true', help='show TorchScript graph')
139
    parser.add_argument(
140
        '--verify', action='store_true', help='verify the TorchScript model')
141
    parser.add_argument('--output-file', type=str, default='tmp.pt')
142
    parser.add_argument(
143
        '--shape',
144
        type=int,
145
        nargs='+',
146
        default=[512, 512],
147
        help='input image size (height, width)')
148
    args = parser.parse_args()
149
    return args
150
151
152
if __name__ == '__main__':
153
    args = parse_args()
154
    check_torch_version()
155
156
    if len(args.shape) == 1:
157
        input_shape = (1, 3, args.shape[0], args.shape[0])
158
    elif len(args.shape) == 2:
159
        input_shape = (
160
            1,
161
            3,
162
        ) + tuple(args.shape)
163
    else:
164
        raise ValueError('invalid input shape')
165
166
    cfg = mmcv.Config.fromfile(args.config)
167
    cfg.model.pretrained = None
168
169
    # build the model and load checkpoint
170
    cfg.model.train_cfg = None
171
    segmentor = build_segmentor(
172
        cfg.model, train_cfg=None, test_cfg=cfg.get('test_cfg'))
173
    # convert SyncBN to BN
174
    segmentor = _convert_batchnorm(segmentor)
175
176
    if args.checkpoint:
177
        load_checkpoint(segmentor, args.checkpoint, map_location='cpu')
178
179
    # convert the PyTorch model to LibTorch model
180
    pytorch2libtorch(
181
        segmentor,
182
        input_shape,
183
        show=args.show,
184
        output_file=args.output_file,
185
        verify=args.verify)