Diff of /train.py [000000] .. [e77990]

Switch to unified view

a b/train.py
1
import os
2
import argparse
3
import datetime
4
import uuid
5
import tensorflow as tf
6
import matplotlib.pyplot as plt
7
8
from azureml.core.run import Run
9
from azureml.core import Datastore
10
from azureml.core.model import Model, Dataset
11
from tensorflow.keras import backend as K
12
from tensorflow.keras.layers import (
13
    Flatten, Dense, Reshape, Conv2D, MaxPool2D, Conv2DTranspose)
14
15
16
class DisplayCallback(tf.keras.callbacks.Callback):
17
    def on_epoch_end(self, epoch, logs=None):
18
        print ('\nSample Prediction after epoch {}\n'.format(epoch+1))
19
20
21
class Train():
22
23
    def __init__(self):
24
25
        self._parser = argparse.ArgumentParser("train")
26
        self._parser.add_argument("--model_name", type=str, help="Name of the tf model")
27
28
        self._args = self._parser.parse_args()
29
        self._run = Run.get_context()
30
        self._exp = self._run.experiment
31
        self._ws = self._run.experiment.workspace
32
        self._image_feature_description = {
33
            'height':    tf.io.FixedLenFeature([], tf.int64),
34
            'width':     tf.io.FixedLenFeature([], tf.int64),
35
            'depth':     tf.io.FixedLenFeature([], tf.int64),
36
            'name' :     tf.io.FixedLenFeature([], tf.string),
37
            'image_raw': tf.io.FixedLenFeature([], tf.string),
38
            'label_raw': tf.io.FixedLenFeature([], tf.string),
39
        }
40
        self._model = self.__get_model()
41
        self._parsed_training_dataset, self._parsed_val_dataset = self.__load_dataset()
42
        self.__steps_per_epoch = len(list(self._parsed_training_dataset))
43
        self._buffer_size = 10
44
        self._batch_size = 1
45
        self.__epochs = 30
46
47
48
    def main(self):
49
        plt.rcParams['image.cmap'] = 'Greys_r'
50
51
        tf_autotune = tf.data.experimental.AUTOTUNE
52
        train = self._parsed_training_dataset.map(
53
            self.__read_and_decode, num_parallel_calls=tf_autotune)
54
        val = self._parsed_val_dataset.map(self.__read_and_decode)
55
56
        train_dataset = train.cache().shuffle(self._buffer_size).batch(self._batch_size).repeat()
57
        train_dataset = train_dataset.prefetch(buffer_size=tf_autotune)
58
        test_dataset  = val.batch(self._batch_size)
59
60
        for image, label in train.take(2):
61
            sample_image, sample_label = image, label
62
            self.__display("Training Images", [sample_image, sample_label])
63
64
        for image, label in val.take(2):
65
            sample_image, sample_label = image, label
66
            self.__display("Eval Images", [sample_image, sample_label])
67
68
        logdir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
69
        tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1)
70
71
        tf.keras.backend.clear_session()
72
73
        self._model = self.__get_model()
74
75
        model_history = self._model.fit(train_dataset, epochs=self.__epochs,
76
                          steps_per_epoch=self.__steps_per_epoch,
77
                          validation_data=test_dataset,
78
                          callbacks=[DisplayCallback()])
79
80
        metrics_results = self._model.evaluate(test_dataset)
81
        self._run.log("DICE", "{:.2f}%".format(metrics_results[0]))
82
        self._run.log("Accuracy", "{:.2f}%".format(metrics_results[1]))
83
84
        self.__plot_training_logs(model_history)
85
        self.__show_predictions(test_dataset, 5)
86
        self.__register_model(metrics_results)
87
88
    
89
    def __parse_image_function(self, example_proto):
90
        return tf.io.parse_single_example(example_proto, self._image_feature_description)
91
92
93
    def __load_dataset(self):
94
        raw_training_dataset = tf.data.TFRecordDataset('data/train_images.tfrecords')
95
        raw_val_dataset      = tf.data.TFRecordDataset('data/val_images.tfrecords')
96
97
        parsed_training_dataset = raw_training_dataset.map(self.__parse_image_function)
98
        parsed_val_dataset = raw_val_dataset.map(self.__parse_image_function)
99
100
        return parsed_training_dataset, parsed_val_dataset
101
102
103
    @tf.function
104
    def __read_and_decode(self, example):
105
        image_raw = tf.io.decode_raw(example['image_raw'], tf.int64)
106
        image_raw.set_shape([65536])
107
        image = tf.reshape(image_raw, [256, 256, 1])
108
109
        image = tf.cast(image, tf.float32) * (1. / 1024)
110
111
        label_raw = tf.io.decode_raw(example['label_raw'], tf.uint8)
112
        label_raw.set_shape([65536])
113
        label = tf.reshape(label_raw, [256, 256, 1])
114
115
        return image, label
116
117
118
    def __display(self, image_title, display_list):
119
        plt.figure(figsize=(10, 10))
120
        title = ['Input Image', 'Label', 'Predicted Label']
121
122
        for i in range(len(display_list)):
123
            display_resized = tf.reshape(display_list[i], [256, 256])
124
            plt.subplot(1, len(display_list), i+1)
125
            plt.title(title[i])
126
            plt.imshow(display_resized)
