|
a |
|
b/.dev/gather_models.py |
|
|
1 |
# Copyright (c) OpenMMLab. All rights reserved. |
|
|
2 |
import argparse |
|
|
3 |
import glob |
|
|
4 |
import hashlib |
|
|
5 |
import json |
|
|
6 |
import os |
|
|
7 |
import os.path as osp |
|
|
8 |
import shutil |
|
|
9 |
|
|
|
10 |
import mmcv |
|
|
11 |
import torch |
|
|
12 |
|
|
|
13 |
# build schedule look-up table to automatically find the final model |
|
|
14 |
RESULTS_LUT = ['mIoU', 'mAcc', 'aAcc'] |
|
|
15 |
|
|
|
16 |
|
|
|
17 |
def calculate_file_sha256(file_path): |
|
|
18 |
"""calculate file sha256 hash code.""" |
|
|
19 |
with open(file_path, 'rb') as fp: |
|
|
20 |
sha256_cal = hashlib.sha256() |
|
|
21 |
sha256_cal.update(fp.read()) |
|
|
22 |
return sha256_cal.hexdigest() |
|
|
23 |
|
|
|
24 |
|
|
|
25 |
def process_checkpoint(in_file, out_file): |
|
|
26 |
checkpoint = torch.load(in_file, map_location='cpu') |
|
|
27 |
# remove optimizer for smaller file size |
|
|
28 |
if 'optimizer' in checkpoint: |
|
|
29 |
del checkpoint['optimizer'] |
|
|
30 |
# if it is necessary to remove some sensitive data in checkpoint['meta'], |
|
|
31 |
# add the code here. |
|
|
32 |
torch.save(checkpoint, out_file) |
|
|
33 |
# The hash code calculation and rename command differ on different system |
|
|
34 |
# platform. |
|
|
35 |
sha = calculate_file_sha256(out_file) |
|
|
36 |
final_file = out_file.rstrip('.pth') + '-{}.pth'.format(sha[:8]) |
|
|
37 |
os.rename(out_file, final_file) |
|
|
38 |
|
|
|
39 |
# Remove prefix and suffix |
|
|
40 |
final_file_name = osp.split(final_file)[1] |
|
|
41 |
final_file_name = osp.splitext(final_file_name)[0] |
|
|
42 |
|
|
|
43 |
return final_file_name |
|
|
44 |
|
|
|
45 |
|
|
|
46 |
def get_final_iter(config): |
|
|
47 |
iter_num = config.split('_')[-2] |
|
|
48 |
assert iter_num.endswith('k') |
|
|
49 |
return int(iter_num[:-1]) * 1000 |
|
|
50 |
|
|
|
51 |
|
|
|
52 |
def get_final_results(log_json_path, iter_num): |
|
|
53 |
result_dict = dict() |
|
|
54 |
last_iter = 0 |
|
|
55 |
with open(log_json_path, 'r') as f: |
|
|
56 |
for line in f.readlines(): |
|
|
57 |
log_line = json.loads(line) |
|
|
58 |
if 'mode' not in log_line.keys(): |
|
|
59 |
continue |
|
|
60 |
|
|
|
61 |
# When evaluation, the 'iter' of new log json is the evaluation |
|
|
62 |
# steps on single gpu. |
|
|
63 |
flag1 = ('aAcc' in log_line) or (log_line['mode'] == 'val') |
|
|
64 |
flag2 = (last_iter == iter_num - 50) or (last_iter == iter_num) |
|
|
65 |
if flag1 and flag2: |
|
|
66 |
result_dict.update({ |
|
|
67 |
key: log_line[key] |
|
|
68 |
for key in RESULTS_LUT if key in log_line |
|
|
69 |
}) |
|
|
70 |
return result_dict |
|
|
71 |
|
|
|
72 |
last_iter = log_line['iter'] |
|
|
73 |
|
|
|
74 |
|
|
|
75 |
def parse_args(): |
|
|
76 |
parser = argparse.ArgumentParser(description='Gather benchmarked models') |
|
|
77 |
parser.add_argument( |
|
|
78 |
'-f', '--config-name', type=str, help='Process the selected config.') |
|
|
79 |
parser.add_argument( |
|
|
80 |
'-w', |
|
|
81 |
'--work-dir', |
|
|
82 |
default='work_dirs/', |
|
|
83 |
type=str, |
|
|
84 |
help='Ckpt storage root folder of benchmarked models to be gathered.') |
|
|
85 |
parser.add_argument( |
|
|
86 |
'-c', |
|
|
87 |
'--collect-dir', |
|
|
88 |
default='work_dirs/gather', |
|
|
89 |
type=str, |
|
|
90 |
help='Ckpt collect root folder of gathered models.') |
|
|
91 |
parser.add_argument( |
|
|
92 |
'--all', action='store_true', help='whether include .py and .log') |
|
|
93 |
|
|
|
94 |
args = parser.parse_args() |
|
|
95 |
return args |
|
|
96 |
|
|
|
97 |
|
|
|
98 |
def main(): |
|
|
99 |
args = parse_args() |
|
|
100 |
work_dir = args.work_dir |
|
|
101 |
collect_dir = args.collect_dir |
|
|
102 |
selected_config_name = args.config_name |
|
|
103 |
mmcv.mkdir_or_exist(collect_dir) |
|
|
104 |
|
|
|
105 |
# find all models in the root directory to be gathered |
|
|
106 |
raw_configs = list(mmcv.scandir('./configs', '.py', recursive=True)) |
|
|
107 |
|
|
|
108 |
# filter configs that is not trained in the experiments dir |
|
|
109 |
used_configs = [] |
|
|
110 |
for raw_config in raw_configs: |
|
|
111 |
config_name = osp.splitext(osp.basename(raw_config))[0] |
|
|
112 |
if osp.exists(osp.join(work_dir, config_name)): |
|
|
113 |
if (selected_config_name is None |
|
|
114 |
or selected_config_name == config_name): |
|
|
115 |
used_configs.append(raw_config) |
|
|
116 |
print(f'Find {len(used_configs)} models to be gathered') |
|
|
117 |
|
|
|
118 |
# find final_ckpt and log file for trained each config |
|
|
119 |
# and parse the best performance |
|
|
120 |
model_infos = [] |
|
|
121 |
for used_config in used_configs: |
|
|
122 |
config_name = osp.splitext(osp.basename(used_config))[0] |
|
|
123 |
exp_dir = osp.join(work_dir, config_name) |
|
|
124 |
# check whether the exps is finished |
|
|
125 |
final_iter = get_final_iter(used_config) |
|
|
126 |
final_model = 'iter_{}.pth'.format(final_iter) |
|
|
127 |
model_path = osp.join(exp_dir, final_model) |
|
|
128 |
|
|
|
129 |
# skip if the model is still training |
|
|
130 |
if not osp.exists(model_path): |
|
|
131 |
print(f'{used_config} train not finished yet') |
|
|
132 |
continue |
|
|
133 |
|
|
|
134 |
# get logs |
|
|
135 |
log_json_paths = glob.glob(osp.join(exp_dir, '*.log.json')) |
|
|
136 |
log_json_path = log_json_paths[0] |
|
|
137 |
model_performance = None |
|
|
138 |
for idx, _log_json_path in enumerate(log_json_paths): |
|
|
139 |
model_performance = get_final_results(_log_json_path, final_iter) |
|
|
140 |
if model_performance is not None: |
|
|
141 |
log_json_path = _log_json_path |
|
|
142 |
break |
|
|
143 |
|
|
|
144 |
if model_performance is None: |
|
|
145 |
print(f'{used_config} model_performance is None') |
|
|
146 |
continue |
|
|
147 |
|
|
|
148 |
model_time = osp.split(log_json_path)[-1].split('.')[0] |
|
|
149 |
model_infos.append( |
|
|
150 |
dict( |
|
|
151 |
config_name=config_name, |
|
|
152 |
results=model_performance, |
|
|
153 |
iters=final_iter, |
|
|
154 |
model_time=model_time, |
|
|
155 |
log_json_path=osp.split(log_json_path)[-1])) |
|
|
156 |
|
|
|
157 |
# publish model for each checkpoint |
|
|
158 |
publish_model_infos = [] |
|
|
159 |
for model in model_infos: |
|
|
160 |
config_name = model['config_name'] |
|
|
161 |
model_publish_dir = osp.join(collect_dir, config_name) |
|
|
162 |
|
|
|
163 |
publish_model_path = osp.join(model_publish_dir, |
|
|
164 |
config_name + '_' + model['model_time']) |
|
|
165 |
trained_model_path = osp.join(work_dir, config_name, |
|
|
166 |
'iter_{}.pth'.format(model['iters'])) |
|
|
167 |
if osp.exists(model_publish_dir): |
|
|
168 |
for file in os.listdir(model_publish_dir): |
|
|
169 |
if file.endswith('.pth'): |
|
|
170 |
print(f'model {file} found') |
|
|
171 |
model['model_path'] = osp.abspath( |
|
|
172 |
osp.join(model_publish_dir, file)) |
|
|
173 |
break |
|
|
174 |
if 'model_path' not in model: |
|
|
175 |
print(f'dir {model_publish_dir} exists, no model found') |
|
|
176 |
|
|
|
177 |
else: |
|
|
178 |
mmcv.mkdir_or_exist(model_publish_dir) |
|
|
179 |
|
|
|
180 |
# convert model |
|
|
181 |
final_model_path = process_checkpoint(trained_model_path, |
|
|
182 |
publish_model_path) |
|
|
183 |
model['model_path'] = final_model_path |
|
|
184 |
|
|
|
185 |
new_json_path = f'{config_name}_{model["log_json_path"]}' |
|
|
186 |
# copy log |
|
|
187 |
shutil.copy( |
|
|
188 |
osp.join(work_dir, config_name, model['log_json_path']), |
|
|
189 |
osp.join(model_publish_dir, new_json_path)) |
|
|
190 |
|
|
|
191 |
if args.all: |
|
|
192 |
new_txt_path = new_json_path.rstrip('.json') |
|
|
193 |
shutil.copy( |
|
|
194 |
osp.join(work_dir, config_name, |
|
|
195 |
model['log_json_path'].rstrip('.json')), |
|
|
196 |
osp.join(model_publish_dir, new_txt_path)) |
|
|
197 |
|
|
|
198 |
if args.all: |
|
|
199 |
# copy config to guarantee reproducibility |
|
|
200 |
raw_config = osp.join('./configs', f'{config_name}.py') |
|
|
201 |
mmcv.Config.fromfile(raw_config).dump( |
|
|
202 |
osp.join(model_publish_dir, osp.basename(raw_config))) |
|
|
203 |
|
|
|
204 |
publish_model_infos.append(model) |
|
|
205 |
|
|
|
206 |
models = dict(models=publish_model_infos) |
|
|
207 |
mmcv.dump(models, osp.join(collect_dir, 'model_infos.json'), indent=4) |
|
|
208 |
|
|
|
209 |
|
|
|
210 |
if __name__ == '__main__': |
|
|
211 |
main() |