|
a |
|
b/tools/analyze_logs.py |
|
|
1 |
# Copyright (c) OpenMMLab. All rights reserved. |
|
|
2 |
"""Modified from https://github.com/open- |
|
|
3 |
mmlab/mmdetection/blob/master/tools/analysis_tools/analyze_logs.py.""" |
|
|
4 |
import argparse |
|
|
5 |
import json |
|
|
6 |
from collections import defaultdict |
|
|
7 |
|
|
|
8 |
import matplotlib.pyplot as plt |
|
|
9 |
import seaborn as sns |
|
|
10 |
|
|
|
11 |
|
|
|
12 |
def plot_curve(log_dicts, args): |
|
|
13 |
if args.backend is not None: |
|
|
14 |
plt.switch_backend(args.backend) |
|
|
15 |
sns.set_style(args.style) |
|
|
16 |
# if legend is None, use {filename}_{key} as legend |
|
|
17 |
legend = args.legend |
|
|
18 |
if legend is None: |
|
|
19 |
legend = [] |
|
|
20 |
for json_log in args.json_logs: |
|
|
21 |
for metric in args.keys: |
|
|
22 |
legend.append(f'{json_log}_{metric}') |
|
|
23 |
assert len(legend) == (len(args.json_logs) * len(args.keys)) |
|
|
24 |
metrics = args.keys |
|
|
25 |
|
|
|
26 |
num_metrics = len(metrics) |
|
|
27 |
for i, log_dict in enumerate(log_dicts): |
|
|
28 |
epochs = list(log_dict.keys()) |
|
|
29 |
for j, metric in enumerate(metrics): |
|
|
30 |
print(f'plot curve of {args.json_logs[i]}, metric is {metric}') |
|
|
31 |
plot_epochs = [] |
|
|
32 |
plot_iters = [] |
|
|
33 |
plot_values = [] |
|
|
34 |
# In some log files, iters number is not correct, `pre_iter` is |
|
|
35 |
# used to prevent generate wrong lines. |
|
|
36 |
pre_iter = -1 |
|
|
37 |
for epoch in epochs: |
|
|
38 |
epoch_logs = log_dict[epoch] |
|
|
39 |
if metric not in epoch_logs.keys(): |
|
|
40 |
continue |
|
|
41 |
if metric in ['mIoU', 'mAcc', 'aAcc']: |
|
|
42 |
plot_epochs.append(epoch) |
|
|
43 |
plot_values.append(epoch_logs[metric][0]) |
|
|
44 |
else: |
|
|
45 |
for idx in range(len(epoch_logs[metric])): |
|
|
46 |
if pre_iter > epoch_logs['iter'][idx]: |
|
|
47 |
continue |
|
|
48 |
pre_iter = epoch_logs['iter'][idx] |
|
|
49 |
plot_iters.append(epoch_logs['iter'][idx]) |
|
|
50 |
plot_values.append(epoch_logs[metric][idx]) |
|
|
51 |
ax = plt.gca() |
|
|
52 |
label = legend[i * num_metrics + j] |
|
|
53 |
if metric in ['mIoU', 'mAcc', 'aAcc']: |
|
|
54 |
ax.set_xticks(plot_epochs) |
|
|
55 |
plt.xlabel('epoch') |
|
|
56 |
plt.plot(plot_epochs, plot_values, label=label, marker='o') |
|
|
57 |
else: |
|
|
58 |
plt.xlabel('iter') |
|
|
59 |
plt.plot(plot_iters, plot_values, label=label, linewidth=0.5) |
|
|
60 |
plt.legend() |
|
|
61 |
if args.title is not None: |
|
|
62 |
plt.title(args.title) |
|
|
63 |
if args.out is None: |
|
|
64 |
plt.show() |
|
|
65 |
else: |
|
|
66 |
print(f'save curve to: {args.out}') |
|
|
67 |
plt.savefig(args.out) |
|
|
68 |
plt.cla() |
|
|
69 |
|
|
|
70 |
|
|
|
71 |
def parse_args(): |
|
|
72 |
parser = argparse.ArgumentParser(description='Analyze Json Log') |
|
|
73 |
parser.add_argument( |
|
|
74 |
'json_logs', |
|
|
75 |
type=str, |
|
|
76 |
nargs='+', |
|
|
77 |
help='path of train log in json format') |
|
|
78 |
parser.add_argument( |
|
|
79 |
'--keys', |
|
|
80 |
type=str, |
|
|
81 |
nargs='+', |
|
|
82 |
default=['mIoU'], |
|
|
83 |
help='the metric that you want to plot') |
|
|
84 |
parser.add_argument('--title', type=str, help='title of figure') |
|
|
85 |
parser.add_argument( |
|
|
86 |
'--legend', |
|
|
87 |
type=str, |
|
|
88 |
nargs='+', |
|
|
89 |
default=None, |
|
|
90 |
help='legend of each plot') |
|
|
91 |
parser.add_argument( |
|
|
92 |
'--backend', type=str, default=None, help='backend of plt') |
|
|
93 |
parser.add_argument( |
|
|
94 |
'--style', type=str, default='dark', help='style of plt') |
|
|
95 |
parser.add_argument('--out', type=str, default=None) |
|
|
96 |
args = parser.parse_args() |
|
|
97 |
return args |
|
|
98 |
|
|
|
99 |
|
|
|
100 |
def load_json_logs(json_logs): |
|
|
101 |
# load and convert json_logs to log_dict, key is epoch, value is a sub dict |
|
|
102 |
# keys of sub dict is different metrics |
|
|
103 |
# value of sub dict is a list of corresponding values of all iterations |
|
|
104 |
log_dicts = [dict() for _ in json_logs] |
|
|
105 |
for json_log, log_dict in zip(json_logs, log_dicts): |
|
|
106 |
with open(json_log, 'r') as log_file: |
|
|
107 |
for line in log_file: |
|
|
108 |
log = json.loads(line.strip()) |
|
|
109 |
# skip lines without `epoch` field |
|
|
110 |
if 'epoch' not in log: |
|
|
111 |
continue |
|
|
112 |
epoch = log.pop('epoch') |
|
|
113 |
if epoch not in log_dict: |
|
|
114 |
log_dict[epoch] = defaultdict(list) |
|
|
115 |
for k, v in log.items(): |
|
|
116 |
log_dict[epoch][k].append(v) |
|
|
117 |
return log_dicts |
|
|
118 |
|
|
|
119 |
|
|
|
120 |
def main(): |
|
|
121 |
args = parse_args() |
|
|
122 |
json_logs = args.json_logs |
|
|
123 |
for json_log in json_logs: |
|
|
124 |
assert json_log.endswith('.json') |
|
|
125 |
log_dicts = load_json_logs(json_logs) |
|
|
126 |
plot_curve(log_dicts, args) |
|
|
127 |
|
|
|
128 |
|
|
|
129 |
if __name__ == '__main__': |
|
|
130 |
main() |