[6d389a]: / tools / deployment / mmaction2torchserve.py

Download this file

110 lines (95 with data), 3.9 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
# Copyright (c) OpenMMLab. All rights reserved.
import shutil
from argparse import ArgumentParser, Namespace
from pathlib import Path
from tempfile import TemporaryDirectory
import mmcv
try:
from model_archiver.model_packaging import package_model
from model_archiver.model_packaging_utils import ModelExportUtils
except ImportError:
raise ImportError('`torch-model-archiver` is required.'
'Try: pip install torch-model-archiver')
def mmaction2torchserve(
config_file: str,
checkpoint_file: str,
output_folder: str,
model_name: str,
label_file: str,
model_version: str = '1.0',
force: bool = False,
):
"""Converts MMAction2 model (config + checkpoint) to TorchServe `.mar`.
Args:
config_file (str): In MMAction2 config format.
checkpoint_file (str): In MMAction2 checkpoint format.
output_folder (str): Folder where `{model_name}.mar` will be created.
The file created will be in TorchServe archive format.
label_file (str): A txt file which contains the action category names.
model_name (str | None): If not None, used for naming the
`{model_name}.mar` file that will be created under `output_folder`.
If None, `{Path(checkpoint_file).stem}` will be used.
model_version (str): Model's version.
force (bool): If True, if there is an existing `{model_name}.mar` file
under `output_folder` it will be overwritten.
"""
mmcv.mkdir_or_exist(output_folder)
config = mmcv.Config.fromfile(config_file)
with TemporaryDirectory() as tmpdir:
config.dump(f'{tmpdir}/config.py')
shutil.copy(label_file, f'{tmpdir}/label_map.txt')
args = Namespace(
**{
'model_file': f'{tmpdir}/config.py',
'serialized_file': checkpoint_file,
'handler': f'{Path(__file__).parent}/mmaction_handler.py',
'model_name': model_name or Path(checkpoint_file).stem,
'version': model_version,
'export_path': output_folder,
'force': force,
'requirements_file': None,
'extra_files': f'{tmpdir}/label_map.txt',
'runtime': 'python',
'archive_format': 'default'
})
manifest = ModelExportUtils.generate_manifest_json(args)
package_model(args, manifest)
def parse_args():
parser = ArgumentParser(
description='Convert MMAction2 models to TorchServe `.mar` format.')
parser.add_argument('config', type=str, help='config file path')
parser.add_argument('checkpoint', type=str, help='checkpoint file path')
parser.add_argument(
'--output-folder',
type=str,
required=True,
help='Folder where `{model_name}.mar` will be created.')
parser.add_argument(
'--model-name',
type=str,
default=None,
help='If not None, used for naming the `{model_name}.mar`'
'file that will be created under `output_folder`.'
'If None, `{Path(checkpoint_file).stem}` will be used.')
parser.add_argument(
'--label-file',
type=str,
default=None,
help='A txt file which contains the action category names. ')
parser.add_argument(
'--model-version',
type=str,
default='1.0',
help='Number used for versioning.')
parser.add_argument(
'-f',
'--force',
action='store_true',
help='overwrite the existing `{model_name}.mar`')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
mmaction2torchserve(args.config, args.checkpoint, args.output_folder,
args.model_name, args.label_file, args.model_version,
args.force)