|
a |
|
b/plot_trainingloss.py |
|
|
1 |
import os |
|
|
2 |
import sys |
|
|
3 |
import numpy as np |
|
|
4 |
import matplotlib.pyplot as plt |
|
|
5 |
import math |
|
|
6 |
import pylab |
|
|
7 |
import sys |
|
|
8 |
import argparse |
|
|
9 |
import re |
|
|
10 |
from pylab import figure, show, legend, ylabel |
|
|
11 |
|
|
|
12 |
from mpl_toolkits.axes_grid1 import host_subplot |
|
|
13 |
|
|
|
14 |
if __name__ == "__main__": |
|
|
15 |
plt.ion() |
|
|
16 |
host = host_subplot(111) |
|
|
17 |
host.set_xlabel("Iterations") |
|
|
18 |
host.set_ylabel("Loss") |
|
|
19 |
plt.subplots_adjust(right=0.75) |
|
|
20 |
|
|
|
21 |
|
|
|
22 |
while True: |
|
|
23 |
parser = argparse.ArgumentParser(description='makes a plot from Caffe output') |
|
|
24 |
parser.add_argument('output_file', help='file of captured stdout and stderr') |
|
|
25 |
args = parser.parse_args() |
|
|
26 |
|
|
|
27 |
f = open(args.output_file, 'r') |
|
|
28 |
|
|
|
29 |
training_iterations = [] |
|
|
30 |
training_loss = [] |
|
|
31 |
|
|
|
32 |
test_iterations = [] |
|
|
33 |
test_accuracy = [] |
|
|
34 |
test_loss = [] |
|
|
35 |
|
|
|
36 |
check_test = False |
|
|
37 |
check_test2 = False |
|
|
38 |
for line in f: |
|
|
39 |
|
|
|
40 |
# if check_test: |
|
|
41 |
# #test_accuracy.append(float(line.strip().split(' = ')[-1])) |
|
|
42 |
# check_test = False |
|
|
43 |
# check_test2 = True |
|
|
44 |
# elif check_test2: |
|
|
45 |
if 'Test net output' in line and 'loss = ' in line: |
|
|
46 |
# print line |
|
|
47 |
#print line.strip().split(' ') |
|
|
48 |
test_loss.append(float(line.strip().split(' ')[-2])) |
|
|
49 |
check_test2 = False |
|
|
50 |
# else: |
|
|
51 |
# test_loss.append(0) |
|
|
52 |
# check_test2 = False |
|
|
53 |
|
|
|
54 |
if '] Iteration ' in line and 'loss = ' in line: |
|
|
55 |
arr = re.findall(r'ion \b\d+\b,', line) |
|
|
56 |
training_iterations.append(int(arr[0].strip(',')[4:])) |
|
|
57 |
training_loss.append(float(line.strip().split(' = ')[-1])) |
|
|
58 |
|
|
|
59 |
if '] Iteration ' in line and 'Testing net' in line: |
|
|
60 |
arr = re.findall(r'ion \b\d+\b,', line) |
|
|
61 |
test_iterations.append(int(arr[0].strip(',')[4:])) |
|
|
62 |
check_test = True |
|
|
63 |
|
|
|
64 |
print 'train iterations len: ', len(training_iterations) |
|
|
65 |
print 'train loss len: ', len(training_loss) |
|
|
66 |
print 'test loss len: ', len(test_loss) |
|
|
67 |
print 'test iterations len: ', len(test_iterations) |
|
|
68 |
#print 'test accuracy len: ', len(test_accuracy) |
|
|
69 |
|
|
|
70 |
# if len(test_iterations) != len(test_accuracy): # awaiting test... |
|
|
71 |
# print 'mis-match' |
|
|
72 |
# print len(test_iterations[0:-1]) |
|
|
73 |
# test_iterations = test_iterations[0:-1] |
|
|
74 |
|
|
|
75 |
f.close() |
|
|
76 |
# plt.plot(training_iterations, training_loss, '-', linewidth=2) |
|
|
77 |
# plt.plot(test_iterations, test_accuracy, '-', linewidth=2) |
|
|
78 |
# plt.show() |
|
|
79 |
|
|
|
80 |
# host = host_subplot(111) # , axes_class=AA.Axes) |
|
|
81 |
# plt.subplots_adjust(right=0.75) |
|
|
82 |
|
|
|
83 |
#par1 = host.twinx() |
|
|
84 |
|
|
|
85 |
# host.set_xlabel("iterations") |
|
|
86 |
# host.set_ylabel("log loss") |
|
|
87 |
#par1.set_ylabel("validation accuracy") |
|
|
88 |
|
|
|
89 |
host.clear() |
|
|
90 |
host.clear() |
|
|
91 |
host.set_xlabel("Iterations") |
|
|
92 |
host.set_ylabel("Loss") |
|
|
93 |
#p1, = host.plot(training_iterations, training_loss, label="training loss") |
|
|
94 |
if len(training_iterations) == len(training_loss): |
|
|
95 |
p1, = host.plot(training_iterations, training_loss, label="training loss") |
|
|
96 |
if len(test_iterations) == len(test_loss): |
|
|
97 |
p3, = host.plot(test_iterations, test_loss, label="valdation loss") |
|
|
98 |
#p2, = par1.plot(test_iterations, test_accuracy, label="validation accuracy") |
|
|
99 |
|
|
|
100 |
host.legend(loc=2) |
|
|
101 |
|
|
|
102 |
#host.axis["left"].label.set_color(p1.get_color()) |
|
|
103 |
#par1.axis["right"].label.set_color(p2.get_color()) |
|
|
104 |
#fig = plt.figure() |
|
|
105 |
#fig.patch.set_facecolor('white') |
|
|
106 |
|
|
|
107 |
#axes = plt.gca() |
|
|
108 |
#ymin, ymax = min(training_loss), max(training_loss) |
|
|
109 |
#axes.set_xlim([xmin, xmax]) |
|
|
110 |
#axes.set_ylim([0, ymax]) |
|
|
111 |
#plt.yticks([0, 0.2, 0.4, 0.6, 0.8,1.0, 1.2, 1.4, 1.6]) |
|
|
112 |
plt.grid() |
|
|
113 |
plt.draw() |
|
|
114 |
plt.show() |
|
|
115 |
plt.pause(5) |
|
|
116 |
|
|
|
117 |
|
|
|
118 |
|
|
|
119 |
|