--- a +++ b/model_train.py @@ -0,0 +1,286 @@ +import os +import pandas as pd +import numpy as np + +import tensorflow as tf +import config + +from model_utils import x, y, keep_prob +from model_utils import input_img, reshape_op + +from model_utils import evaluate_log_loss, accuracy, evaluate_validation_set +from model_utils import model_store_path, store_error_plots, evaluate_test_set +from model_utils import high_error_increase, display_confusion_matrix_info +from model_utils import get_specificity, get_sensitivity, validate_data_loaded +from model import loss_function_with_logits, sparse_loss_with_logits +from model_factory import ModelFactory + + +# Parameters used during training +batch_size = config.BATCH_SIZE +learning_rate = 0.001 +training_iters = 101 +save_step = 10 +display_steps = 20 +validaton_log_loss_incr_threshold = 0.1 +last_errors = 2 +tolerance = 20 +dropout = 0.5 # Dropout, probability to keep units +beta = 0.01 + +# Construct model +factory = ModelFactory() +model = factory.get_network_model() + +if not config.RESTORE: + # Add tensors to collection stored in the model graph + # definition + tf.add_to_collection('vars', x) + tf.add_to_collection('vars', y) + tf.add_to_collection('vars', keep_prob) + + for weigth_var in model.weights(): + tf.add_to_collection('vars', weigth_var) + + for bias_var in model.biases(): + tf.add_to_collection('vars', bias_var) + +pred = model.conv_net(x, dropout) + +with tf.name_scope("cross_entropy"): + # Define loss and optimizer + cost = sparse_loss_with_logits(pred, y) + + # add l2 regularization on the weights on the fully connected layer + # if term != 0 is returned + regularizer = model.l2_regularizer() + if regularizer != 0: + print("Adding L2 regularization...") + cost = tf.reduce_mean(cost + beta * regularizer) + +trainable_vars = tf.trainable_variables() + +with tf.name_scope("train"): + gradients = tf.gradients(cost, trainable_vars) + gradients = list(zip(gradients, trainable_vars)) + optimizer = tf.train.AdagradOptimizer(learning_rate=learning_rate) + train_op = optimizer.apply_gradients(grads_and_vars=gradients) + +# Add gradients to summary +for gradient, var in gradients: + tf.summary.histogram(var.name + '/gradient', gradient) + +# Add the variables we train to the summary +for var in trainable_vars: + tf.summary.histogram(var.name, var) + +# Predictions for the training, validation, and test data. +softmax_prediction = tf.nn.softmax(pred, name='softmax_prediction') + +if not config.RESTORE: + tf.add_to_collection('vars', cost) + tf.add_to_collection('vars', softmax_prediction) + + +merged = tf.summary.merge_all() + +# ======= Training ======== +data_loader = factory.get_data_loader() +training_set = data_loader.get_training_set() +validation_set = data_loader.get_validation_set() +exact_tests = data_loader.get_exact_tests_set() +model_out_dir = data_loader.results_out_dir() + +print('Validation examples count: ', validation_set.num_samples) +print('Test examples count: ', exact_tests.num_samples) +print('Model will be stored in: ', model_out_dir) + + +# Initializing the variables +init = tf.global_variables_initializer() + +saver = tf.train.Saver() + +validation_errors = [] +train_errors_per_epoch = [] +best_validation_err = 1.0 +best_validation_sensitivity = 0.0 + +# Add summary for log loss per epoch, accuracy and sensitivity +with tf.name_scope("log_loss"): + log_loss = tf.placeholder(tf.float32, name="log_loss_per_epoch") + +loss_summary = tf.summary.scalar("log_loss", log_loss) + +with tf.name_scope("sensitivity"): + sensitivity = tf.placeholder(tf.float32, name="sensitivity_per_epoch") + +sensitivity_summary = tf.summary.scalar("sensitivity", sensitivity) + +with tf.name_scope("accuracy"): + tf_accuracy = tf.placeholder(tf.float32, name="accuracy_per_epoch") + +accuracy_summary = tf.summary.scalar("accuracy", tf_accuracy) + + +def export_evaluation_summary(log_loss_value, + accuracy_value, + sensitivity_value, + step, + sess, + writer): + error_summary, acc_summary, sens_summary = sess.run( + [loss_summary, accuracy_summary, sensitivity_summary], + feed_dict={log_loss: log_loss_value, tf_accuracy: accuracy_value, + sensitivity: sensitivity_value}) + writer.add_summary(error_summary, global_step=step) + writer.add_summary(acc_summary, global_step=step) + writer.add_summary(sens_summary, global_step=step) + writer.flush() + + +# Launch the graph +with tf.Session() as sess: + if not os.path.exists(config.SUMMARIES_DIR): + os.makedirs(config.SUMMARIES_DIR) + + train_writer = tf.summary.FileWriter(os.path.join(config.SUMMARIES_DIR, 'train')) + validation_writer = tf.summary.FileWriter(os.path.join(config.SUMMARIES_DIR, + 'validation')) + + sess.run(init) + + if config.RESTORE and \ + os.path.exists(os.path.join(model_out_dir, config.RESTORE_MODEL_CKPT + '.index')): + + saver.restore(sess, os.path.join(model_out_dir, config.RESTORE_MODEL_CKPT)) + print("Restoring model from last saved state: ", config.RESTORE_MODEL_CKPT) + + + # Add the model graph to TensorBoard + if not config.RESTORE: + train_writer.add_graph(sess.graph) + + for step in range(config.START_STEP, training_iters): + train_pred = [] + train_labels = [] + + for i in range(training_set.num_samples): + batch_data, batch_labels = training_set.next_batch(batch_size) + + if not validate_data_loaded(batch_data, batch_labels): + break + + reshaped = sess.run(reshape_op, feed_dict={input_img: np.stack(batch_data)}) + feed_dict = {x: reshaped, y: batch_labels, keep_prob: dropout} + + if step % display_steps == 0: + _, loss, predictions, summary = sess.run([train_op, cost, softmax_prediction, merged], + feed_dict=feed_dict) + + try: + train_writer.add_summary(summary, step + i) + except Exception as e: + print("Exeption raised during summary export. ", e) + else: + _, loss, predictions = sess.run([train_op, cost, softmax_prediction], + feed_dict=feed_dict) + + train_pred.extend(predictions) + train_labels.extend(batch_labels) + + train_writer.flush() + + if step % save_step == 0: + print("Storing model snaphost...") + saver.save(sess, model_store_path(model_out_dir, 'lungs' + str(step))) + + + print("Train epoch {} finished. {} samples processed.".format( + training_set.finished_epochs, len(train_pred))) + + if not len(train_pred): + break + + train_acc_epoch = accuracy(np.stack(train_pred), np.stack(train_labels)) + + train_log_loss = evaluate_log_loss(train_pred, train_labels) + + print('Train log loss error {}.'.format(train_log_loss)) + print('Train set accuracy {}.'.format(train_acc_epoch)) + print('Train set confusion matrix.') + confusion_matrix = display_confusion_matrix_info(train_labels, train_pred) + train_sensitivity = get_sensitivity(confusion_matrix) + train_specificity = get_specificity(confusion_matrix) + print('Test data sensitivity {} and specificity {}'.format( + train_sensitivity, train_specificity)) + + export_evaluation_summary(train_log_loss, + train_acc_epoch, + train_sensitivity, + step, sess, train_writer) + + print('Evaluate validation set') + validation_acc, validation_log_loss, val_sensitivity, val_specificity = evaluate_validation_set(sess, + validation_set, + softmax_prediction, + x, + batch_size) + if not validation_log_loss: + break + export_evaluation_summary(validation_log_loss, + validation_acc, + val_sensitivity, + step, sess, validation_writer) + + print('Validation accuracy: %.1f%%' % validation_acc) + print('Log loss overall validation samples: {}.'.format( + validation_log_loss)) + print('Validation set sensitivity {} and specificity {}'.format( + val_sensitivity, val_specificity)) + + if validation_log_loss < best_validation_err and val_sensitivity > best_validation_sensitivity: + best_validation_err = validation_log_loss + best_validation_sensitivity = val_sensitivity + print("Storing model snaphost with best validation error {} and sensitivity {} ".format( + best_validation_err, best_validation_sensitivity)) + if step % save_step != 0: + saver.save(sess, model_store_path(model_out_dir, 'best_err' + str(step))) + + if validation_log_loss < 0.1: + print("Low enough log loss validation error, terminate!") + break; + + if high_error_increase(validation_errors[-last_errors:], + validation_log_loss, + last_errors, + validaton_log_loss_incr_threshold): + if tolerance and train_log_loss <= train_errors_per_epoch[-1]: + print("Train error still decreases, continue...") + tolerance -= 1 + validation_errors.append(validation_log_loss) + train_errors_per_epoch.append(train_log_loss) + continue + + print("Validation log loss has increased more than the allowed threshold", + " for the past iterations, terminate!") + print("Last iterations: ", validation_errors[-last_errors:]) + print("Current validation error: ", validation_log_loss) + break + + validation_errors.append(validation_log_loss) + train_errors_per_epoch.append(train_log_loss) + + train_writer.close() + validation_writer.close() + + saver.save(sess, model_store_path(model_out_dir, 'last')) + print("Model saved...") + store_error_plots(validation_errors, train_errors_per_epoch) + + + # ============= REAL TEST DATA EVALUATION ===================== + evaluate_test_set(sess, + exact_tests, + softmax_prediction, + x)