|
a |
|
b/tools/publish_model.py |
|
|
1 |
# Copyright (c) OpenMMLab. All rights reserved. |
|
|
2 |
import argparse |
|
|
3 |
import subprocess |
|
|
4 |
|
|
|
5 |
import torch |
|
|
6 |
|
|
|
7 |
|
|
|
8 |
def parse_args(): |
|
|
9 |
parser = argparse.ArgumentParser( |
|
|
10 |
description='Process a checkpoint to be published') |
|
|
11 |
parser.add_argument('in_file', help='input checkpoint filename') |
|
|
12 |
parser.add_argument('out_file', help='output checkpoint filename') |
|
|
13 |
args = parser.parse_args() |
|
|
14 |
return args |
|
|
15 |
|
|
|
16 |
|
|
|
17 |
def process_checkpoint(in_file, out_file): |
|
|
18 |
checkpoint = torch.load(in_file, map_location='cpu') |
|
|
19 |
# remove optimizer for smaller file size |
|
|
20 |
if 'optimizer' in checkpoint: |
|
|
21 |
del checkpoint['optimizer'] |
|
|
22 |
# if it is necessary to remove some sensitive data in checkpoint['meta'], |
|
|
23 |
# add the code here. |
|
|
24 |
torch.save(checkpoint, out_file) |
|
|
25 |
sha = subprocess.check_output(['sha256sum', out_file]).decode() |
|
|
26 |
final_file = out_file.rstrip('.pth') + '-{}.pth'.format(sha[:8]) |
|
|
27 |
subprocess.Popen(['mv', out_file, final_file]) |
|
|
28 |
|
|
|
29 |
|
|
|
30 |
def main(): |
|
|
31 |
args = parse_args() |
|
|
32 |
process_checkpoint(args.in_file, args.out_file) |
|
|
33 |
|
|
|
34 |
|
|
|
35 |
if __name__ == '__main__': |
|
|
36 |
main() |