Diff of /trained_model_loader.py [000000] .. [4f54f1]

Switch to unified view

a b/trained_model_loader.py
1
import os
2
import tensorflow as tf
3
4
from utils import store_to_csv
5
import config
6
7
from model_utils import x, evaluate_test_set, evaluate_solution
8
from model_factory import ModelFactory
9
10
11
# Construct model
12
factory = ModelFactory()
13
model = factory.get_network_model()
14
15
softmax_prediction = tf.nn.softmax(model.conv_net(x, 1.0), 
16
    name='softmax_prediction')
17
18
data_loader = factory.get_data_loader()
19
test_set = data_loader.get_exact_tests_set()
20
out_dir = data_loader.results_out_dir()
21
print(out_dir)
22
23
24
saver = tf.train.Saver()
25
26
with tf.Session() as sess:
27
    sess.run(tf.global_variables_initializer())
28
    if os.path.exists(os.path.join(out_dir, config.RESTORE_MODEL_CKPT + '.index')):
29
        saver.restore(sess, os.path.join(out_dir, config.RESTORE_MODEL_CKPT))
30
31
        evaluate_test_set(sess, 
32
                          test_set,
33
                          softmax_prediction,
34
                          x)
35
        if os.path.exists(config.SOLUTION_FILE_PATH):
36
            print("Evaluate generated solution...")
37
            evaluate_solution(config.SOLUTION_FILE_PATH)
38
        else:
39
          print("Solution file was not generated, check if test data is complete...")
40
    else:
41
        print("Checkpoint file {} does not exist in the configured directory {}.".format(
42
            config.RESTORE_MODEL_CKPT, out_dir))
43