Switch to side-by-side view

--- a
+++ b/Segmentation/train/train.py
@@ -0,0 +1,365 @@
+import sys
+import os
+from glob import glob
+import datetime
+import tensorflow as tf
+import numpy as np
+from time import time
+
+from Segmentation.train.utils import setup_gpu, LearningRateUpdate, Metric
+from Segmentation.train.reshape import get_mid_slice, get_mid_vol
+from Segmentation.train.validation import validate_best_model
+from Segmentation.utils.data_loader import read_tfrecord_3d
+from Segmentation.utils.visualise_utils import visualise_sample
+from Segmentation.utils.losses import dice_loss, tversky_loss, iou_loss
+from Segmentation.utils.losses import iou_loss_eval_3d, dice_coef_eval_3d
+from Segmentation.utils.losses import dice_loss_weighted_3d, focal_tversky
+from Segmentation.model.vnet import VNet
+
+class Train:
+    def __init__(self,
+                 epochs,
+                 batch_size,
+                 enable_function,
+                 model,
+                 optimizer,
+                 loss_func,
+                 lr_manager,
+                 predict_slice,
+                 metrics,
+                 tfrec_dir='./Data/tfrecords/',
+                 log_dir="logs"):
+
+
+        self.epochs = epochs
+        self.batch_size = batch_size
+        self.enable_function = enable_function
+        self.model = model
+        self.optimizer = optimizer
+        self.loss_func = loss_func
+        self.lr_manager = lr_manager
+        self.predict_slice = predict_slice
+        self.metrics = Metric(metrics)
+        self.tfrec_dir = tfrec_dir
+        self.log_dir = log_dir
+
+    def train_step(self,
+                   x_train,
+                   y_train,
+                   visualise):
+        with tf.GradientTape() as tape:
+            predictions = self.model(x_train, training=True)
+            loss = self.loss_func(y_train, predictions)
+        grads = tape.gradient(loss, self.model.trainable_variables)
+        self.optimizer.apply_gradients(zip(grads, self.model.trainable_variables))
+        self.metrics.store_metric(y_train, predictions, training=True)
+        if visualise:
+            return loss, predictions
+        return loss, None
+
+    def test_step(self,
+                  x_test,
+                  y_test,
+                  visualise):
+        predictions = self.model(x_test, training=False)
+        loss = self.loss_func(y_test, predictions)
+        self.metrics.store_metric(y_test, predictions, training=False)
+        if visualise:
+            return loss, predictions
+        return loss, None
+
+    def train_model_loop(self,
+                         train_ds,
+                         valid_ds,
+                         strategy,
+                         multi_class,
+                         visual_save_freq=5,
+                         debug=False,
+                         num_to_visualise=0):
+        """ Trains 3D model with custom tf loop and MirrorStrategy
+        """
+
+        def run_train_strategy(x, y, visualise):
+            total_step_loss, pred = strategy.run(self.train_step, args=(x, y, visualise, ))
+            return strategy.reduce(
+                tf.distribute.ReduceOp.SUM, total_step_loss, axis=None), pred
+
+        def run_test_strategy(x, y, visualise):
+            total_step_loss, pred = strategy.run(self.test_step, args=(x, y, visualise, ))
+            return strategy.reduce(
+                tf.distribute.ReduceOp.SUM, total_step_loss, axis=None), pred
+
+        # TODO(Joe): This needs to be rewritten so that it works with 2D as well
+        def distributed_train_epoch(train_ds,
+                                    epoch,
+                                    strategy,
+                                    num_to_visualise,
+                                    multi_class,
+                                    slice_writer,
+                                    vol_writer,
+                                    visual_save_freq,
+                                    predict_slice):
+
+            total_loss, num_train_batch = 0.0, 0.0
+            is_training = True
+            use_2d = False
+            for x_train, y_train in train_ds:
+                visualise = (num_train_batch < num_to_visualise)
+                loss, pred = run_train_strategy(x_train, y_train, visualise)
+                loss /= strategy.num_replicas_in_sync
+                total_loss += loss
+                if visualise:
+                    num_to_visualise = visualise_sample(x_train, y_train, pred, 
+                                                        num_to_visualise,
+                                                        slice_writer, vol_writer, 
+                                                        use_2d, epoch, multi_class, predict_slice, is_training)
+                num_train_batch += 1
+            return total_loss / num_train_batch
+
+        def distributed_test_epoch(valid_ds,
+                                   epoch,
+                                   strategy,
+                                   num_to_visualise,
+                                   multi_class,
+                                   slice_writer,
+                                   vol_writer,
+                                   visual_save_freq,
+                                   predict_slice):
+            total_loss, num_test_batch = 0.0, 0.0
+            is_training = False
+            use_2d = False
+            for x_valid, y_valid in valid_ds:
+                visualise = (num_test_batch < num_to_visualise)
+                loss, pred = run_test_strategy(x_valid, y_valid, visualise)
+                loss /= strategy.num_replicas_in_sync
+                total_loss += loss
+                if visualise:
+                    num_to_visualise = visualise_sample(x_train, y_train, pred, 
+                                                        num_to_visualise,
+                                                        slice_writer, vol_writer, 
+                                                        use_2d, epoch, multi_class, predict_slice, is_training)
+                num_test_batch += 1
+            return total_loss / num_test_batch
+
+        if self.enable_function:
+            run_train_strategy = tf.function(run_train_strategy)
+            run_test_strategy = tf.function(run_test_strategy)
+
+        # TODO: This whole chunk of code needs to be refactored. Perhaps write it as a function
+        name = "/" + self.model.name
+        db = "/debug" if debug else "/test"
+        mc = "/multi" if multi_class else "/binary"
+        log_dir_now = self.log_dir + name + db + mc + datetime.datetime.now().strftime("/%Y%m%d/%H%M%S")
+        train_summary_writer = tf.summary.create_file_writer(log_dir_now + '/train')
+        test_summary_writer = tf.summary.create_file_writer(log_dir_now + '/val')
+        test_min_summary_writer = tf.summary.create_file_writer(log_dir_now + '/val_min')
+        train_img_slice_writer = tf.summary.create_file_writer(log_dir_now + '/train/img/slice')
+        test_img_slice_writer = tf.summary.create_file_writer(log_dir_now + '/val/img/slice')
+        train_img_vol_writer = tf.summary.create_file_writer(log_dir_now + '/train/img/vol')
+        test_img_vol_writer = tf.summary.create_file_writer(log_dir_now + '/val/img/vol')
+        lr_summary_writer = tf.summary.create_file_writer(log_dir_now + '/lr')
+
+        self.metrics.add_metric_summary_writer(log_dir_now)
+
+        best_loss = None
+        for e in range(self.epochs):
+            self.optimizer.learning_rate = self.lr_manager.update_lr(e)
+
+            et0 = time()
+
+            train_loss = distributed_train_epoch(train_ds,
+                                                 e,
+                                                 strategy,
+                                                 num_to_visualise,
+                                                 multi_class,
+                                                 train_img_slice_writer,
+                                                 train_img_vol_writer,
+                                                 visual_save_freq,
+                                                 self.predict_slice)
+
+            with train_summary_writer.as_default():
+                tf.summary.scalar('epoch_loss', train_loss, step=e)
+
+            # distributed_test_epoch(valid_ds,
+            #                        e,
+            #                        strategy,
+            #                        num_to_visualise,
+            #                        multi_class,
+            #                        test_img_slice_writer,
+            #                        test_img_vol_writer,
+            #                        visual_save_freq,
+            #                        self.predict_slice)
+            test_loss = distributed_test_epoch(valid_ds,
+                                               e,
+                                               strategy,
+                                               num_to_visualise,
+                                               multi_class,
+                                               test_img_slice_writer,
+                                               test_img_vol_writer,
+                                               visual_save_freq,
+                                               self.predict_slice)
+            with test_summary_writer.as_default():
+                tf.summary.scalar('epoch_loss', test_loss, step=e)
+
+            current_lr = self.optimizer.get_config()['learning_rate']
+            with lr_summary_writer.as_default():
+                tf.summary.scalar('epoch_lr', current_lr, step=e)
+
+            self.metrics.record_metric_to_summary(e)
+            metric_str = self.metrics.reset_metrics_get_str()
+            print(f"Epoch {e+1}/{self.epochs} - {time() - et0:.0f}s - loss: {train_loss:.05f} - val_loss: {test_loss:.05f} - lr: {self.optimizer.get_config()['learning_rate']: .06f}" + metric_str)
+
+            if best_loss is None:
+                self.model.save_weights(os.path.join(log_dir_now + f'/best_weights.tf'))
+                best_loss = test_loss
+            else:
+                if test_loss < best_loss:
+                    self.model.save_weights(os.path.join(log_dir_now + f'/best_weights.tf'))
+                    best_loss = test_loss
+            with test_min_summary_writer.as_default():
+                tf.summary.scalar('epoch_loss', best_loss, step=e)
+        return log_dir_now
+
+
+def load_datasets(batch_size, buffer_size,
+                  tfrec_dir='./Data/tfrecords/',
+                  multi_class=False,
+                  crop_size=144,
+                  depth_crop_size=80,
+                  aug=[],
+                  predict_slice=False,
+                  ):
+    """
+    Loads tf records datasets for 3D models.
+    """
+    args = {
+        'batch_size': batch_size,
+        'buffer_size': buffer_size,
+        'multi_class': multi_class,
+        'use_keras_fit': False,
+        'crop_size': crop_size, 
+        'depth_crop_size': depth_crop_size,
+        'aug': aug,
+    }
+    train_ds = read_tfrecord_3d(tfrecords_dir=os.path.join(tfrec_dir, 'train_3d/'),
+                                is_training=True, predict_slice=predict_slice, **args)
+    valid_ds = read_tfrecord_3d(tfrecords_dir=os.path.join(tfrec_dir, 'valid_3d/'),
+                                is_training=False, predict_slice=predict_slice, **args)
+    return train_ds, valid_ds
+
+
+def build_model(num_channels, num_classes, name, **kwargs):
+    """
+    Builds standard vnet for 3D
+    """
+    model = VNet(num_channels, num_classes, name=name, **kwargs)
+    return model
+
+
+def main(epochs,
+         name,
+         log_dir_now=None,
+         batch_size=2,
+         val_batch_size=2,
+         lr=1e-4,
+         lr_drop=0.9,
+         lr_drop_freq=5,
+         lr_warmup=3,
+         num_to_visualise=2,
+         num_channels=4,
+         buffer_size=4,
+         enable_function=True,
+         tfrec_dir='./Data/tfrecords/',
+         multi_class=False,
+         crop_size=144,
+         depth_crop_size=80,
+         aug=[],
+         debug=False,
+         predict_slice=False,
+         tpu=False,
+         min_lr=1e-7,
+         custom_loss=None,
+         **model_kwargs,
+         ):
+    t0 = time()
+
+    if tpu:
+        tfrec_dir = 'gs://oai-challenge-dataset/tfrecords'
+
+    num_classes = 7 if multi_class else 1
+
+    metrics = {
+        'losses': {
+            'mIoU': [iou_loss, tf.keras.metrics.Mean(), tf.keras.metrics.Mean(), None, None],
+            'dice': [dice_loss, tf.keras.metrics.Mean(), tf.keras.metrics.Mean(), None, None]
+        },
+    }
+
+    if multi_class:
+        metrics['losses']['mIoU-6ch'] = [iou_loss_eval_3d, tf.keras.metrics.Mean(), tf.keras.metrics.Mean(), None, None]
+        metrics['losses']['dice-6ch'] = [dice_coef_eval_3d, tf.keras.metrics.Mean(), tf.keras.metrics.Mean(), None, None]
+
+    train_ds, valid_ds = load_datasets(batch_size, buffer_size, tfrec_dir, multi_class,
+                                       crop_size=crop_size, depth_crop_size=depth_crop_size, aug=aug,
+                                       predict_slice=predict_slice)
+
+    num_gpu = len(tf.config.experimental.list_physical_devices('GPU'))
+    steps_per_epoch = len(glob(os.path.join(tfrec_dir, 'train_3d/*'))) / (batch_size)
+
+    if tpu:
+        resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='pit-tpu')
+        tf.config.experimental_connect_to_cluster(resolver)
+        tf.tpu.experimental.initialize_tpu_system(resolver)
+        strategy = tf.distribute.experimental.TPUStrategy(resolver)
+    else:
+        strategy = tf.distribute.MirroredStrategy()
+    # strategy = tf.distribute.OneDeviceStrategy(device="/gpu:0")
+    with strategy.scope():
+        if custom_loss is None:
+            loss_func = tversky_loss if multi_class else dice_loss
+        elif multi_class and custom_loss == "weighted":
+            loss_func = dice_loss_weighted_3d
+        elif multi_class and custom_loss == "focal":
+            loss_func = focal_tversky
+        else:
+            raise NotImplementedError(f"Custom loss: {custom_loss} not implemented.")
+
+        lr_manager = LearningRateUpdate(lr, lr_drop, lr_drop_freq, warmup=lr_warmup, min_lr=min_lr)
+
+        optimizer = tf.keras.optimizers.Adam(learning_rate=lr)
+        model = build_model(num_channels, num_classes, name, predict_slice=predict_slice, **model_kwargs)
+
+        trainer = Train(epochs, batch_size, enable_function,
+                        model, optimizer, loss_func, lr_manager, predict_slice, metrics,
+                        tfrec_dir=tfrec_dir)
+
+        train_ds = strategy.experimental_distribute_dataset(train_ds)
+        valid_ds = strategy.experimental_distribute_dataset(valid_ds)
+
+        if log_dir_now is None:
+            log_dir_now = trainer.train_model_loop(train_ds, valid_ds, strategy, multi_class, debug, num_to_visualise)
+
+    train_time = time() - t0
+    print(f"Train Time: {train_time:.02f}")
+    t1 = time()
+    with strategy.scope():
+        model = build_model(num_channels, num_classes, name, predict_slice=predict_slice, **model_kwargs)
+        model.load_weights(os.path.join(log_dir_now + f'/best_weights.tf')).expect_partial()
+    print("Validation for:", log_dir_now)
+
+    if not predict_slice:
+        total_loss, metric_str = validate_best_model(model,
+                                                     log_dir_now,
+                                                     val_batch_size,
+                                                     buffer_size,
+                                                     tfrec_dir,
+                                                     multi_class,
+                                                     crop_size,
+                                                     depth_crop_size,
+                                                     predict_slice,
+                                                     Metric(metrics))
+        print(f"Train Time: {train_time:.02f}")
+        print(f"Validation Time: {time() - t1:.02f}")              
+        print(f"Total Time: {time() - t0:.02f}")
+        with open("results/3d_result.txt", "a") as f:
+            f.write(f'{log_dir_now}: total_loss {total_loss} {metric_str} \n')