Diff of /plot_learning_curves.py [000000] .. [a8f942]

Switch to unified view

a b/plot_learning_curves.py
1
import matplotlib.pyplot as plt
2
import pandas as pd
3
import numpy as np
4
5
if __name__ == "__main__":
6
    import argparse
7
    parser = argparse.ArgumentParser(description='plot learnign curve.')
8
    parser.add_argument('history_file', type=str,
9
                        help="path to history file.")
10
    parser.add_argument('--plot_style', nargs='*', default=[],
11
                        help='plot styles to be used')
12
    parser.add_argument('--save', default='',
13
                        help='save the plot in the given file')
14
    args = parser.parse_args()
15
16
    if args.plot_style:
17
        plt.style.use(args.plot_style)
18
19
    df = pd.read_csv(args.history_file)
20
21
    # Plot MAE
22
    fig, ax = plt.subplots()
23
    ax.plot(df['epoch']+1, df['mae'], label='train', color='blue')
24
    ax.set_xlabel('epoch')
25
    ax.set_ylabel('MAE (years)', color='blue')
26
    ax.set_ylim((8, 14))
27
    axt = ax.twinx()
28
29
    # Plot learning rate
30
    axt.step(df['epoch']+1, df['lr'], label='train', alpha=0.4, color='k')
31
    axt.set_yscale('log')
32
    axt.set_ylabel('learning rate', alpha=0.4, color='k')
33
    axt.set_ylim((1e-8, 1e-2))
34
35
    if args.save:
36
        plt.savefig(args.save)
37
    else:
38
        plt.show()