127
            plt.axis('off')
128
        title = uuid.uuid4()
129
        self._run.log_image(f'{title}', plot=plt)
130
131
132
    def __create_mask(self, pred_mask):
133
        pred_mask = tf.argmax(pred_mask, axis=-1)
134
        pred_mask = pred_mask[..., tf.newaxis]
135
        return pred_mask[0]
136
137
138
    def __show_predictions(self, dataset=None, num=1):
139
        if dataset:
140
            for image, label in dataset.take(num):
141
                pred_mask = self._model.predict(image)
142
                self.__display("Show predictions", [image[0], label[0], self.__create_mask(pred_mask)])
143
        else:
144
            prediction = self.__create_mask(self._.predict(sample_image[tf.newaxis, ...]))
145
            self.__display("Show predictions sample image", [sample_image, sample_label, prediction])
146
147
    
148
    def __get_dice_coef(self, y_true, y_pred, smooth=1):
149
        indices = K.argmax(y_pred, 3)
150
        indices = K.reshape(indices, [-1, 256, 256, 1])
151
152
        true_cast = y_true
153
        indices_cast = K.cast(indices, dtype='float32')
154
155
        axis = [1, 2, 3]
156
        intersection = K.sum(true_cast * indices_cast, axis=axis)
157
        union = K.sum(true_cast, axis=axis) + K.sum(indices_cast, axis=axis)
158
        dice = K.mean((2. * intersection + smooth)/(union + smooth), axis=0)
159
160
        return dice
161
162
    
163
    def __get_model(self):
164
        layers = [
165
            Conv2D(input_shape=[256, 256, 1],
166
                filters=100,
167
                kernel_size=5,
168
                strides=2,
169
                padding="same",
170
                activation=tf.nn.relu,
171
                name="Conv1"),
172
            MaxPool2D(pool_size=2, strides=2, padding="same"),
173
            Conv2D(filters=200,
174
                kernel_size=5,
175
                strides=2,
176
                padding="same",
177
                activation=tf.nn.relu),
178
            MaxPool2D(pool_size=2, strides=2, padding="same"),
179
            Conv2D(filters=300,
180
                kernel_size=3,
181
                strides=1,
182
                padding="same",
183
                activation=tf.nn.relu),
184
            Conv2D(filters=300,
185
                kernel_size=3,
186
                strides=1,
187
                padding="same",
188
                activation=tf.nn.relu),
189
            Conv2D(filters=2,
190
                kernel_size=1,
191
                strides=1,
192
                padding="same",
193
                activation=tf.nn.relu),
194
            Conv2DTranspose(filters=2, kernel_size=31, strides=16, padding="same")
195
        ]
196
197
        tf.keras.backend.clear_session()
198
        model = tf.keras.models.Sequential(layers)
199
200
        model.compile(
201
            optimizer='adam',
202
            loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
203
            metrics=[self.__get_dice_coef, 'accuracy', self.__f1_score,
204
                    self.__precision, self.__recall])
205
        
206
        return model
207
208
209
    def __plot_training_logs(self, model_history):
210
        loss = model_history.history['loss']
211
        val_loss = model_history.history['val_loss']
212
        accuracy = model_history.history['accuracy']
213
        val_accuracy = model_history.history['val_accuracy']
214
        dice = model_history.history['__get_dice_coef']
215
216
        epochs = range(self.__epochs)
217
218
        plt.figure()
219
        plt.plot(epochs, loss, 'r', label='Training loss')
220
        plt.plot(epochs, val_loss, 'bo', label='Validation loss')
221
        plt.plot(epochs, dice, 'go', label='Dice Coefficient')
222
        plt.title('Training and Validation Loss')
223
        plt.xlabel('Epoch')
224
        plt.ylabel('Loss Value')
225
        plt.ylim([0, 1])
226
        plt.legend()
227
        self._run.log_image("Training and Validation Loss", plot=plt)
228
229
230
    def __recall(self, y_true, y_pred):
231
        y_true = K.ones_like(y_true) 
232
        true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
233
        all_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
234
        
235
        recall = true_positives / (all_positives + K.epsilon())
236
        return recall
237
238
239
    def __precision(self, y_true, y_pred):
240
        y_true = K.ones_like(y_true) 
241
        true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
242
        
243
        predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
244
        precision = true_positives / (predicted_positives + K.epsilon())
245
        return precision
246
247
248
    def __f1_score(self, y_true, y_pred):
249
        precision = self.__precision(y_true, y_pred)
250
        recall = self.__recall(y_true, y_pred)
251
        return 2*((precision*recall)/(precision+recall+K.epsilon()))
252
253
254
    def __register_model(self, metrics_results):
255
        tf.keras.models.save_model(
256
            self._model, "./model", overwrite=True, include_optimizer=True, save_format=tf,
257
            signatures=None, options=None)
258
        Model.register(workspace=self._ws,
259
                    model_path="./model",
260
                    model_name=self._args.model_name,
261
                    properties = {"run_id": self._run.id,
262
                                "experiment": self._run.experiment.name},
263
                    tags={
264
                        "DICE": float(metrics_results[0]),
265
                        "Accuracy": float(metrics_results[1])
266
                    })
267
268
269
if __name__ == '__main__':
270
    tr = Train()
271
    tr.main()