a b/tools/deployment/publish_model.py
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import argparse
3
import os
4
import platform
5
import subprocess
6
7
import torch
8
9
10
def parse_args():
11
    parser = argparse.ArgumentParser(
12
        description='Process a checkpoint to be published')
13
    parser.add_argument('in_file', help='input checkpoint filename')
14
    parser.add_argument('out_file', help='output checkpoint filename')
15
    args = parser.parse_args()
16
    return args
17
18
19
def process_checkpoint(in_file, out_file):
20
    checkpoint = torch.load(in_file, map_location='cpu')
21
    # remove optimizer for smaller file size
22
    if 'optimizer' in checkpoint:
23
        del checkpoint['optimizer']
24
    # if it is necessary to remove some sensitive data in checkpoint['meta'],
25
    # add the code here.
26
    torch.save(checkpoint, out_file)
27
    if platform.system() == 'Windows':
28
        sha = subprocess.check_output(
29
            ['certutil', '-hashfile', out_file, 'SHA256'])
30
        sha = str(sha).split('\\r\\n')[1]
31
    else:
32
        sha = subprocess.check_output(['sha256sum', out_file]).decode()
33
    if out_file.endswith('.pth'):
34
        out_file_name = out_file[:-4]
35
    else:
36
        out_file_name = out_file
37
    final_file = out_file_name + f'-{sha[:8]}.pth'
38
    os.rename(out_file, final_file)
39
40
41
def main():
42
    args = parse_args()
43
    process_checkpoint(args.in_file, args.out_file)
44
45
46
if __name__ == '__main__':
47
    main()