Switch to unified view

a b/scripts/plot_objectives.py
1
import matplotlib
2
matplotlib.use('Agg')
3
import matplotlib.pyplot as plt
4
import os
5
import numpy as np
6
import cPickle as pickle
7
import time
8
import sys
9
10
filename = sys.argv[1]
11
file = open(filename)
12
13
last_chunk = -1
14
training_errors = []
15
validation_errors = []
16
training_idcs = []
17
validation_idcs=[]
18
19
for line in file:
20
    if 'Chunk' in line :
21
        last_chunk = int(line.split()[1].split('/')[0])
22
    if 'Validation loss' in line:
23
        validation_errors.append(float(line.split(':')[1].rsplit()[0]))
24
        validation_idcs.append(last_chunk)
25
    if 'Mean train loss' in line:
26
        training_errors.append(float(line.split(':')[1].rsplit()[0]))
27
        training_idcs.append(last_chunk)
28
29
30
print 'training errors'
31
print training_errors
32
print training_idcs
33
print 'validation errors'
34
print validation_errors
35
print validation_idcs
36
37
print 'min training error', np.amin(np.array(training_errors)), 'at', np.argmin(np.array(training_errors))
38
print 'min validation error', np.amin(np.array(validation_errors)), 'at', np.argmin(np.array(validation_errors))
39
40
plt.plot(training_errors, label='training errors')
41
plt.plot(validation_errors, label='validation errors')
42
plt.legend(loc="upper right")
43
plt.title(sys.argv[1])
44
plt.xlabel('Epoch')
45
#plt.ylim(0, 0.7)
46
plt.ylabel('Error') 
47
plt.savefig(sys.argv[2])
48
49