Diff of /model_train.py [000000] .. [4f54f1]

Switch to side-by-side view

--- a
+++ b/model_train.py
@@ -0,0 +1,286 @@
+import os
+import pandas as pd
+import numpy as np
+
+import tensorflow as tf
+import config
+
+from model_utils import x, y, keep_prob
+from model_utils import input_img, reshape_op
+
+from model_utils import evaluate_log_loss, accuracy, evaluate_validation_set
+from model_utils import model_store_path, store_error_plots, evaluate_test_set
+from model_utils import high_error_increase, display_confusion_matrix_info
+from model_utils import get_specificity, get_sensitivity, validate_data_loaded
+from model import loss_function_with_logits, sparse_loss_with_logits
+from model_factory import ModelFactory
+
+
+# Parameters used during training
+batch_size = config.BATCH_SIZE
+learning_rate = 0.001
+training_iters = 101
+save_step = 10
+display_steps = 20
+validaton_log_loss_incr_threshold = 0.1
+last_errors = 2
+tolerance = 20
+dropout = 0.5 # Dropout, probability to keep units
+beta = 0.01
+
+# Construct model
+factory = ModelFactory()
+model = factory.get_network_model()
+
+if not config.RESTORE:
+    # Add tensors to collection stored in the model graph
+    # definition
+    tf.add_to_collection('vars', x)
+    tf.add_to_collection('vars', y)
+    tf.add_to_collection('vars', keep_prob)
+
+    for weigth_var in model.weights():
+        tf.add_to_collection('vars', weigth_var)
+
+    for bias_var in model.biases():
+        tf.add_to_collection('vars', bias_var)
+
+pred = model.conv_net(x, dropout)
+
+with tf.name_scope("cross_entropy"):
+    # Define loss and optimizer
+    cost = sparse_loss_with_logits(pred, y)
+    
+    # add l2 regularization on the weights on the fully connected layer
+    # if term != 0 is returned
+    regularizer = model.l2_regularizer()
+    if regularizer != 0:
+        print("Adding L2 regularization...")
+        cost = tf.reduce_mean(cost + beta * regularizer)
+
+trainable_vars = tf.trainable_variables()
+
+with tf.name_scope("train"):
+    gradients = tf.gradients(cost, trainable_vars)
+    gradients = list(zip(gradients, trainable_vars))
+    optimizer = tf.train.AdagradOptimizer(learning_rate=learning_rate)
+    train_op = optimizer.apply_gradients(grads_and_vars=gradients)
+
+# Add gradients to summary  
+for gradient, var in gradients:
+    tf.summary.histogram(var.name + '/gradient', gradient)
+
+# Add the variables we train to the summary  
+for var in trainable_vars:
+    tf.summary.histogram(var.name, var)
+
+# Predictions for the training, validation, and test data.
+softmax_prediction = tf.nn.softmax(pred, name='softmax_prediction')
+
+if not config.RESTORE:
+    tf.add_to_collection('vars', cost)
+    tf.add_to_collection('vars', softmax_prediction)
+
+
+merged = tf.summary.merge_all()
+
+# ======= Training ========
+data_loader = factory.get_data_loader()
+training_set = data_loader.get_training_set()
+validation_set = data_loader.get_validation_set()
+exact_tests = data_loader.get_exact_tests_set()
+model_out_dir = data_loader.results_out_dir()
+
+print('Validation examples count: ', validation_set.num_samples)
+print('Test examples count: ', exact_tests.num_samples)
+print('Model will be stored in: ', model_out_dir)
+
+
+# Initializing the variables
+init = tf.global_variables_initializer()
+
+saver = tf.train.Saver()
+
+validation_errors = []
+train_errors_per_epoch = []
+best_validation_err = 1.0
+best_validation_sensitivity = 0.0
+
+# Add summary for log loss per epoch, accuracy and sensitivity
+with tf.name_scope("log_loss"):
+    log_loss = tf.placeholder(tf.float32, name="log_loss_per_epoch")
+    
+loss_summary = tf.summary.scalar("log_loss", log_loss)
+
+with tf.name_scope("sensitivity"):
+    sensitivity = tf.placeholder(tf.float32, name="sensitivity_per_epoch")
+
+sensitivity_summary = tf.summary.scalar("sensitivity", sensitivity)
+
+with tf.name_scope("accuracy"):
+    tf_accuracy = tf.placeholder(tf.float32, name="accuracy_per_epoch")
+
+accuracy_summary = tf.summary.scalar("accuracy", tf_accuracy)
+
+
+def export_evaluation_summary(log_loss_value, 
+                              accuracy_value, 
+                              sensitivity_value, 
+                              step,
+                              sess,
+                              writer):
+    error_summary, acc_summary, sens_summary = sess.run(
+        [loss_summary, accuracy_summary, sensitivity_summary],
+        feed_dict={log_loss: log_loss_value, tf_accuracy: accuracy_value, 
+                   sensitivity: sensitivity_value})
+    writer.add_summary(error_summary, global_step=step)
+    writer.add_summary(acc_summary, global_step=step)
+    writer.add_summary(sens_summary, global_step=step)
+    writer.flush()
+
+
+# Launch the graph
+with tf.Session() as sess:
+    if not os.path.exists(config.SUMMARIES_DIR):
+        os.makedirs(config.SUMMARIES_DIR)
+        
+    train_writer = tf.summary.FileWriter(os.path.join(config.SUMMARIES_DIR, 'train'))
+    validation_writer = tf.summary.FileWriter(os.path.join(config.SUMMARIES_DIR, 
+                                                           'validation'))
+
+    sess.run(init)
+
+    if config.RESTORE and \
+        os.path.exists(os.path.join(model_out_dir, config.RESTORE_MODEL_CKPT + '.index')):
+        
+        saver.restore(sess, os.path.join(model_out_dir, config.RESTORE_MODEL_CKPT))
+        print("Restoring model from last saved state: ", config.RESTORE_MODEL_CKPT)
+
+
+    # Add the model graph to TensorBoard
+    if not config.RESTORE:
+        train_writer.add_graph(sess.graph)
+
+    for step in range(config.START_STEP, training_iters):
+        train_pred = []
+        train_labels = []
+
+        for i in range(training_set.num_samples):
+            batch_data, batch_labels = training_set.next_batch(batch_size)
+
+            if not validate_data_loaded(batch_data, batch_labels):
+                break
+
+            reshaped = sess.run(reshape_op, feed_dict={input_img: np.stack(batch_data)})
+            feed_dict = {x: reshaped, y: batch_labels, keep_prob: dropout}
+
+            if step % display_steps == 0:
+                _, loss, predictions, summary = sess.run([train_op, cost, softmax_prediction, merged], 
+                                                          feed_dict=feed_dict)
+
+                try:
+                    train_writer.add_summary(summary, step + i)
+                except Exception as e:
+                    print("Exeption raised during summary export. ", e)
+            else:
+                _, loss, predictions = sess.run([train_op, cost, softmax_prediction], 
+                                                 feed_dict=feed_dict)
+
+            train_pred.extend(predictions)
+            train_labels.extend(batch_labels)
+
+        train_writer.flush()
+
+        if step % save_step == 0:
+            print("Storing model snaphost...")
+            saver.save(sess, model_store_path(model_out_dir, 'lungs' + str(step)))
+
+        
+        print("Train epoch {} finished. {} samples processed.".format(
+            training_set.finished_epochs, len(train_pred)))
+
+        if not len(train_pred):
+            break
+            
+        train_acc_epoch = accuracy(np.stack(train_pred), np.stack(train_labels))
+
+        train_log_loss = evaluate_log_loss(train_pred, train_labels)
+    
+        print('Train log loss error {}.'.format(train_log_loss))
+        print('Train set accuracy {}.'.format(train_acc_epoch))
+        print('Train set confusion matrix.')
+        confusion_matrix = display_confusion_matrix_info(train_labels, train_pred)
+        train_sensitivity = get_sensitivity(confusion_matrix)
+        train_specificity = get_specificity(confusion_matrix)
+        print('Test data sensitivity {} and specificity {}'.format(
+            train_sensitivity, train_specificity))
+
+        export_evaluation_summary(train_log_loss, 
+                                  train_acc_epoch, 
+                                  train_sensitivity, 
+                                  step, sess, train_writer)
+
+        print('Evaluate validation set')
+        validation_acc, validation_log_loss, val_sensitivity, val_specificity = evaluate_validation_set(sess, 
+                                                                                                        validation_set,
+                                                                                                        softmax_prediction,
+                                                                                                        x,
+                                                                                                        batch_size)
+        if not validation_log_loss:
+            break
+        export_evaluation_summary(validation_log_loss, 
+                                  validation_acc, 
+                                  val_sensitivity, 
+                                  step, sess, validation_writer)
+
+        print('Validation accuracy: %.1f%%' % validation_acc)
+        print('Log loss overall validation samples: {}.'.format(
+            validation_log_loss))
+        print('Validation set sensitivity {} and specificity {}'.format(
+            val_sensitivity, val_specificity))
+
+        if validation_log_loss < best_validation_err and val_sensitivity > best_validation_sensitivity:
+            best_validation_err = validation_log_loss
+            best_validation_sensitivity = val_sensitivity
+            print("Storing model snaphost with best validation error {} and sensitivity {} ".format(
+                best_validation_err, best_validation_sensitivity))
+            if step % save_step != 0:
+                saver.save(sess, model_store_path(model_out_dir, 'best_err' + str(step)))
+
+        if validation_log_loss < 0.1:
+            print("Low enough log loss validation error, terminate!")
+            break;
+
+        if high_error_increase(validation_errors[-last_errors:], 
+                               validation_log_loss,
+                               last_errors,
+                               validaton_log_loss_incr_threshold):
+            if tolerance and train_log_loss <= train_errors_per_epoch[-1]:
+                print("Train error still decreases, continue...")
+                tolerance -= 1
+                validation_errors.append(validation_log_loss)
+                train_errors_per_epoch.append(train_log_loss)
+                continue
+
+            print("Validation log loss has increased more than the allowed threshold",
+                  " for the past iterations, terminate!")
+            print("Last iterations: ", validation_errors[-last_errors:])
+            print("Current validation error: ", validation_log_loss)
+            break
+
+        validation_errors.append(validation_log_loss)
+        train_errors_per_epoch.append(train_log_loss)
+
+    train_writer.close()
+    validation_writer.close()
+
+    saver.save(sess, model_store_path(model_out_dir, 'last'))
+    print("Model saved...")
+    store_error_plots(validation_errors, train_errors_per_epoch)
+
+
+    # ============= REAL TEST DATA EVALUATION =====================
+    evaluate_test_set(sess,
+                      exact_tests,
+                      softmax_prediction,
+                      x)