Diff of /plot_trainingloss.py [000000] .. [168bda]

Switch to unified view

a b/plot_trainingloss.py
1
import os
2
import sys
3
import numpy as np
4
import matplotlib.pyplot as plt
5
import math
6
import pylab
7
import sys
8
import argparse
9
import re
10
from pylab import figure, show, legend, ylabel
11
12
from mpl_toolkits.axes_grid1 import host_subplot
13
14
if __name__ == "__main__":
15
    plt.ion()
16
    host = host_subplot(111)
17
    host.set_xlabel("Iterations")
18
    host.set_ylabel("Loss")
19
    plt.subplots_adjust(right=0.75)
20
21
22
    while True:
23
        parser = argparse.ArgumentParser(description='makes a plot from Caffe output')
24
        parser.add_argument('output_file', help='file of captured stdout and stderr')
25
        args = parser.parse_args()
26
27
        f = open(args.output_file, 'r')
28
29
        training_iterations = []
30
        training_loss = []
31
32
        test_iterations = []
33
        test_accuracy = []
34
        test_loss = []
35
36
        check_test = False
37
        check_test2 = False
38
        for line in f:
39
40
            # if check_test:
41
            #     #test_accuracy.append(float(line.strip().split(' = ')[-1]))
42
            #     check_test = False
43
            #     check_test2 = True
44
            # elif check_test2:
45
            if 'Test net output' in line and 'loss = ' in line:
46
                # print line
47
                #print line.strip().split(' ')
48
                test_loss.append(float(line.strip().split(' ')[-2]))
49
                check_test2 = False
50
            # else:
51
            #     test_loss.append(0)
52
            #     check_test2 = False
53
54
            if '] Iteration ' in line and 'loss = ' in line:
55
                arr = re.findall(r'ion \b\d+\b,', line)
56
                training_iterations.append(int(arr[0].strip(',')[4:]))
57
                training_loss.append(float(line.strip().split(' = ')[-1]))
58
59
            if '] Iteration ' in line and 'Testing net' in line:
60
                arr = re.findall(r'ion \b\d+\b,', line)
61
                test_iterations.append(int(arr[0].strip(',')[4:]))
62
                check_test = True
63
64
        print 'train iterations len: ', len(training_iterations)
65
        print 'train loss len: ', len(training_loss)
66
        print 'test loss len: ', len(test_loss)
67
        print 'test iterations len: ', len(test_iterations)
68
        #print 'test accuracy len: ', len(test_accuracy)
69
70
        # if len(test_iterations) != len(test_accuracy):  # awaiting test...
71
        #     print 'mis-match'
72
        #     print len(test_iterations[0:-1])
73
        #     test_iterations = test_iterations[0:-1]
74
75
        f.close()
76
        #  plt.plot(training_iterations, training_loss, '-', linewidth=2)
77
        #  plt.plot(test_iterations, test_accuracy, '-', linewidth=2)
78
        #  plt.show()
79
80
        # host = host_subplot(111)  # , axes_class=AA.Axes)
81
        # plt.subplots_adjust(right=0.75)
82
83
        #par1 = host.twinx()
84
85
        # host.set_xlabel("iterations")
86
        # host.set_ylabel("log loss")
87
        #par1.set_ylabel("validation accuracy")
88
89
        host.clear()
90
        host.clear()
91
        host.set_xlabel("Iterations")
92
        host.set_ylabel("Loss")
93
        #p1, = host.plot(training_iterations, training_loss, label="training loss")
94
        if len(training_iterations) == len(training_loss):
95
            p1, = host.plot(training_iterations, training_loss, label="training loss")
96
        if len(test_iterations) == len(test_loss):
97
            p3, = host.plot(test_iterations, test_loss, label="valdation loss")
98
        #p2, = par1.plot(test_iterations, test_accuracy, label="validation accuracy")
99
100
        host.legend(loc=2)
101
102
        #host.axis["left"].label.set_color(p1.get_color())
103
        #par1.axis["right"].label.set_color(p2.get_color())
104
        #fig = plt.figure()
105
        #fig.patch.set_facecolor('white')
106
107
        #axes = plt.gca()
108
        #ymin, ymax = min(training_loss), max(training_loss)
109
        #axes.set_xlim([xmin, xmax])
110
        #axes.set_ylim([0, ymax])
111
        #plt.yticks([0, 0.2, 0.4, 0.6, 0.8,1.0, 1.2, 1.4, 1.6])
112
        plt.grid()
113
        plt.draw()
114
        plt.show()
115
        plt.pause(5)
116
117
118
119