Switch to side-by-side view

--- a
+++ b/Segmentation/utils/evaluation_utils.py
@@ -0,0 +1,1104 @@
+import tensorflow as tf
+import numpy as np
+import matplotlib.pyplot as plt
+import matplotlib.animation as animation
+from matplotlib.animation import ArtistAnimation
+
+import glob
+from google.cloud import storage
+from pathlib import Path
+import os
+import datetime
+
+from Segmentation.utils.losses import dice_coef
+from Segmentation.plotting.voxels import plot_volume
+# from Segmentation.utils.data_loader import read_tfrecord_2d
+from Segmentation.utils.training_utils import visualise_binary, visualise_multi_class
+from Segmentation.utils.evaluation_metrics import get_confusion_matrix, plot_confusion_matrix, iou_loss_eval, dice_coef_eval
+from Segmentation.utils.losses import dice_coef, iou_loss
+
+def get_depth(conc):
+    depth = 0
+    for batch in conc:
+        depth += batch.shape[0]
+    return depth
+
+
+def get_all_weights(bucket_name, logdir, tpu_name, visual_file, weights_dir):
+    """ Load the checkpoints in the specified log directory """
+
+    ######################
+    """ Add the visualisation code here """
+    print("+========================================================")
+    print('bucket_name',bucket_name)
+    print("\n\nThe directories are:")
+    print('weights_dir == "checkpoint"',weights_dir == "checkpoint")
+    print('weights_dir',weights_dir)
+    ######################
+
+    session_name = weights_dir.split('/')[3]
+    session_name = os.path.join(session_name, tpu_name, visual_file)
+    # Pietro's: session_name = os.path.join(weights_dir, tpu_name, visual_file)
+
+    # Get names within folder in gcloud
+    storage_client = storage.Client()
+    blobs = storage_client.list_blobs(bucket_name)
+    session_content = []
+    print('session_name',session_name)
+    for blob in blobs:
+        if session_name in blob.name:
+            session_content.append(blob.name)
+
+    session_weights = []
+    for item in session_content:
+        if ('_weights' in item) and ('.ckpt.index' in item):
+            session_weights.append(item)
+
+    ######################
+    for s in session_weights:
+        print(s) #print all the checkpoint directories
+    print("--")
+    ######################
+
+    return session_weights
+
+# def plot_and_eval_3D(model,
+                    #  logdir,
+                    #  visual_file,
+                    #  tpu_name,
+                    #  bucket_name,
+                    #  weights_dir,
+                    #  multi_class,
+                    #  save_freq,
+                    #  dataset,
+                    #  model_args):
+
+    # """ plotly: Generates a numpy volume for every #save_freq number of weights
+    #     and saves it in local results/pred/*visual_file* and results/y/*visual_file*
+
+    #     Once numpy's are generated, run the following in console to get an embeddable html file:
+    #         python3 Visualization/plotly_3d_voxels/run_plotly.py -dir_l FOLDER_TO_Y_SAMPLES
+    #          -dir_r FOLDER_TO_PREDICTIONS
+
+
+    # """
+
+    # session_weights = get_all_weights(bucket_name, logdir, tpu_name, visual_file, weights_dir)
+
+    # # Only use part of dataset
+    # idx_vol= 0 # how many numpies have been save
+    # target = 160
+    
+    # for i, chkpt in enumerate(session_weights):
+        
+    #     should_save_np = np.mod((i+1), save_freq) == 0
+        
+    #     ######################
+    #     # print('should_save_np',should_save_np)
+    #     # print('checkpoint enum i',i)
+    #     # print('save_freq set to ',save_freq)
+    #     ######################
+
+    #     if not should_save_np:      # skip this checkpoint weight
+    #         print("Skipping weight", chkpt)
+    #         continue
+
+
+    #     name = chkpt.split('/')[-1]
+    #     name = name.split('.inde')[0]
+    #     trained_model = model(*model_args)
+    #     trained_model.load_weights('gs://' + os.path.join(bucket_name,
+    #                                                     'checkpoints',
+    #                                                     tpu_name,
+    #                                                     visual_file,
+    #                                                     name)).expect_partial()
+
+
+
+    #     # sample_x = []    # x for current 160,288,288 vol
+    #     sample_pred = []  # prediction for current 160,288,288 vol
+    #     sample_y = []    # y for current 160,288,288 vol
+
+    #     which_volume = 2
+    #     for idx, ds in enumerate(dataset):
+
+    #         ######################
+    #         print('Current chkpt name',name)
+    #         print(f"the index is {idx}")
+    #         ######################
+
+
+    #         x, y = ds
+    #         batch_size = x.shape[0]
+
+    #         if batch_size == 160:
+    #             if not (int(idx) == int(which_volume)):
+    #                 continue
+
+    #         x = np.array(x)
+    #         y = np.array(y)
+        
+    #         pred = trained_model.predict(x)
+
+    #         ######################
+    #         # print("Current batch size set to {}. Target depth is {}".format(batch_size, target))
+    #         # print('Input image data type: {}, shape: {}'.format(type(x), x.shape))
+    #         # print('Ground truth data type: {}, shape: {}'.format(type(y), y.shape))
+    #         # print('Prediction data type: {}, shape: {}'.format(type(pred), pred.shape))
+    #         # print("=================")
+    #         ######################
+
+    #         if (get_depth(sample_pred) + batch_size) < target:  # check if next batch will fit in volume (160)
+    #             sample_pred.append(pred)
+    #             del pred
+    #             sample_y.append(y)
+    #             del y
+    #         else:
+    #             remaining = target - get_depth(sample_pred)
+    #             sample_pred.append(pred[:remaining])
+    #             sample_y.append(y[:remaining])
+    #             pred_vol = np.concatenate(sample_pred)
+    #             del sample_pred
+    #             y_vol = np.concatenate(sample_y)
+    #             del sample_y
+    #             sample_pred = [pred[remaining:]]
+    #             sample_y = [y[remaining:]]
+
+    #             del pred
+    #             del y
+
+    #             ######################
+    #             # print("===============")
+    #             # print("pred done")
+    #             # print(pred_vol.shape)
+    #             # print(y_vol.shape)
+    #             # print("===============")
+    #             # print('multi_class', multi_class)
+    #             ######################
+
+    #             if multi_class:  # or np.shape(pred_vol)[-1] not
+    #                 pred_vol = np.argmax(pred_vol, axis=-1)
+    #                 y_vol = np.argmax(y_vol, axis=-1)
+
+    #                 ######################
+    #                 # print('np.shape(pred_vol)', np.shape(pred_vol))
+    #                 # print('np.shape(y_vol)',np.shape(y_vol))
+    #                 ######################
+
+    #             # Save volume as numpy file for plotlyyy
+    #             fig_dir = "results"
+    #             name_pred_npy = os.path.join(fig_dir, "pred", (visual_file + "_" + name))
+    #             name_y_npy = os.path.join(fig_dir, "ground_truth", (visual_file + "_" + str(which_volume).zfill(3)))
+                
+    #             ######################
+    #             # print("npy save pred as ", name_pred_npy)
+    #             # print("npy save y as ", name_y_npy)
+    #             # print("Currently on vol ", idx_vol)
+    #             ######################
+
+
+    #             # Get middle xx slices cuz 288x288x160 too big
+    #             roi = int(80 / 2)
+    #             d1,d2,d3 = np.shape(pred_vol)[0:3]
+    #             d1, d2, d3 = int(np.floor(d1/2)), int(np.floor(d2/2)), int(np.floor(d3/2))
+    #             pred_vol = pred_vol[(d1-roi):(d1+roi),(d2-roi):(d2+roi), (d3-roi):(d3+roi)]
+    #             d1,d2,d3 = np.shape(y_vol)[0:3]
+    #             d1, d2, d3 = int(np.floor(d1/2)), int(np.floor(d2/2)), int(np.floor(d3/2))
+    #             y_vol = y_vol[(d1-roi):(d1+roi),(d2-roi):(d2+roi), (d3-roi):(d3+roi)]
+
+    #             ######################
+    #             print('y_vol.shape', np.shape(y_vol))
+    #             ######################
+
+    #             np.save(name_pred_npy,pred_vol)
+    #             np.save(name_y_npy,y_vol)
+    #             idx_vol += 1
+    #             ######################
+    #             print("Total voxels saved, pred:", np.sum(pred_vol), "y:", np.sum(y_vol))
+    #             ######################
+    #             del pred_vol
+    #             del y_vol
+
+    #             break
+
+# def epoch_gif(model,
+            #   logdir,
+            #   tfrecords_dir,
+            #   aug_strategy,
+            #   visual_file,
+            #   tpu_name,
+            #   bucket_name,
+            #   weights_dir,
+            #   multi_class,
+            #   model_args,
+            #   which_slice,
+            #   which_volume=1,
+            #   epoch_limit=1000,
+            #   gif_dir='',
+            #   gif_cmap='gray',
+            #   clean=False):
+
+    # #load the database
+    # valid_ds = read_tfrecord_2d(tfrecords_dir=tfrecords_dir, #'gs://oai-challenge-dataset/tfrecords/valid/',
+    #                         batch_size=160,
+    #                         buffer_size=500,
+    #                         augmentation=aug_strategy,
+    #                         multi_class=multi_class,
+    #                         is_training=False,
+    #                         use_bfloat16=False,
+    #                         use_RGB=False)
+
+    # # load the checkpoints in the specified log directory
+    # session_weights = get_all_weights(bucket_name, logdir, tpu_name, visual_file, weights_dir)
+
+    # #figure for gif
+    # fig, ax = plt.subplots()
+    # images_gif = []
+
+    # for chkpt in session_weights:
+    #     name = chkpt.split('/')[-1]
+    #     name = name.split('.inde')[0]
+
+    #     if int(name.split('.')[1]) <= epoch_limit:
+
+    #         print("\n\n\n\n+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
+    #         print(f"\t\tLoading weights from {name.split('.')[1]} epoch")
+    #         print(f"\t\t  {name}")
+    #         print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n")
+
+    #         trained_model = model(*model_args)
+    #         trained_model.load_weights('gs://' + os.path.join(bucket_name,
+    #                                                           weights_dir,
+    #                                                           tpu_name,
+    #                                                           visual_file,
+    #                                                           name)).expect_partial()
+
+    #         for idx, ds in enumerate(valid_ds):
+
+    #             if idx+1 == which_volume:
+    #                 x, _ = ds
+    #                 x_slice = np.expand_dims(x[which_slice-1], axis=0)
+    #                 print('Input image data type: {}, shape: {}\n'.format(type(x_slice), x_slice.shape))
+
+    #                 print('predicting slice {}'.format(which_slice))
+    #                 predicted_slice = trained_model.predict(x_slice)
+    #                 if multi_class:
+    #                     predicted_slice = np.argmax(predicted_slice, axis=-1)
+    #                 else:
+    #                     predicted_slice = np.squeeze(predicted_slice, axis=-1)
+
+    #                 print('slice predicted\n')
+
+    #                 print("adding prediction to the queue")
+    #                 im = ax.imshow(predicted_slice[0], cmap=gif_cmap, animated=True)
+    #                 if not clean:
+    #                     text = ax.text(0.5,1.05,f"Epoch {int(name.split('.')[1])}", 
+    #                                 size=plt.rcParams["axes.titlesize"],
+    #                                 ha="center", transform=ax.transAxes)
+    #                     images_gif.append([im, text])
+    #                 else:
+    #                     ax.axis('off')
+    #                     images_gif.append([im])
+    #                 print("prediction added\n")
+
+    #                 break
+
+    #     else:
+    #         break
+
+    # pred_evolution_gif(fig, images_gif, save_dir=gif_dir, save=True, no_margins=clean)
+
+# def volume_gif(model,
+            #    logdir,
+            #    tfrecords_dir,
+            #    aug_strategy,
+            #    visual_file,
+            #    tpu_name,
+            #    bucket_name,
+            #    weights_dir,
+            #    multi_class,
+            #    model_args,
+            #    which_epoch,
+            #    which_volume=1,
+            #    gif_dir='',
+            #    gif_cmap='gray',
+            #    clean=False):
+
+    # #load the database
+    # valid_ds = read_tfrecord_2d(tfrecords_dir=tfrecords_dir, #'gs://oai-challenge-dataset/tfrecords/valid/',
+    #                         batch_size=160,
+    #                         buffer_size=500,
+    #                         augmentation=aug_strategy,
+    #                         multi_class=multi_class,
+    #                         is_training=False,
+    #                         use_bfloat16=False,
+    #                         use_RGB=False)
+
+    # # load the checkpoints in the specified log directory
+    # session_weights = get_all_weights(bucket_name, logdir, tpu_name, visual_file, weights_dir)
+
+    # #figure for gif
+    # fig, ax = plt.subplots()
+    # images_gif = []
+
+    # for chkpt in session_weights:
+    #     name = chkpt.split('/')[-1]
+    #     name = name.split('.inde')[0]
+
+    #     if int(name.split('.')[1]) == which_epoch:
+
+    #         print("\n\n\n\n+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
+    #         print(f"\t\tLoading weights from {name.split('.')[1]} epoch")
+    #         print(f"\t\t  {name}")
+    #         print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n")
+
+    #         trained_model = model(*model_args)
+    #         trained_model.load_weights('gs://' + os.path.join(bucket_name,
+    #                                                           weights_dir,
+    #                                                           tpu_name,
+    #                                                           visual_file,
+    #                                                           name)).expect_partial()
+
+    #         for idx, ds in enumerate(valid_ds):
+
+    #             if idx+1 == which_volume:
+    #                 x, _ = ds
+    #                 x = np.array(x)
+    #                 print('Input image data type: {}, shape: {}\n'.format(type(x), x.shape))
+
+    #                 print('predicting volume {}'.format(which_volume))
+    #                 pred_vol = trained_model.predict(x)
+    #                 if multi_class:
+    #                     pred_vol = np.argmax(pred_vol, axis=-1)
+    #                 else:
+    #                     pred_vol = np.squeeze(pred_vol, axis=-1)
+    #                 print('volume predicted\n')
+
+    #                 for i in range(x.shape[0]):
+    #                     print(f"Analysing slice {i+1}")
+    #                     im = ax.imshow(pred_vol[i,:,:], cmap='gray', animated=True, aspect='auto')
+    #                     if not clean:
+    #                         text = ax.text(0.5,1.05,f'Slice {i+1}', 
+    #                                     size=plt.rcParams["axes.titlesize"],
+    #                                     ha="center", transform=ax.transAxes)
+    #                         images_gif.append([im, text])
+    #                     else:
+    #                         ax.axis('off')
+    #                         images_gif.append([im])
+
+    #                 break
+            
+    #         break
+
+    # pred_evolution_gif(fig, images_gif, save_dir=gif_dir, save=True, no_margins=clean)
+
+# def volume_comparison_gif(model,
+                        #   logdir,
+                        #   tfrecords_dir,
+                        #   visual_file,
+                        #   tpu_name,
+                        #   bucket_name,
+                        #   weights_dir,
+                        #   multi_class,
+                        #   model_args,
+                        #   which_epoch,
+                        #   which_volume=1,
+                        #   gif_dir='',
+                        #   gif_cmap='gray',
+                        #   clean=False):
+
+    # #load the database
+    # valid_ds = read_tfrecord_2d(tfrecords_dir=tfrecords_dir, #'gs://oai-challenge-dataset/tfrecords/valid/',
+    #                         batch_size=160,
+    #                         buffer_size=500,
+    #                         augmentation=None,
+    #                         multi_class=multi_class,
+    #                         is_training=False,
+    #                         use_bfloat16=False,
+    #                         use_RGB=False)
+
+    # # load the checkpoints in the specified log directory
+    # session_weights = get_all_weights(bucket_name, logdir, tpu_name, visual_file, weights_dir)
+
+    # #figure for gif
+    # fig, axes = plt.subplots(1, 3)
+    # images_gif = []
+
+    # for chkpt in session_weights:
+    #     name = chkpt.split('/')[-1]
+    #     name = name.split('.inde')[0]
+
+    #     if int(name.split('.')[1]) == which_epoch:
+
+    #         print("\n\n\n\n+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
+    #         print(f"\t\tLoading weights from {name.split('.')[1]} epoch")
+    #         print(f"\t\t  {name}")
+    #         print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n")
+
+    #         trained_model = model(*model_args)
+    #         trained_model.load_weights('gs://' + os.path.join(bucket_name,
+    #                                                           weights_dir,
+    #                                                           tpu_name,
+    #                                                           visual_file,
+    #                                                           name)).expect_partial()
+
+    #         for idx, ds in enumerate(valid_ds):
+
+    #             if idx+1 == which_volume:
+    #                 x, y = ds
+    #                 x = np.array(x)
+    #                 x = np.squeeze(x, axis=-1)
+
+    #                 print('predicting volume {}'.format(which_volume))
+    #                 pred_vol = trained_model.predict(x)
+    #                 if multi_class:
+    #                     pred_vol = np.argmax(pred_vol, axis=-1)
+    #                     y = np.argmax(y, axis=-1)
+    #                 print('volume predicted\n')
+
+    #                 print('input image data type: {}, shape: {}'.format(type(x), x.shape))
+    #                 print('label image data type: {}, shape: {}'.format(type(y), y.shape))
+    #                 print('prediction image data type: {}, shape: {}\n'.format(type(pred), pred.shape))
+
+    #                 for i in range(x.shape[0]):
+    #                     print(f"Analysing slice {i+1}")
+    #                     x_im = axes[0].imshow(x[i,:,:], cmap='gray', animated=True, aspect='auto')
+    #                     y_im = axes[1].imshow(y[i,:,:], cmap='gray', animated=True, aspect='auto')
+    #                     pred_im = axes[2].imshow(pred_vol[i,:,:], cmap='gray', animated=True, aspect='auto')
+    #                     if not clean:
+    #                         text = ax.text(0.5,1.05,f'Slice {i+1}', 
+    #                                     size=plt.rcParams["axes.titlesize"],
+    #                                     ha="center", transform=ax.transAxes)
+    #                         images_gif.append([im, text])
+    #                     else:
+    #                         ax.axis('off')
+    #                         images_gif.append([im])
+
+    #                 break
+            
+    #         break
+
+    # pred_evolution_gif(fig, images_gif, save_dir=gif_dir, save=True, no_margins=False)
+
+
+def pred_evolution_gif(fig,
+                       frames_list,
+                       interval=300,
+                       save_dir='',
+                       save=True,
+                       no_margins=True,
+                       show=False):
+
+    print("\n\n\n\n=================")
+    print("checking for ffmpeg...")
+    if not os.path.isfile('./../../../opt/conda/bin/ffmpeg'):
+        print("please 'pip install ffmpeg' to create gif")
+        print("gif not created")
+        
+    else:
+        print("ffmpeg found")
+        print("creating the gif ...\n")
+        gif = ArtistAnimation(fig, frames_list, interval, repeat=True) # create gif
+
+        if save:
+            if no_margins:
+                plt.tight_layout()
+                plt.gca().set_axis_off()
+                plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, 
+                                    hspace = 0, wspace = 0)
+                plt.margins(0,0)
+                plt.gca().xaxis.set_major_locator(plt.NullLocator())
+                plt.gca().yaxis.set_major_locator(plt.NullLocator())
+
+            if save_dir == '':
+                time = datetime.now().strftime("%Y%m%d-%H%M%S")
+                save_dir = 'results/gif'+ time + '.gif'
+
+            plt.rcParams['animation.ffmpeg_path'] = r'//opt//conda//bin//ffmpeg'  # set directory of ffmpeg binary file
+            Writer = animation.writers['ffmpeg']
+            ffmwriter = Writer(fps=1000//interval, metadata=dict(artist='Me'), bitrate=1800) #set the save writer
+            gif.save('results/temp_video.mp4', writer=ffmwriter)
+
+            codeBASH = f"ffmpeg -i 'results/temp_video.mp4' -loop 0 {save_dir}" #convert mp4 to gif
+            os.system(codeBASH)
+            os.remove("results/temp_video.mp4")
+
+            plt.close('all')
+
+        if show:
+            plt.show()
+            plt.close('all')
+        
+        print("\n\n=================")
+        print('done\n\n')
+
+# def take_slice(model,
+            #    logdir,
+            #    tfrecords_dir,
+            #    aug_strategy,
+            #    visual_file,
+            #    tpu_name,
+            #    bucket_name,
+            #    weights_dir,
+            #    multi_as_binary,
+            #    multi_class,
+            #    model_args,
+            #    which_epoch,
+            #    which_slice,
+            #    which_volume=1,
+            #    save_dir='',
+            #    cmap='gray',
+            #    clean=False):
+
+    # #load the database
+    # valid_ds = read_tfrecord_2d(tfrecords_dir=tfrecords_dir, #'gs://oai-challenge-dataset/tfrecords/valid/',
+    #                         batch_size=160,
+    #                         buffer_size=500,
+    #                         augmentation=aug_strategy,
+    #                         multi_class=multi_class,
+    #                         is_training=False,
+    #                         use_bfloat16=False,
+    #                         use_RGB=False)
+
+    # # load the checkpoints in the specified log directory
+    # session_weights = get_all_weights(bucket_name, logdir, tpu_name, visual_file, weights_dir)
+
+    # #figure for gif
+    # fig, axes = plt.subplots(1, 3)
+    # images_gif = []
+
+    # for chkpt in session_weights:
+    #     name = chkpt.split('/')[-1]
+    #     name = name.split('.inde')[0]
+
+    #     if int(name.split('.')[1]) == which_epoch:
+
+    #         print("\n\n\n\n+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
+    #         print(f"\t\tLoading weights from {name.split('.')[1]} epoch")
+    #         print(f"\t\t  {name}")
+    #         print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n")
+
+    #         trained_model = model(*model_args)
+    #         trained_model.load_weights('gs://' + os.path.join(bucket_name,
+    #                                                           weights_dir,
+    #                                                           tpu_name,
+    #                                                           visual_file,
+    #                                                           name)).expect_partial()
+
+    #         for idx, ds in enumerate(valid_ds):
+
+    #             if idx+1 == which_volume:
+    #                 x, y = ds
+    #                 x_slice = np.expand_dims(x[which_slice-1], axis=0)
+    #                 y_slice = y[which_slice-1]
+
+    #                 print('predicting slice {}'.format(which_slice))
+    #                 pred_slice = trained_model.predict(x_slice)
+    #                 print('prediction image data type: {}, shape: {}\n'.format(type(pred_slice), pred_slice.shape))
+    #                 if multi_class:
+    #                     pred_slice = np.argmax(pred_slice, axis=-1)
+    #                     y_slice = np.argmax(y_slice, axis=-1)
+    #                     if multi_as_binary:
+    #                         pred_slice[pred_slice>0] = 1
+    #                         y_slice[y_slice>0] = 1
+    #                 else:
+    #                     pred_slice = np.squeeze(pred_slice, axis=-1)
+    #                     y_slice = np.squeeze(y_slice, axis=-1)
+    #                 print('slice predicted\n')
+
+    #                 print('input image data type: {}, shape: {}'.format(type(x), x.shape))
+    #                 print('label image data type: {}, shape: {}'.format(type(y), y.shape))
+    #                 print('prediction image data type: {}, shape: {}\n'.format(type(pred_slice), pred_slice.shape))
+
+    #                 print("Creating input image")
+    #                 x_s = np.squeeze(x[which_slice-1], axis=-1)
+    #                 fig_x = plt.figure()
+    #                 ax_x = fig_x.add_subplot(1, 1, 1)
+    #                 ax_x.imshow(x_s, cmap='gray')
+                    
+    #                 print("Creating label image")
+    #                 fig_y = plt.figure()
+    #                 ax_y = fig_y.add_subplot(1, 1, 1)
+    #                 ax_y.imshow(y_slice, cmap='gray')
+                    
+    #                 print("Creating prediction image")
+    #                 fig_pred = plt.figure()
+    #                 ax_pred = fig_pred.add_subplot(1, 1, 1)
+    #                 ax_pred.imshow(pred_slice[0], cmap='gray')
+
+    #                 #Removing outside frame
+    #                 if clean:
+    #                     ax_x.axis('off')
+    #                     ax_y.axis('off')
+    #                     ax_pred.axis('off')
+                        
+
+    #                 print("Saving images")
+    #                 save_dir_x = save_dir + '_x.png'
+    #                 save_dir_y = save_dir + '_y.png'
+    #                 save_dir_pred = save_dir + '_pred.png'
+    #                 fig_x.savefig(save_dir_x)
+    #                 fig_y.savefig(save_dir_y)
+    #                 fig_pred.savefig(save_dir_pred)
+
+    #                 break
+            
+    #         break
+
+# def confusion_matrix(trained_model,
+                    #  weights_dir,
+                    #  fig_dir,
+                    #  dataset,
+                    #  validation_steps,
+                    #  multi_class,
+                    #  model_architecture,
+                    #  callbacks,
+                    #  num_classes=7
+                    #  ):
+
+    # trained_model.load_weights(weights_dir).expect_partial()
+    # trained_model.evaluate(dataset, steps=validation_steps, callbacks=callbacks)
+
+
+    # f = weights_dir.split('/')[-1]
+    # # Excluding parenthese before f too
+    # if weights_dir.endswith(f):
+    #     writer_dir = weights_dir[:-(len(f)+1)]
+    # writer_dir = os.path.join(writer_dir, 'eval')
+    # # os.makedirs(writer_dir)
+    # eval_metric_writer = tf.summary.create_file_writer(writer_dir)
+
+
+    # if multi_class:
+    #     cm = np.zeros((num_classes, num_classes))
+    #     classes = ["Background",
+    #                "Femoral",
+    #                "Medial Tibial",
+    #                "Lateral Tibial",
+    #                "Patellar",
+    #                "Lateral Meniscus",
+    #                "Medial Meniscus"]
+    # else:
+    #     cm = np.zeros((2, 2))
+    #     classes = ["Background",
+    #                "Cartilage"]
+
+    # for step, (image, label) in enumerate(dataset):
+    #     print(step)
+    #     pred = trained_model.predict(image)
+    #     visualise_multi_class(label, pred)
+    #     cm = cm + get_confusion_matrix(label, pred, classes=list(range(0, num_classes)))
+
+    #     if multi_class:
+    #         iou = iou_loss_eval(label, pred)
+    #         dice = dice_coef_eval(label, pred)
+    #     else:
+    #         iou = iou_loss(label, pred)
+    #         dice = dice_coef(label, pred)
+
+    #     with eval_metric_writer.as_default():
+    #         tf.summary.scalar('iou eval validation', iou, step=step)
+    #         tf.summary.scalar('dice eval validation', dice, step=step)
+        
+
+            
+
+    #     if step > validation_steps - 1:
+    #         break
+
+    # fig_file = model_architecture + '_matrix.png'
+    # fig_dir = os.path.join(fig_dir, fig_file)
+    # plot_confusion_matrix(cm, fig_dir, classes=classes)
+
+
+
+
+########## Confusion Matrix ##########
+def initialize_cm(multi_class, num_classes=7):
+    if multi_class:
+        cm = np.zeros((num_classes, num_classes))
+        classes = ["Background",
+                   "Femoral",
+                   "Medial Tibial",
+                   "Lateral Tibial",
+                   "Patellar",
+                   "Lateral Meniscus",
+                   "Medial Meniscus"]
+    else:
+        cm = np.zeros((2, 2))
+        classes = ["Background",
+                   "Cartilage"]
+
+    return cm, classes
+
+
+def update_cm(cm, num_classes=7):
+    cm = cm + get_confusion_matrix(label, pred, classes=list(range(0, num_classes)))
+    return cm
+
+def save_cm(cm, model_architecture, fig_dir, classes):
+    fig_file = model_architecture + '_matrix.png'
+    fig_dir = os.path.join(fig_dir, fig_file)
+    plot_confusion_matrix(cm, fig_dir, classes=classes)
+##########
+
+
+
+########## Gif ##########
+def initialize_gif():   
+    #figure for gif
+    fig, axes = plt.subplots(1, 3)
+    images_gif = []
+    return fig, axes, images_gif
+
+def update_gif_slice(x, y, trained_model,
+               aug_strategy,
+               multi_as_binary, multi_class,
+               which_epoch, which_slice, which_volume=1,
+               save_dir='',
+               gif_cmap='gray',
+               clean=False):
+
+
+    x_slice = np.expand_dims(x[which_slice-1], axis=0)
+    y_slice = y[which_slice-1]
+
+    print('predicting slice {}'.format(which_slice))
+    pred_slice = trained_model.predict(x_slice)
+    print('prediction image data type: {}, shape: {}\n'.format(type(pred_slice), pred_slice.shape))
+    if multi_class:
+        pred_slice = np.argmax(pred_slice, axis=-1)
+        y_slice = np.argmax(y_slice, axis=-1)
+        if multi_as_binary:
+            pred_slice[pred_slice>0] = 1
+            y_slice[y_slice>0] = 1
+    else:
+        pred_slice = np.squeeze(pred_slice, axis=-1)
+        y_slice = np.squeeze(y_slice, axis=-1)
+
+    ###############
+    print('slice predicted\n')
+    print('input image data type: {}, shape: {}'.format(type(x), x.shape))
+    print('label image data type: {}, shape: {}'.format(type(y), y.shape))
+    print('prediction image data type: {}, shape: {}\n'.format(type(pred_slice), pred_slice.shape))
+    ###############
+
+    print("Creating input image")
+    x_s = np.squeeze(x[which_slice-1], axis=-1)
+    fig_x = plt.figure()
+    ax_x = fig_x.add_subplot(1, 1, 1)
+    ax_x.imshow(x_s, cmap='gray')
+    
+    print("Creating label image")
+    fig_y = plt.figure()
+    ax_y = fig_y.add_subplot(1, 1, 1)
+    ax_y.imshow(y_slice, cmap='gray')
+    
+    print("Creating prediction image")
+    fig_pred = plt.figure()
+    ax_pred = fig_pred.add_subplot(1, 1, 1)
+    ax_pred.imshow(pred_slice[0], cmap='gray')
+
+    #Removing outside frame
+    if clean:
+        ax_x.axis('off')
+        ax_y.axis('off')
+        ax_pred.axis('off')
+        
+
+    print("Saving images")
+    save_dir_x = save_dir + '_x.png'
+    save_dir_y = save_dir + '_y.png'
+    save_dir_pred = save_dir + '_pred.png'
+    fig_x.savefig(save_dir_x)
+    fig_y.savefig(save_dir_y)
+    fig_pred.savefig(save_dir_pred)
+
+
+def update_volume_comp_gif(x,y, images_gif, trained_model,
+                          multi_class,
+                          which_epoch,
+                          which_volume=1,
+                          gif_dir='',
+                          gif_cmap='gray',
+                          clean=False):
+
+    x = np.array(x)
+    x = np.squeeze(x, axis=-1)
+
+    print('predicting volume {}'.format(which_volume))
+    pred_vol = trained_model.predict(x)
+    if multi_class:
+        pred_vol = np.argmax(pred_vol, axis=-1)
+        y = np.argmax(y, axis=-1)
+    print('volume predicted\n')
+
+    print('input image data type: {}, shape: {}'.format(type(x), x.shape))
+    print('label image data type: {}, shape: {}'.format(type(y), y.shape))
+    print('prediction image data type: {}, shape: {}\n'.format(type(pred), pred.shape))
+
+    for i in range(x.shape[0]):
+        print(f"Analysing slice {i+1}")
+        x_im = axes[0].imshow(x[i,:,:], cmap='gray', animated=True, aspect='auto')
+        y_im = axes[1].imshow(y[i,:,:], cmap='gray', animated=True, aspect='auto')
+        pred_im = axes[2].imshow(pred_vol[i,:,:], cmap='gray', animated=True, aspect='auto')
+        if not clean:
+            text = ax.text(0.5,1.05,f'Slice {i+1}', 
+                        size=plt.rcParams["axes.titlesize"],
+                        ha="center", transform=ax.transAxes)
+            images_gif.append([im, text])
+        else:
+            ax.axis('off')
+            images_gif.append([im])
+
+    return images_gif
+
+
+def update_epoch_gif(x, trained_model, aug_strategy,
+              multi_class, which_slice, which_volume=1,
+              epoch_limit=1000,
+              gif_dir='',
+              gif_cmap='gray',
+              clean=False):
+
+    images_gif = []
+
+    x_slice = np.expand_dims(x[which_slice-1], axis=0)
+    print('Input image data type: {}, shape: {}\n'.format(type(x_slice), x_slice.shape))
+
+    print('predicting slice {}'.format(which_slice))
+    predicted_slice = trained_model.predict(x_slice)
+    if multi_class:
+        predicted_slice = np.argmax(predicted_slice, axis=-1)
+    else:
+        predicted_slice = np.squeeze(predicted_slice, axis=-1)
+
+    print('slice predicted\n')
+
+    print("adding prediction to the queue")
+    im = ax.imshow(predicted_slice[0], cmap=gif_cmap, animated=True)
+    if not clean:
+        text = ax.text(0.5,1.05,f"Epoch {int(name.split('.')[1])}", 
+                    size=plt.rcParams["axes.titlesize"],
+                    ha="center", transform=ax.transAxes)
+        images_gif.append([im, text])
+    else:
+        ax.axis('off')
+        images_gif.append([im])
+    print("prediction added\n")
+
+    return images_gif
+##########
+
+
+
+########## Plotly npys ##########
+def update_volume_npy(y, pred, target, 
+                     sample_pred, sample_y, 
+                     visual_file, name, 
+                     which_volume, multi_class):
+    batch_size = y.shape[0]
+    y = np.array(y)
+    pred = np.array(pred)
+
+
+    if (get_depth(sample_pred) + batch_size) < target:  # check if next batch will fit in volume (160)
+        sample_pred.append(pred)
+        del pred
+        sample_y.append(y)
+        del y
+    else:
+        remaining = target - get_depth(sample_pred)
+        sample_pred.append(pred[:remaining])
+        sample_y.append(y[:remaining])
+        pred_vol = np.concatenate(sample_pred)
+        del sample_pred
+        y_vol = np.concatenate(sample_y)
+        del sample_y
+        sample_pred = [pred[remaining:]]
+        sample_y = [y[remaining:]]
+
+        del pred
+        del y
+
+        ######################
+        # print("===============")
+        # print("pred done")
+        # print(pred_vol.shape)
+        # print(y_vol.shape)
+        # print("===============")
+        # print('multi_class', multi_class)
+        ######################
+
+        if multi_class:  # or np.shape(pred_vol)[-1] not
+            pred_vol = np.argmax(pred_vol, axis=-1)
+            y_vol = np.argmax(y_vol, axis=-1)
+
+            ######################
+            # print('np.shape(pred_vol)', np.shape(pred_vol))
+            # print('np.shape(y_vol)',np.shape(y_vol))
+            ######################
+
+        # Save volume as numpy file for plotlyyy
+        fig_dir = "results"
+        name_pred_npy = os.path.join(fig_dir, "pred", (visual_file + "_" + name))
+        name_y_npy = os.path.join(fig_dir, "ground_truth", (visual_file + "_vol" + str(which_volume).zfill(3)))
+        
+        ######################
+        # print("npy save pred as ", name_pred_npy)
+        # print("npy save y as ", name_y_npy)
+        # print("Currently on vol ", idx_vol)
+        ######################
+
+
+        # Get middle xx slices cuz 288x288x160 too big
+        roi = int(80 / 2)
+        d1,d2,d3 = np.shape(pred_vol)[0:3]
+        d1, d2, d3 = int(np.floor(d1/2)), int(np.floor(d2/2)), int(np.floor(d3/2))
+        pred_vol = pred_vol[(d1-roi):(d1+roi),(d2-roi):(d2+roi), (d3-roi):(d3+roi)]
+        d1,d2,d3 = np.shape(y_vol)[0:3]
+        d1, d2, d3 = int(np.floor(d1/2)), int(np.floor(d2/2)), int(np.floor(d3/2))
+        y_vol = y_vol[(d1-roi):(d1+roi),(d2-roi):(d2+roi), (d3-roi):(d3+roi)]
+
+        ######################
+        print('y_vol.shape', np.shape(y_vol))
+        ######################
+
+        np.save(name_pred_npy,pred_vol)
+        np.save(name_y_npy,y_vol)
+        ######################
+        print("Total voxels saved, pred:", np.sum(pred_vol), "y:", np.sum(y_vol))
+        ######################
+
+        sample_pred = []
+        sample_y = []
+        del pred_vol
+        del y_vol
+
+    return sample_pred, sample_y
+
+##########
+
+
+
+
+
+def eval_loop(dataset, validation_steps, aug_strategy,
+                     bucket_name, logdir, tpu_name, visual_file, weights_dir, 
+                     fig_dir, 
+                     which_volume, which_epoch, which_slice, 
+                     multi_as_binary,
+                     trained_model, model_architecture, 
+                     callbacks,
+                     num_classes=7
+                     ):
+
+    """ Evaluate model and visualize as needed """
+
+    multi_class = num_classes > 1
+    gif_dir=''
+
+    # load the checkpoints in the specified log directory
+    session_weights = get_all_weights(bucket_name, logdir, tpu_name, visual_file, weights_dir)
+    last_epoch = len(session_weights)
+
+    # trained_model.load_weights(weights_dir).expect_partial()
+    # trained_model.evaluate(dataset, steps=validation_steps, callbacks=callbacks)
+
+    
+
+
+    # Callbacks (as in og conf matrix function)
+    f = weights_dir.split('/')[-1]
+    # Excluding parenthese before f too
+    if weights_dir.endswith(f):
+        writer_dir = weights_dir[:-(len(f)+1)]
+    writer_dir = os.path.join(writer_dir, 'eval')
+    eval_metric_writer = tf.summary.create_file_writer(writer_dir)
+
+
+    # Init visuals
+    cm, classes = initialize_cm(multi_class, num_classes)
+    fig, axes, images_gif = initialize_gif()
+    target = 160 # how many slices in 1 vol
+    sample_pred = []  # prediction for current 160,288,288 vol
+    sample_y = []    # y for current 160,288,288 vol
+
+
+
+
+    for chkpt in session_weights:
+        ### Skip to last chkpt if you only want evaluation
+
+        name = chkpt.split('/')[-1]
+        name = name.split('.inde')[0]
+        epoch = name.split('.')[1]
+
+        #########################
+        print("\n\n\n\n+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
+        print(f"\t\tLoading weights from {epoch} epoch")
+        print(f"\t\t  {name}")
+        print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n")
+        #########################
+
+        trained_model.load_weights('gs://' + os.path.join(bucket_name,
+                                                            weights_dir,
+                                                            tpu_name,
+                                                            visual_file,
+                                                            name)).expect_partial()
+        if epoch==last_epoch:
+            trained_model.evaluate(dataset, steps=validation_steps, callbacks=callbacks)    
+
+
+        # Initializing volume saving
+        sample_pred = []  # prediction for current 160,288,288 vol
+        sample_y = []    # y for current 160,288,288 vol
+
+        for step, (x, label) in enumerate(dataset):
+            print('step',step)
+            pred = trained_model.predict(x)
+            
+
+            # Update visuals
+            cm = update_cm(cm, num_classes)
+            visualise_multi_class(label, pred)
+
+            if step+1 == which_volume:
+                update_gif_slice(x, label, trained_model,
+                                aug_strategy,
+                                multi_as_binary, multi_class,
+                                which_epoch, which_slice)
+
+                images_gif = update_volume_comp_gif(x,label, images_gif, trained_model,
+                          multi_class,
+                          which_epoch,
+                          gif_dir=gif_dir)
+
+                images_gif = update_epoch_gif(x, trained_model, aug_strategy,
+                                            multi_class, which_slice, 
+                                            gif_dir=gif_dir)
+            
+                sample_pred, sample_y = update_volume_npy(label, pred, target, 
+                                                        sample_pred, sample_y, 
+                                                        visual_file, name, 
+                                                        which_volume, multi_class)
+                    
+
+
+            # if multi_class:
+            #     iou = iou_loss_eval(label, pred)
+            #     dice = dice_coef_eval(label, pred)
+            # else:
+            #     iou = iou_loss(label, pred)
+            #     dice = dice_coef(label, pred)
+            iou = iou_loss_eval(label, pred) if multi_class else iou_loss(label, pred)
+            dice = dice_coef_eval(label, pred) if multi_class else dice_coef(label, pred)
+
+            with eval_metric_writer.as_default():
+                tf.summary.scalar('iou eval validation', iou, step=step)
+                tf.summary.scalar('dice eval validation', dice, step=step)
+
+        # Save visuals
+        save_cm(cm, model_architecture, fig_dir, classes)
+    pred_evolution_gif(fig, images_gif, save_dir=gif_dir, save=True, no_margins=False)
+