a b/tools/analyze_logs.py
1
# Copyright (c) OpenMMLab. All rights reserved.
2
"""Modified from https://github.com/open-
3
mmlab/mmdetection/blob/master/tools/analysis_tools/analyze_logs.py."""
4
import argparse
5
import json
6
from collections import defaultdict
7
8
import matplotlib.pyplot as plt
9
import seaborn as sns
10
11
12
def plot_curve(log_dicts, args):
13
    if args.backend is not None:
14
        plt.switch_backend(args.backend)
15
    sns.set_style(args.style)
16
    # if legend is None, use {filename}_{key} as legend
17
    legend = args.legend
18
    if legend is None:
19
        legend = []
20
        for json_log in args.json_logs:
21
            for metric in args.keys:
22
                legend.append(f'{json_log}_{metric}')
23
    assert len(legend) == (len(args.json_logs) * len(args.keys))
24
    metrics = args.keys
25
26
    num_metrics = len(metrics)
27
    for i, log_dict in enumerate(log_dicts):
28
        epochs = list(log_dict.keys())
29
        for j, metric in enumerate(metrics):
30
            print(f'plot curve of {args.json_logs[i]}, metric is {metric}')
31
            plot_epochs = []
32
            plot_iters = []
33
            plot_values = []
34
            # In some log files, iters number is not correct, `pre_iter` is
35
            # used to prevent generate wrong lines.
36
            pre_iter = -1
37
            for epoch in epochs:
38
                epoch_logs = log_dict[epoch]
39
                if metric not in epoch_logs.keys():
40
                    continue
41
                if metric in ['mIoU', 'mAcc', 'aAcc']:
42
                    plot_epochs.append(epoch)
43
                    plot_values.append(epoch_logs[metric][0])
44
                else:
45
                    for idx in range(len(epoch_logs[metric])):
46
                        if pre_iter > epoch_logs['iter'][idx]:
47
                            continue
48
                        pre_iter = epoch_logs['iter'][idx]
49
                        plot_iters.append(epoch_logs['iter'][idx])
50
                        plot_values.append(epoch_logs[metric][idx])
51
            ax = plt.gca()
52
            label = legend[i * num_metrics + j]
53
            if metric in ['mIoU', 'mAcc', 'aAcc']:
54
                ax.set_xticks(plot_epochs)
55
                plt.xlabel('epoch')
56
                plt.plot(plot_epochs, plot_values, label=label, marker='o')
57
            else:
58
                plt.xlabel('iter')
59
                plt.plot(plot_iters, plot_values, label=label, linewidth=0.5)
60
        plt.legend()
61
        if args.title is not None:
62
            plt.title(args.title)
63
    if args.out is None:
64
        plt.show()
65
    else:
66
        print(f'save curve to: {args.out}')
67
        plt.savefig(args.out)
68
        plt.cla()
69
70
71
def parse_args():
72
    parser = argparse.ArgumentParser(description='Analyze Json Log')
73
    parser.add_argument(
74
        'json_logs',
75
        type=str,
76
        nargs='+',
77
        help='path of train log in json format')
78
    parser.add_argument(
79
        '--keys',
80
        type=str,
81
        nargs='+',
82
        default=['mIoU'],
83
        help='the metric that you want to plot')
84
    parser.add_argument('--title', type=str, help='title of figure')
85
    parser.add_argument(
86
        '--legend',
87
        type=str,
88
        nargs='+',
89
        default=None,
90
        help='legend of each plot')
91
    parser.add_argument(
92
        '--backend', type=str, default=None, help='backend of plt')
93
    parser.add_argument(
94
        '--style', type=str, default='dark', help='style of plt')
95
    parser.add_argument('--out', type=str, default=None)
96
    args = parser.parse_args()
97
    return args
98
99
100
def load_json_logs(json_logs):
101
    # load and convert json_logs to log_dict, key is epoch, value is a sub dict
102
    # keys of sub dict is different metrics
103
    # value of sub dict is a list of corresponding values of all iterations
104
    log_dicts = [dict() for _ in json_logs]
105
    for json_log, log_dict in zip(json_logs, log_dicts):
106
        with open(json_log, 'r') as log_file:
107
            for line in log_file:
108
                log = json.loads(line.strip())
109
                # skip lines without `epoch` field
110
                if 'epoch' not in log:
111
                    continue
112
                epoch = log.pop('epoch')
113
                if epoch not in log_dict:
114
                    log_dict[epoch] = defaultdict(list)
115
                for k, v in log.items():
116
                    log_dict[epoch][k].append(v)
117
    return log_dicts
118
119
120
def main():
121
    args = parse_args()
122
    json_logs = args.json_logs
123
    for json_log in json_logs:
124
        assert json_log.endswith('.json')
125
    log_dicts = load_json_logs(json_logs)
126
    plot_curve(log_dicts, args)
127
128
129
if __name__ == '__main__':
130
    main()