Diff of /train.py [000000] .. [e77990]

Switch to side-by-side view

--- a
+++ b/train.py
@@ -0,0 +1,271 @@
+import os
+import argparse
+import datetime
+import uuid
+import tensorflow as tf
+import matplotlib.pyplot as plt
+
+from azureml.core.run import Run
+from azureml.core import Datastore
+from azureml.core.model import Model, Dataset
+from tensorflow.keras import backend as K
+from tensorflow.keras.layers import (
+    Flatten, Dense, Reshape, Conv2D, MaxPool2D, Conv2DTranspose)
+
+
+class DisplayCallback(tf.keras.callbacks.Callback):
+    def on_epoch_end(self, epoch, logs=None):
+        print ('\nSample Prediction after epoch {}\n'.format(epoch+1))
+
+
+class Train():
+
+    def __init__(self):
+
+        self._parser = argparse.ArgumentParser("train")
+        self._parser.add_argument("--model_name", type=str, help="Name of the tf model")
+
+        self._args = self._parser.parse_args()
+        self._run = Run.get_context()
+        self._exp = self._run.experiment
+        self._ws = self._run.experiment.workspace
+        self._image_feature_description = {
+            'height':    tf.io.FixedLenFeature([], tf.int64),
+            'width':     tf.io.FixedLenFeature([], tf.int64),
+            'depth':     tf.io.FixedLenFeature([], tf.int64),
+            'name' :     tf.io.FixedLenFeature([], tf.string),
+            'image_raw': tf.io.FixedLenFeature([], tf.string),
+            'label_raw': tf.io.FixedLenFeature([], tf.string),
+        }
+        self._model = self.__get_model()
+        self._parsed_training_dataset, self._parsed_val_dataset = self.__load_dataset()
+        self.__steps_per_epoch = len(list(self._parsed_training_dataset))
+        self._buffer_size = 10
+        self._batch_size = 1
+        self.__epochs = 30
+
+
+    def main(self):
+        plt.rcParams['image.cmap'] = 'Greys_r'
+
+        tf_autotune = tf.data.experimental.AUTOTUNE
+        train = self._parsed_training_dataset.map(
+            self.__read_and_decode, num_parallel_calls=tf_autotune)
+        val = self._parsed_val_dataset.map(self.__read_and_decode)
+
+        train_dataset = train.cache().shuffle(self._buffer_size).batch(self._batch_size).repeat()
+        train_dataset = train_dataset.prefetch(buffer_size=tf_autotune)
+        test_dataset  = val.batch(self._batch_size)
+
+        for image, label in train.take(2):
+            sample_image, sample_label = image, label
+            self.__display("Training Images", [sample_image, sample_label])
+
+        for image, label in val.take(2):
+            sample_image, sample_label = image, label
+            self.__display("Eval Images", [sample_image, sample_label])
+
+        logdir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
+        tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1)
+
+        tf.keras.backend.clear_session()
+
+        self._model = self.__get_model()
+
+        model_history = self._model.fit(train_dataset, epochs=self.__epochs,
+                          steps_per_epoch=self.__steps_per_epoch,
+                          validation_data=test_dataset,
+                          callbacks=[DisplayCallback()])
+
+        metrics_results = self._model.evaluate(test_dataset)
+        self._run.log("DICE", "{:.2f}%".format(metrics_results[0]))
+        self._run.log("Accuracy", "{:.2f}%".format(metrics_results[1]))
+
+        self.__plot_training_logs(model_history)
+        self.__show_predictions(test_dataset, 5)
+        self.__register_model(metrics_results)
+
+    
+    def __parse_image_function(self, example_proto):
+        return tf.io.parse_single_example(example_proto, self._image_feature_description)
+
+
+    def __load_dataset(self):
+        raw_training_dataset = tf.data.TFRecordDataset('data/train_images.tfrecords')
+        raw_val_dataset      = tf.data.TFRecordDataset('data/val_images.tfrecords')
+
+        parsed_training_dataset = raw_training_dataset.map(self.__parse_image_function)
+        parsed_val_dataset = raw_val_dataset.map(self.__parse_image_function)
+
+        return parsed_training_dataset, parsed_val_dataset
+
+
+    @tf.function
+    def __read_and_decode(self, example):
+        image_raw = tf.io.decode_raw(example['image_raw'], tf.int64)
+        image_raw.set_shape([65536])
+        image = tf.reshape(image_raw, [256, 256, 1])
+
+        image = tf.cast(image, tf.float32) * (1. / 1024)
+
+        label_raw = tf.io.decode_raw(example['label_raw'], tf.uint8)
+        label_raw.set_shape([65536])
+        label = tf.reshape(label_raw, [256, 256, 1])
+
+        return image, label
+
+
+    def __display(self, image_title, display_list):
+        plt.figure(figsize=(10, 10))
+        title = ['Input Image', 'Label', 'Predicted Label']
+
+        for i in range(len(display_list)):
+            display_resized = tf.reshape(display_list[i], [256, 256])
+            plt.subplot(1, len(display_list), i+1)
+            plt.title(title[i])
+            plt.imshow(display_resized)
+            plt.axis('off')
+        title = uuid.uuid4()
+        self._run.log_image(f'{title}', plot=plt)
+
+
+    def __create_mask(self, pred_mask):
+        pred_mask = tf.argmax(pred_mask, axis=-1)
+        pred_mask = pred_mask[..., tf.newaxis]
+        return pred_mask[0]
+
+
+    def __show_predictions(self, dataset=None, num=1):
+        if dataset:
+            for image, label in dataset.take(num):
+                pred_mask = self._model.predict(image)
+                self.__display("Show predictions", [image[0], label[0], self.__create_mask(pred_mask)])
+        else:
+            prediction = self.__create_mask(self._.predict(sample_image[tf.newaxis, ...]))
+            self.__display("Show predictions sample image", [sample_image, sample_label, prediction])
+
+    
+    def __get_dice_coef(self, y_true, y_pred, smooth=1):
+        indices = K.argmax(y_pred, 3)
+        indices = K.reshape(indices, [-1, 256, 256, 1])
+
+        true_cast = y_true
+        indices_cast = K.cast(indices, dtype='float32')
+
+        axis = [1, 2, 3]
+        intersection = K.sum(true_cast * indices_cast, axis=axis)
+        union = K.sum(true_cast, axis=axis) + K.sum(indices_cast, axis=axis)
+        dice = K.mean((2. * intersection + smooth)/(union + smooth), axis=0)
+
+        return dice
+
+    
+    def __get_model(self):
+        layers = [
+            Conv2D(input_shape=[256, 256, 1],
+                filters=100,
+                kernel_size=5,
+                strides=2,
+                padding="same",
+                activation=tf.nn.relu,
+                name="Conv1"),
+            MaxPool2D(pool_size=2, strides=2, padding="same"),
+            Conv2D(filters=200,
+                kernel_size=5,
+                strides=2,
+                padding="same",
+                activation=tf.nn.relu),
+            MaxPool2D(pool_size=2, strides=2, padding="same"),
+            Conv2D(filters=300,
+                kernel_size=3,
+                strides=1,
+                padding="same",
+                activation=tf.nn.relu),
+            Conv2D(filters=300,
+                kernel_size=3,
+                strides=1,
+                padding="same",
+                activation=tf.nn.relu),
+            Conv2D(filters=2,
+                kernel_size=1,
+                strides=1,
+                padding="same",
+                activation=tf.nn.relu),
+            Conv2DTranspose(filters=2, kernel_size=31, strides=16, padding="same")
+        ]
+
+        tf.keras.backend.clear_session()
+        model = tf.keras.models.Sequential(layers)
+
+        model.compile(
+            optimizer='adam',
+            loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+            metrics=[self.__get_dice_coef, 'accuracy', self.__f1_score,
+                    self.__precision, self.__recall])
+        
+        return model
+
+
+    def __plot_training_logs(self, model_history):
+        loss = model_history.history['loss']
+        val_loss = model_history.history['val_loss']
+        accuracy = model_history.history['accuracy']
+        val_accuracy = model_history.history['val_accuracy']
+        dice = model_history.history['__get_dice_coef']
+
+        epochs = range(self.__epochs)
+
+        plt.figure()
+        plt.plot(epochs, loss, 'r', label='Training loss')
+        plt.plot(epochs, val_loss, 'bo', label='Validation loss')
+        plt.plot(epochs, dice, 'go', label='Dice Coefficient')
+        plt.title('Training and Validation Loss')
+        plt.xlabel('Epoch')
+        plt.ylabel('Loss Value')
+        plt.ylim([0, 1])
+        plt.legend()
+        self._run.log_image("Training and Validation Loss", plot=plt)
+
+
+    def __recall(self, y_true, y_pred):
+        y_true = K.ones_like(y_true) 
+        true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
+        all_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
+        
+        recall = true_positives / (all_positives + K.epsilon())
+        return recall
+
+
+    def __precision(self, y_true, y_pred):
+        y_true = K.ones_like(y_true) 
+        true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
+        
+        predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
+        precision = true_positives / (predicted_positives + K.epsilon())
+        return precision
+
+
+    def __f1_score(self, y_true, y_pred):
+        precision = self.__precision(y_true, y_pred)
+        recall = self.__recall(y_true, y_pred)
+        return 2*((precision*recall)/(precision+recall+K.epsilon()))
+
+
+    def __register_model(self, metrics_results):
+        tf.keras.models.save_model(
+            self._model, "./model", overwrite=True, include_optimizer=True, save_format=tf,
+            signatures=None, options=None)
+        Model.register(workspace=self._ws,
+                    model_path="./model",
+                    model_name=self._args.model_name,
+                    properties = {"run_id": self._run.id,
+                                "experiment": self._run.experiment.name},
+                    tags={
+                        "DICE": float(metrics_results[0]),
+                        "Accuracy": float(metrics_results[1])
+                    })
+
+
+if __name__ == '__main__':
+    tr = Train()
+    tr.main()
\ No newline at end of file