a b/.dev/gather_benchmark_train_results.py
1
import argparse
2
import glob
3
import os.path as osp
4
5
import mmcv
6
from gather_models import get_final_results
7
from mmcv import Config
8
9
10
def parse_args():
11
    parser = argparse.ArgumentParser(
12
        description='Gather benchmarked models train results')
13
    parser.add_argument('config', help='test config file path')
14
    parser.add_argument(
15
        'root',
16
        type=str,
17
        help='root path of benchmarked models to be gathered')
18
    parser.add_argument(
19
        '--out',
20
        type=str,
21
        default='benchmark_train_info.json',
22
        help='output path of gathered metrics to be stored')
23
24
    args = parser.parse_args()
25
    return args
26
27
28
if __name__ == '__main__':
29
    args = parse_args()
30
31
    root_path = args.root
32
    metrics_out = args.out
33
34
    evaluation_cfg = Config.fromfile(args.config)
35
36
    result_dict = {}
37
    for model_key in evaluation_cfg:
38
        model_infos = evaluation_cfg[model_key]
39
        if not isinstance(model_infos, list):
40
            model_infos = [model_infos]
41
        for model_info in model_infos:
42
            config = model_info['config']
43
44
            # benchmark train dir
45
            model_name = osp.split(osp.dirname(config))[1]
46
            config_name = osp.splitext(osp.basename(config))[0]
47
            exp_dir = osp.join(root_path, model_name, config_name)
48
            if not osp.exists(exp_dir):
49
                print(f'{config} hasn\'t {exp_dir}')
50
                continue
51
52
            # parse config
53
            cfg = mmcv.Config.fromfile(config)
54
            total_iters = cfg.runner.max_iters
55
            exp_metric = cfg.evaluation.metric
56
            if not isinstance(exp_metric, list):
57
                exp_metrics = [exp_metric]
58
59
            # determine whether total_iters ckpt exists
60
            ckpt_path = f'iter_{total_iters}.pth'
61
            if not osp.exists(osp.join(exp_dir, ckpt_path)):
62
                print(f'{config} hasn\'t {ckpt_path}')
63
                continue
64
65
            # only the last log json counts
66
            log_json_path = list(
67
                sorted(glob.glob(osp.join(exp_dir, '*.log.json'))))[-1]
68
69
            # extract metric value
70
            model_performance = get_final_results(log_json_path, total_iters)
71
            if model_performance is None:
72
                print(f'log file error: {log_json_path}')
73
                continue
74
75
            differential_results = dict()
76
            old_results = dict()
77
            new_results = dict()
78
            for metric_key in model_performance:
79
                if metric_key in ['mIoU']:
80
                    metric = round(model_performance[metric_key] * 100, 2)
81
                    old_metric = model_info['metric'][metric_key]
82
                    old_results[metric_key] = old_metric
83
                    new_results[metric_key] = metric
84
                    differential = metric - old_metric
85
                    flag = '+' if differential > 0 else '-'
86
                    differential_results[
87
                        metric_key] = f'{flag}{abs(differential):.2f}'
88
            result_dict[config] = dict(
89
                differential_results=differential_results,
90
                old_results=old_results,
91
                new_results=new_results,
92
            )
93
94
    # 4 save or print results
95
    if metrics_out:
96
        mmcv.dump(result_dict, metrics_out, indent=4)
97
    print('===================================')
98
    for config_name, metrics in result_dict.items():
99
        print(config_name, metrics)
100
    print('===================================')