|
a |
|
b/plot_learning_curves.py |
|
|
1 |
import matplotlib.pyplot as plt |
|
|
2 |
import pandas as pd |
|
|
3 |
import numpy as np |
|
|
4 |
|
|
|
5 |
if __name__ == "__main__": |
|
|
6 |
import argparse |
|
|
7 |
parser = argparse.ArgumentParser(description='plot learnign curve.') |
|
|
8 |
parser.add_argument('history_file', type=str, |
|
|
9 |
help="path to history file.") |
|
|
10 |
parser.add_argument('--plot_style', nargs='*', default=[], |
|
|
11 |
help='plot styles to be used') |
|
|
12 |
parser.add_argument('--save', default='', |
|
|
13 |
help='save the plot in the given file') |
|
|
14 |
args = parser.parse_args() |
|
|
15 |
|
|
|
16 |
if args.plot_style: |
|
|
17 |
plt.style.use(args.plot_style) |
|
|
18 |
|
|
|
19 |
df = pd.read_csv(args.history_file) |
|
|
20 |
|
|
|
21 |
# Plot MAE |
|
|
22 |
fig, ax = plt.subplots() |
|
|
23 |
ax.plot(df['epoch']+1, df['mae'], label='train', color='blue') |
|
|
24 |
ax.set_xlabel('epoch') |
|
|
25 |
ax.set_ylabel('MAE (years)', color='blue') |
|
|
26 |
ax.set_ylim((8, 14)) |
|
|
27 |
axt = ax.twinx() |
|
|
28 |
|
|
|
29 |
# Plot learning rate |
|
|
30 |
axt.step(df['epoch']+1, df['lr'], label='train', alpha=0.4, color='k') |
|
|
31 |
axt.set_yscale('log') |
|
|
32 |
axt.set_ylabel('learning rate', alpha=0.4, color='k') |
|
|
33 |
axt.set_ylim((1e-8, 1e-2)) |
|
|
34 |
|
|
|
35 |
if args.save: |
|
|
36 |
plt.savefig(args.save) |
|
|
37 |
else: |
|
|
38 |
plt.show() |