--- a +++ b/train.py @@ -0,0 +1,271 @@ +import os +import argparse +import datetime +import uuid +import tensorflow as tf +import matplotlib.pyplot as plt + +from azureml.core.run import Run +from azureml.core import Datastore +from azureml.core.model import Model, Dataset +from tensorflow.keras import backend as K +from tensorflow.keras.layers import ( + Flatten, Dense, Reshape, Conv2D, MaxPool2D, Conv2DTranspose) + + +class DisplayCallback(tf.keras.callbacks.Callback): + def on_epoch_end(self, epoch, logs=None): + print ('\nSample Prediction after epoch {}\n'.format(epoch+1)) + + +class Train(): + + def __init__(self): + + self._parser = argparse.ArgumentParser("train") + self._parser.add_argument("--model_name", type=str, help="Name of the tf model") + + self._args = self._parser.parse_args() + self._run = Run.get_context() + self._exp = self._run.experiment + self._ws = self._run.experiment.workspace + self._image_feature_description = { + 'height': tf.io.FixedLenFeature([], tf.int64), + 'width': tf.io.FixedLenFeature([], tf.int64), + 'depth': tf.io.FixedLenFeature([], tf.int64), + 'name' : tf.io.FixedLenFeature([], tf.string), + 'image_raw': tf.io.FixedLenFeature([], tf.string), + 'label_raw': tf.io.FixedLenFeature([], tf.string), + } + self._model = self.__get_model() + self._parsed_training_dataset, self._parsed_val_dataset = self.__load_dataset() + self.__steps_per_epoch = len(list(self._parsed_training_dataset)) + self._buffer_size = 10 + self._batch_size = 1 + self.__epochs = 30 + + + def main(self): + plt.rcParams['image.cmap'] = 'Greys_r' + + tf_autotune = tf.data.experimental.AUTOTUNE + train = self._parsed_training_dataset.map( + self.__read_and_decode, num_parallel_calls=tf_autotune) + val = self._parsed_val_dataset.map(self.__read_and_decode) + + train_dataset = train.cache().shuffle(self._buffer_size).batch(self._batch_size).repeat() + train_dataset = train_dataset.prefetch(buffer_size=tf_autotune) + test_dataset = val.batch(self._batch_size) + + for image, label in train.take(2): + sample_image, sample_label = image, label + self.__display("Training Images", [sample_image, sample_label]) + + for image, label in val.take(2): + sample_image, sample_label = image, label + self.__display("Eval Images", [sample_image, sample_label]) + + logdir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S")) + tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1) + + tf.keras.backend.clear_session() + + self._model = self.__get_model() + + model_history = self._model.fit(train_dataset, epochs=self.__epochs, + steps_per_epoch=self.__steps_per_epoch, + validation_data=test_dataset, + callbacks=[DisplayCallback()]) + + metrics_results = self._model.evaluate(test_dataset) + self._run.log("DICE", "{:.2f}%".format(metrics_results[0])) + self._run.log("Accuracy", "{:.2f}%".format(metrics_results[1])) + + self.__plot_training_logs(model_history) + self.__show_predictions(test_dataset, 5) + self.__register_model(metrics_results) + + + def __parse_image_function(self, example_proto): + return tf.io.parse_single_example(example_proto, self._image_feature_description) + + + def __load_dataset(self): + raw_training_dataset = tf.data.TFRecordDataset('data/train_images.tfrecords') + raw_val_dataset = tf.data.TFRecordDataset('data/val_images.tfrecords') + + parsed_training_dataset = raw_training_dataset.map(self.__parse_image_function) + parsed_val_dataset = raw_val_dataset.map(self.__parse_image_function) + + return parsed_training_dataset, parsed_val_dataset + + + @tf.function + def __read_and_decode(self, example): + image_raw = tf.io.decode_raw(example['image_raw'], tf.int64) + image_raw.set_shape([65536]) + image = tf.reshape(image_raw, [256, 256, 1]) + + image = tf.cast(image, tf.float32) * (1. / 1024) + + label_raw = tf.io.decode_raw(example['label_raw'], tf.uint8) + label_raw.set_shape([65536]) + label = tf.reshape(label_raw, [256, 256, 1]) + + return image, label + + + def __display(self, image_title, display_list): + plt.figure(figsize=(10, 10)) + title = ['Input Image', 'Label', 'Predicted Label'] + + for i in range(len(display_list)): + display_resized = tf.reshape(display_list[i], [256, 256]) + plt.subplot(1, len(display_list), i+1) + plt.title(title[i]) + plt.imshow(display_resized) + plt.axis('off') + title = uuid.uuid4() + self._run.log_image(f'{title}', plot=plt) + + + def __create_mask(self, pred_mask): + pred_mask = tf.argmax(pred_mask, axis=-1) + pred_mask = pred_mask[..., tf.newaxis] + return pred_mask[0] + + + def __show_predictions(self, dataset=None, num=1): + if dataset: + for image, label in dataset.take(num): + pred_mask = self._model.predict(image) + self.__display("Show predictions", [image[0], label[0], self.__create_mask(pred_mask)]) + else: + prediction = self.__create_mask(self._.predict(sample_image[tf.newaxis, ...])) + self.__display("Show predictions sample image", [sample_image, sample_label, prediction]) + + + def __get_dice_coef(self, y_true, y_pred, smooth=1): + indices = K.argmax(y_pred, 3) + indices = K.reshape(indices, [-1, 256, 256, 1]) + + true_cast = y_true + indices_cast = K.cast(indices, dtype='float32') + + axis = [1, 2, 3] + intersection = K.sum(true_cast * indices_cast, axis=axis) + union = K.sum(true_cast, axis=axis) + K.sum(indices_cast, axis=axis) + dice = K.mean((2. * intersection + smooth)/(union + smooth), axis=0) + + return dice + + + def __get_model(self): + layers = [ + Conv2D(input_shape=[256, 256, 1], + filters=100, + kernel_size=5, + strides=2, + padding="same", + activation=tf.nn.relu, + name="Conv1"), + MaxPool2D(pool_size=2, strides=2, padding="same"), + Conv2D(filters=200, + kernel_size=5, + strides=2, + padding="same", + activation=tf.nn.relu), + MaxPool2D(pool_size=2, strides=2, padding="same"), + Conv2D(filters=300, + kernel_size=3, + strides=1, + padding="same", + activation=tf.nn.relu), + Conv2D(filters=300, + kernel_size=3, + strides=1, + padding="same", + activation=tf.nn.relu), + Conv2D(filters=2, + kernel_size=1, + strides=1, + padding="same", + activation=tf.nn.relu), + Conv2DTranspose(filters=2, kernel_size=31, strides=16, padding="same") + ] + + tf.keras.backend.clear_session() + model = tf.keras.models.Sequential(layers) + + model.compile( + optimizer='adam', + loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), + metrics=[self.__get_dice_coef, 'accuracy', self.__f1_score, + self.__precision, self.__recall]) + + return model + + + def __plot_training_logs(self, model_history): + loss = model_history.history['loss'] + val_loss = model_history.history['val_loss'] + accuracy = model_history.history['accuracy'] + val_accuracy = model_history.history['val_accuracy'] + dice = model_history.history['__get_dice_coef'] + + epochs = range(self.__epochs) + + plt.figure() + plt.plot(epochs, loss, 'r', label='Training loss') + plt.plot(epochs, val_loss, 'bo', label='Validation loss') + plt.plot(epochs, dice, 'go', label='Dice Coefficient') + plt.title('Training and Validation Loss') + plt.xlabel('Epoch') + plt.ylabel('Loss Value') + plt.ylim([0, 1]) + plt.legend() + self._run.log_image("Training and Validation Loss", plot=plt) + + + def __recall(self, y_true, y_pred): + y_true = K.ones_like(y_true) + true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1))) + all_positives = K.sum(K.round(K.clip(y_true, 0, 1))) + + recall = true_positives / (all_positives + K.epsilon()) + return recall + + + def __precision(self, y_true, y_pred): + y_true = K.ones_like(y_true) + true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1))) + + predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1))) + precision = true_positives / (predicted_positives + K.epsilon()) + return precision + + + def __f1_score(self, y_true, y_pred): + precision = self.__precision(y_true, y_pred) + recall = self.__recall(y_true, y_pred) + return 2*((precision*recall)/(precision+recall+K.epsilon())) + + + def __register_model(self, metrics_results): + tf.keras.models.save_model( + self._model, "./model", overwrite=True, include_optimizer=True, save_format=tf, + signatures=None, options=None) + Model.register(workspace=self._ws, + model_path="./model", + model_name=self._args.model_name, + properties = {"run_id": self._run.id, + "experiment": self._run.experiment.name}, + tags={ + "DICE": float(metrics_results[0]), + "Accuracy": float(metrics_results[1]) + }) + + +if __name__ == '__main__': + tr = Train() + tr.main() \ No newline at end of file