a b/.dev/gather_models.py
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import argparse
3
import glob
4
import hashlib
5
import json
6
import os
7
import os.path as osp
8
import shutil
9
10
import mmcv
11
import torch
12
13
# build schedule look-up table to automatically find the final model
14
RESULTS_LUT = ['mIoU', 'mAcc', 'aAcc']
15
16
17
def calculate_file_sha256(file_path):
18
    """calculate file sha256 hash code."""
19
    with open(file_path, 'rb') as fp:
20
        sha256_cal = hashlib.sha256()
21
        sha256_cal.update(fp.read())
22
        return sha256_cal.hexdigest()
23
24
25
def process_checkpoint(in_file, out_file):
26
    checkpoint = torch.load(in_file, map_location='cpu')
27
    # remove optimizer for smaller file size
28
    if 'optimizer' in checkpoint:
29
        del checkpoint['optimizer']
30
    # if it is necessary to remove some sensitive data in checkpoint['meta'],
31
    # add the code here.
32
    torch.save(checkpoint, out_file)
33
    # The hash code calculation and rename command differ on different system
34
    # platform.
35
    sha = calculate_file_sha256(out_file)
36
    final_file = out_file.rstrip('.pth') + '-{}.pth'.format(sha[:8])
37
    os.rename(out_file, final_file)
38
39
    # Remove prefix and suffix
40
    final_file_name = osp.split(final_file)[1]
41
    final_file_name = osp.splitext(final_file_name)[0]
42
43
    return final_file_name
44
45
46
def get_final_iter(config):
47
    iter_num = config.split('_')[-2]
48
    assert iter_num.endswith('k')
49
    return int(iter_num[:-1]) * 1000
50
51
52
def get_final_results(log_json_path, iter_num):
53
    result_dict = dict()
54
    last_iter = 0
55
    with open(log_json_path, 'r') as f:
56
        for line in f.readlines():
57
            log_line = json.loads(line)
58
            if 'mode' not in log_line.keys():
59
                continue
60
61
            # When evaluation, the 'iter' of new log json is the evaluation
62
            # steps on single gpu.
63
            flag1 = ('aAcc' in log_line) or (log_line['mode'] == 'val')
64
            flag2 = (last_iter == iter_num - 50) or (last_iter == iter_num)
65
            if flag1 and flag2:
66
                result_dict.update({
67
                    key: log_line[key]
68
                    for key in RESULTS_LUT if key in log_line
69
                })
70
                return result_dict
71
72
            last_iter = log_line['iter']
73
74
75
def parse_args():
76
    parser = argparse.ArgumentParser(description='Gather benchmarked models')
77
    parser.add_argument(
78
        '-f', '--config-name', type=str, help='Process the selected config.')
79
    parser.add_argument(
80
        '-w',
81
        '--work-dir',
82
        default='work_dirs/',
83
        type=str,
84
        help='Ckpt storage root folder of benchmarked models to be gathered.')
85
    parser.add_argument(
86
        '-c',
87
        '--collect-dir',
88
        default='work_dirs/gather',
89
        type=str,
90
        help='Ckpt collect root folder of gathered models.')
91
    parser.add_argument(
92
        '--all', action='store_true', help='whether include .py and .log')
93
94
    args = parser.parse_args()
95
    return args
96
97
98
def main():
99
    args = parse_args()
100
    work_dir = args.work_dir
101
    collect_dir = args.collect_dir
102
    selected_config_name = args.config_name
103
    mmcv.mkdir_or_exist(collect_dir)
104
105
    # find all models in the root directory to be gathered
106
    raw_configs = list(mmcv.scandir('./configs', '.py', recursive=True))
107
108
    # filter configs that is not trained in the experiments dir
109
    used_configs = []
110
    for raw_config in raw_configs:
111
        config_name = osp.splitext(osp.basename(raw_config))[0]
112
        if osp.exists(osp.join(work_dir, config_name)):
113
            if (selected_config_name is None
114
                    or selected_config_name == config_name):
115
                used_configs.append(raw_config)
116
    print(f'Find {len(used_configs)} models to be gathered')
117
118
    # find final_ckpt and log file for trained each config
119
    # and parse the best performance
120
    model_infos = []
121
    for used_config in used_configs:
122
        config_name = osp.splitext(osp.basename(used_config))[0]
123
        exp_dir = osp.join(work_dir, config_name)
124
        # check whether the exps is finished
125
        final_iter = get_final_iter(used_config)
126
        final_model = 'iter_{}.pth'.format(final_iter)
127
        model_path = osp.join(exp_dir, final_model)
128
129
        # skip if the model is still training
130
        if not osp.exists(model_path):
131
            print(f'{used_config} train not finished yet')
132
            continue
133
134
        # get logs
135
        log_json_paths = glob.glob(osp.join(exp_dir, '*.log.json'))
136
        log_json_path = log_json_paths[0]
137
        model_performance = None
138
        for idx, _log_json_path in enumerate(log_json_paths):
139
            model_performance = get_final_results(_log_json_path, final_iter)
140
            if model_performance is not None:
141
                log_json_path = _log_json_path
142
                break
143
144
        if model_performance is None:
145
            print(f'{used_config} model_performance is None')
146
            continue
147
148
        model_time = osp.split(log_json_path)[-1].split('.')[0]
149
        model_infos.append(
150
            dict(
151
                config_name=config_name,
152
                results=model_performance,
153
                iters=final_iter,
154
                model_time=model_time,
155
                log_json_path=osp.split(log_json_path)[-1]))
156
157
    # publish model for each checkpoint
158
    publish_model_infos = []
159
    for model in model_infos:
160
        config_name = model['config_name']
161
        model_publish_dir = osp.join(collect_dir, config_name)
162
163
        publish_model_path = osp.join(model_publish_dir,
164
                                      config_name + '_' + model['model_time'])
165
        trained_model_path = osp.join(work_dir, config_name,
166
                                      'iter_{}.pth'.format(model['iters']))
167
        if osp.exists(model_publish_dir):
168
            for file in os.listdir(model_publish_dir):
169
                if file.endswith('.pth'):
170
                    print(f'model {file} found')
171
                    model['model_path'] = osp.abspath(
172
                        osp.join(model_publish_dir, file))
173
                    break
174
            if 'model_path' not in model:
175
                print(f'dir {model_publish_dir} exists, no model found')
176
177
        else:
178
            mmcv.mkdir_or_exist(model_publish_dir)
179
180
            # convert model
181
            final_model_path = process_checkpoint(trained_model_path,
182
                                                  publish_model_path)
183
            model['model_path'] = final_model_path
184
185
        new_json_path = f'{config_name}_{model["log_json_path"]}'
186
        # copy log
187
        shutil.copy(
188
            osp.join(work_dir, config_name, model['log_json_path']),
189
            osp.join(model_publish_dir, new_json_path))
190
191
        if args.all:
192
            new_txt_path = new_json_path.rstrip('.json')
193
            shutil.copy(
194
                osp.join(work_dir, config_name,
195
                         model['log_json_path'].rstrip('.json')),
196
                osp.join(model_publish_dir, new_txt_path))
197
198
        if args.all:
199
            # copy config to guarantee reproducibility
200
            raw_config = osp.join('./configs', f'{config_name}.py')
201
            mmcv.Config.fromfile(raw_config).dump(
202
                osp.join(model_publish_dir, osp.basename(raw_config)))
203
204
        publish_model_infos.append(model)
205
206
    models = dict(models=publish_model_infos)
207
    mmcv.dump(models, osp.join(collect_dir, 'model_infos.json'), indent=4)
208
209
210
if __name__ == '__main__':
211
    main()