--- a
+++ b/main.py
@@ -0,0 +1,247 @@
+import tensorflow as tf
+
+import os
+from pathlib import Path
+from datetime import datetime
+from absl import app
+from absl import logging
+
+from Segmentation.utils.data_loader import read_tfrecord_2d as read_tfrecord
+from Segmentation.utils.data_loader import parse_fn_2d, parse_fn_3d
+from Segmentation.utils.losses import dice_coef_loss, tversky_loss, dice_coef, iou_loss  # focal_tversky
+from Segmentation.utils.evaluation_metrics import dice_coef_eval, iou_loss_eval
+from Segmentation.utils.training_utils import LearningRateSchedule
+# from Segmentation.utils.evaluation_utils import plot_and_eval_3D, confusion_matrix, epoch_gif, volume_gif, take_slice
+from Segmentation.utils.evaluation_utils import eval_loop
+from Segmentation.train.train import Train
+
+from flags import FLAGS
+from select_model import select_model
+
+
+def main(argv):
+
+    if FLAGS.visual_file:
+        assert FLAGS.train is False, "Train must be set to False if you are doing a visual."
+
+    del argv  # unused arg
+    tf.random.set_seed(FLAGS.seed)  # set seed
+
+    # set whether to train on GPU or TPU
+    if FLAGS.use_gpu:
+        logging.info('Using GPU...')
+        # strategy requires: export TF_FORCE_GPU_ALLOW_GROWTH=true to be wrote in cmd
+        if FLAGS.num_cores == 1:
+            strategy = tf.distribute.OneDeviceStrategy(device="/gpu:0")
+        else:
+            strategy = tf.distribute.MirroredStrategy()
+        gpus = tf.config.experimental.list_physical_devices('GPU')
+        if gpus:
+            for gpu in gpus:
+                try:
+                    tf.config.experimental.set_visible_devices(gpu, 'GPU')
+                    logical_gpus = tf.config.experimental.list_logical_devices('GPU')
+
+                    logging.info(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPU")
+                except RuntimeError as e:
+                    # Visible devices must be set before GPUs have been initialized
+                    print(e)
+    else:
+        logging.info('Use TPU at %s',
+                     FLAGS.tpu if FLAGS.tpu is not None else 'local')
+        resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=FLAGS.tpu)
+        tf.config.experimental_connect_to_cluster(resolver)
+        tf.tpu.experimental.initialize_tpu_system(resolver)
+        strategy = tf.distribute.experimental.TPUStrategy(resolver)
+
+    # set dataset configuration
+    if FLAGS.dataset == 'oai_challenge':
+
+        batch_size = FLAGS.batch_size * FLAGS.num_cores
+
+        if FLAGS.use_2d:
+            steps_per_epoch = 19200 // batch_size
+            validation_steps = 4480 // batch_size
+        else:
+            steps_per_epoch = 120 // batch_size
+            validation_steps = 28 // batch_size
+
+        logging.info('Using Augmentation Strategy: {}'.format(FLAGS.aug_strategy))
+
+        train_dir = 'train/' if FLAGS.use_2d else 'train_3d/'
+        valid_dir = 'valid/' if FLAGS.use_2d else 'valid_3d/'
+
+        ds_args = {
+            'batch_size': batch_size,
+            'buffer_size': FLAGS.buffer_size,
+            'augmentation': FLAGS.aug_strategy,
+            'parse_fn': parse_fn_2d if FLAGS.use_2d else parse_fn_3d,
+            'multi_class': FLAGS.multi_class,
+            'is_training': True,
+            'use_bfloat16': FLAGS.use_bfloat16,
+            'use_RGB': False if FLAGS.backbone_architecture == 'default' else True
+        }
+
+        train_ds = read_tfrecord(tfrecords_dir=os.path.join(FLAGS.tfrec_dir, train_dir),
+                                 **ds_args)
+        valid_ds = read_tfrecord(tfrecords_dir=os.path.join(FLAGS.tfrec_dir, valid_dir),
+                                 **ds_args)
+
+        num_classes = 7 if FLAGS.multi_class else 1
+
+    if FLAGS.multi_class:
+        loss_fn = tversky_loss
+        crossentropy_loss_fn = tf.keras.losses.categorical_crossentropy
+    else:
+        loss_fn = dice_coef_loss
+        crossentropy_loss_fn = tf.keras.losses.binary_crossentropy
+
+    if FLAGS.use_bfloat16:
+        policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16')
+        tf.keras.mixed_precision.experimental.set_policy(policy)
+
+    # set model architecture
+
+    model_fn, model_args = select_model(FLAGS, num_classes)
+
+    with strategy.scope():
+        model = model_fn(*model_args)
+
+        if FLAGS.custom_decay_lr:
+            lr_decay_epochs = FLAGS.lr_decay_epochs
+        else:
+            lr_decay_epochs = list(range(FLAGS.lr_warmup_epochs + 1, FLAGS.train_epochs))
+
+        lr_rate = LearningRateSchedule(steps_per_epoch,
+                                       FLAGS.base_learning_rate,
+                                       FLAGS.lr_drop_ratio,
+                                       lr_decay_epochs,
+                                       FLAGS.lr_warmup_epochs)
+
+        if FLAGS.optimizer == 'adam':
+            optimiser = tf.keras.optimizers.Adam(learning_rate=lr_rate)
+        elif FLAGS.optimizer == 'rms-prop':
+            optimiser = tf.keras.optimizers.RMSprop(learning_rate=lr_rate)
+        elif FLAGS.optimizer == 'sgd':
+            optimiser = tf.keras.optimizers.SGD(learning_rate=lr_rate)
+        else:
+            print('Not a valid input optimizer, using Adam.')
+            optimiser = tf.keras.optimizers.Adam(learning_rate=lr_rate)
+
+        # for some reason, if i build the model then it can't load checkpoints. I'll see what I can do about this
+        if FLAGS.train:
+            if FLAGS.model_architecture != 'vnet':
+                if FLAGS.backbone_architecture == 'default':
+                    model.build((None, 288, 288, 1))
+                else:
+                    model.build((None, 288, 288, 3))
+            else:
+                model.build((None, 160, 384, 384, 1))
+            model.summary()
+
+        if FLAGS.multi_class:
+            if FLAGS.use_2d:
+                metrics = [dice_coef, iou_loss, dice_coef_eval, iou_loss_eval, crossentropy_loss_fn, 'acc']
+                # model.compile(optimizer=optimiser,
+                #               loss=loss_fn,
+                #               metrics=[dice_coef, iou_loss, dice_coef_eval, iou_loss_eval, crossentropy_loss_fn, 'acc'])
+            else:
+                metrics = [dice_coef, iou_loss, crossentropy_loss_fn, 'acc']
+                # model.compile(optimizer=optimiser,
+                #               loss=loss_fn,
+                #               metrics=[dice_coef, iou_loss, crossentropy_loss_fn, 'acc'])
+        else:
+            metrics = [dice_coef, iou_loss, crossentropy_loss_fn, 'acc']
+            # model.compile(optimizer=optimiser,
+            #               loss=loss_fn,
+            #               metrics=[dice_coef, iou_loss, crossentropy_loss_fn, 'acc'])
+
+        model.compile(optimizer=optimiser,
+                        loss=loss_fn,
+                        metrics=metrics)
+
+    if FLAGS.train:
+        # define checkpoints
+        time = datetime.now().strftime("%Y%m%d-%H%M%S")
+        training_history_dir = os.path.join(FLAGS.fig_dir, FLAGS.tpu)
+        training_history_dir = os.path.join(training_history_dir, time)
+        Path(training_history_dir).mkdir(parents=True, exist_ok=True)
+        flag_name = os.path.join(training_history_dir, 'test_flags.cfg')
+        FLAGS.append_flags_into_file(flag_name)
+
+        logdir = os.path.join(FLAGS.logdir, FLAGS.tpu)
+        logdir = os.path.join(logdir, time)
+        logdir_arch = os.path.join(logdir, FLAGS.model_architecture)
+        ckpt_cb = tf.keras.callbacks.ModelCheckpoint(logdir_arch + '_weights.{epoch:03d}.ckpt',
+                                                     save_best_only=False,
+                                                     save_weights_only=True)
+        tb = tf.keras.callbacks.TensorBoard(logdir, update_freq='epoch')
+
+        # history = model.fit(train_ds,
+        #                     steps_per_epoch=steps_per_epoch,
+        #                     epochs=FLAGS.train_epochs,
+        #                     validation_data=valid_ds,
+        #                     validation_steps=validation_steps,
+        #                     callbacks=[ckpt_cb, tb])
+
+        lr_manager = LearningRateSchedule(steps_per_epoch=steps_per_epoch,
+                                          initial_learning_rate=FLAGS.base_learning_rate,
+                                          drop=FLAGS.lr_drop_ratio,
+                                          epochs_drop=FLAGS.lr_decay_epochs,
+                                          warmup_epochs=FLAGS.lr_warmup_epochs)
+        
+        train = Train(epochs=FLAGS.train_epochs,
+                      batch_size=FLAGS.batch_size,
+                      enable_function=True,
+                      model=model,
+                      optimizer=optimiser,
+                      loss_func=loss_fn,
+                      lr_manager=lr_manager,
+                      predict_slice=FLAGS.which_slice,
+                      metrics=metrics,
+                      tfrec_dir='./Data/tfrecords/',
+                      log_dir="logs")
+
+        log_dir_now = train.train_model_loop(train_ds=train_ds,
+                                             valid_ds=valid_ds,
+                                             strategy=strategy,
+                                             visual_save_freq=FLAGS.visual_save_freq,
+                                             multi_class=FLAGS.multi_class,
+                                             debug=False,
+                                             num_to_visualise=0)
+
+
+    elif FLAGS.visual_file is not None:
+        tpu = FLAGS.tpu_dir if FLAGS.tpu_dir else FLAGS.tpu
+
+        eval_loop(dataset=valid_ds, validation_steps=validation_steps, aug_strategy=FLAGS.aug_strategy,
+                  bucket_name=FLAGS.bucket, logdir=FLAGS.logdir, tpu_name=tpu, visual_file=FLAGS.visual_file, weights_dir=FLAGS.weights_dir,
+                  fig_dir=FLAGS.fig_dir,
+                  which_volume=FLAGS.gif_volume, which_epoch=FLAGS.gif_epochs, which_slice=FLAGS.gif_slice,
+                  multi_as_binary=False,
+                  trained_model=model, model_architecture=FLAGS.model_architecture,
+                  callbacks=[tb],
+                  num_classes=num_classes)
+
+    else:
+        # load the checkpoint in the FLAGS.weights_dir file
+        # maybe_weights = os.path.join(FLAGS.weights_dir, FLAGS.tpu, FLAGS.visual_file)
+
+        time = datetime.now().strftime("%Y%m%d-%H%M%S")
+        logdir = os.path.join(FLAGS.logdir, FLAGS.tpu)
+        logdir = os.path.join(logdir, time)
+        tb = tf.keras.callbacks.TensorBoard(logdir, update_freq='epoch', write_images=True)
+        # confusion_matrix(trained_model=model,
+        #                  weights_dir=FLAGS.weights_dir,
+        #                  fig_dir=FLAGS.fig_dir,
+        #                  dataset=valid_ds,
+        #                  validation_steps=validation_steps,
+        #                  multi_class=FLAGS.multi_class,
+        #                  model_architecture=FLAGS.model_architecture,
+        #                  callbacks=[tb],
+        #                  num_classes=num_classes
+        #                  )
+
+
+if __name__ == '__main__':
+    app.run(main)