|
a |
|
b/tools/deployment/publish_model.py |
|
|
1 |
# Copyright (c) OpenMMLab. All rights reserved. |
|
|
2 |
import argparse |
|
|
3 |
import os |
|
|
4 |
import platform |
|
|
5 |
import subprocess |
|
|
6 |
|
|
|
7 |
import torch |
|
|
8 |
|
|
|
9 |
|
|
|
10 |
def parse_args(): |
|
|
11 |
parser = argparse.ArgumentParser( |
|
|
12 |
description='Process a checkpoint to be published') |
|
|
13 |
parser.add_argument('in_file', help='input checkpoint filename') |
|
|
14 |
parser.add_argument('out_file', help='output checkpoint filename') |
|
|
15 |
args = parser.parse_args() |
|
|
16 |
return args |
|
|
17 |
|
|
|
18 |
|
|
|
19 |
def process_checkpoint(in_file, out_file): |
|
|
20 |
checkpoint = torch.load(in_file, map_location='cpu') |
|
|
21 |
# remove optimizer for smaller file size |
|
|
22 |
if 'optimizer' in checkpoint: |
|
|
23 |
del checkpoint['optimizer'] |
|
|
24 |
# if it is necessary to remove some sensitive data in checkpoint['meta'], |
|
|
25 |
# add the code here. |
|
|
26 |
torch.save(checkpoint, out_file) |
|
|
27 |
if platform.system() == 'Windows': |
|
|
28 |
sha = subprocess.check_output( |
|
|
29 |
['certutil', '-hashfile', out_file, 'SHA256']) |
|
|
30 |
sha = str(sha).split('\\r\\n')[1] |
|
|
31 |
else: |
|
|
32 |
sha = subprocess.check_output(['sha256sum', out_file]).decode() |
|
|
33 |
if out_file.endswith('.pth'): |
|
|
34 |
out_file_name = out_file[:-4] |
|
|
35 |
else: |
|
|
36 |
out_file_name = out_file |
|
|
37 |
final_file = out_file_name + f'-{sha[:8]}.pth' |
|
|
38 |
os.rename(out_file, final_file) |
|
|
39 |
|
|
|
40 |
|
|
|
41 |
def main(): |
|
|
42 |
args = parse_args() |
|
|
43 |
process_checkpoint(args.in_file, args.out_file) |
|
|
44 |
|
|
|
45 |
|
|
|
46 |
if __name__ == '__main__': |
|
|
47 |
main() |