--- a +++ b/Segmentation/utils/evaluation_utils.py @@ -0,0 +1,1104 @@ +import tensorflow as tf +import numpy as np +import matplotlib.pyplot as plt +import matplotlib.animation as animation +from matplotlib.animation import ArtistAnimation + +import glob +from google.cloud import storage +from pathlib import Path +import os +import datetime + +from Segmentation.utils.losses import dice_coef +from Segmentation.plotting.voxels import plot_volume +# from Segmentation.utils.data_loader import read_tfrecord_2d +from Segmentation.utils.training_utils import visualise_binary, visualise_multi_class +from Segmentation.utils.evaluation_metrics import get_confusion_matrix, plot_confusion_matrix, iou_loss_eval, dice_coef_eval +from Segmentation.utils.losses import dice_coef, iou_loss + +def get_depth(conc): + depth = 0 + for batch in conc: + depth += batch.shape[0] + return depth + + +def get_all_weights(bucket_name, logdir, tpu_name, visual_file, weights_dir): + """ Load the checkpoints in the specified log directory """ + + ###################### + """ Add the visualisation code here """ + print("+========================================================") + print('bucket_name',bucket_name) + print("\n\nThe directories are:") + print('weights_dir == "checkpoint"',weights_dir == "checkpoint") + print('weights_dir',weights_dir) + ###################### + + session_name = weights_dir.split('/')[3] + session_name = os.path.join(session_name, tpu_name, visual_file) + # Pietro's: session_name = os.path.join(weights_dir, tpu_name, visual_file) + + # Get names within folder in gcloud + storage_client = storage.Client() + blobs = storage_client.list_blobs(bucket_name) + session_content = [] + print('session_name',session_name) + for blob in blobs: + if session_name in blob.name: + session_content.append(blob.name) + + session_weights = [] + for item in session_content: + if ('_weights' in item) and ('.ckpt.index' in item): + session_weights.append(item) + + ###################### + for s in session_weights: + print(s) #print all the checkpoint directories + print("--") + ###################### + + return session_weights + +# def plot_and_eval_3D(model, + # logdir, + # visual_file, + # tpu_name, + # bucket_name, + # weights_dir, + # multi_class, + # save_freq, + # dataset, + # model_args): + + # """ plotly: Generates a numpy volume for every #save_freq number of weights + # and saves it in local results/pred/*visual_file* and results/y/*visual_file* + + # Once numpy's are generated, run the following in console to get an embeddable html file: + # python3 Visualization/plotly_3d_voxels/run_plotly.py -dir_l FOLDER_TO_Y_SAMPLES + # -dir_r FOLDER_TO_PREDICTIONS + + + # """ + + # session_weights = get_all_weights(bucket_name, logdir, tpu_name, visual_file, weights_dir) + + # # Only use part of dataset + # idx_vol= 0 # how many numpies have been save + # target = 160 + + # for i, chkpt in enumerate(session_weights): + + # should_save_np = np.mod((i+1), save_freq) == 0 + + # ###################### + # # print('should_save_np',should_save_np) + # # print('checkpoint enum i',i) + # # print('save_freq set to ',save_freq) + # ###################### + + # if not should_save_np: # skip this checkpoint weight + # print("Skipping weight", chkpt) + # continue + + + # name = chkpt.split('/')[-1] + # name = name.split('.inde')[0] + # trained_model = model(*model_args) + # trained_model.load_weights('gs://' + os.path.join(bucket_name, + # 'checkpoints', + # tpu_name, + # visual_file, + # name)).expect_partial() + + + + # # sample_x = [] # x for current 160,288,288 vol + # sample_pred = [] # prediction for current 160,288,288 vol + # sample_y = [] # y for current 160,288,288 vol + + # which_volume = 2 + # for idx, ds in enumerate(dataset): + + # ###################### + # print('Current chkpt name',name) + # print(f"the index is {idx}") + # ###################### + + + # x, y = ds + # batch_size = x.shape[0] + + # if batch_size == 160: + # if not (int(idx) == int(which_volume)): + # continue + + # x = np.array(x) + # y = np.array(y) + + # pred = trained_model.predict(x) + + # ###################### + # # print("Current batch size set to {}. Target depth is {}".format(batch_size, target)) + # # print('Input image data type: {}, shape: {}'.format(type(x), x.shape)) + # # print('Ground truth data type: {}, shape: {}'.format(type(y), y.shape)) + # # print('Prediction data type: {}, shape: {}'.format(type(pred), pred.shape)) + # # print("=================") + # ###################### + + # if (get_depth(sample_pred) + batch_size) < target: # check if next batch will fit in volume (160) + # sample_pred.append(pred) + # del pred + # sample_y.append(y) + # del y + # else: + # remaining = target - get_depth(sample_pred) + # sample_pred.append(pred[:remaining]) + # sample_y.append(y[:remaining]) + # pred_vol = np.concatenate(sample_pred) + # del sample_pred + # y_vol = np.concatenate(sample_y) + # del sample_y + # sample_pred = [pred[remaining:]] + # sample_y = [y[remaining:]] + + # del pred + # del y + + # ###################### + # # print("===============") + # # print("pred done") + # # print(pred_vol.shape) + # # print(y_vol.shape) + # # print("===============") + # # print('multi_class', multi_class) + # ###################### + + # if multi_class: # or np.shape(pred_vol)[-1] not + # pred_vol = np.argmax(pred_vol, axis=-1) + # y_vol = np.argmax(y_vol, axis=-1) + + # ###################### + # # print('np.shape(pred_vol)', np.shape(pred_vol)) + # # print('np.shape(y_vol)',np.shape(y_vol)) + # ###################### + + # # Save volume as numpy file for plotlyyy + # fig_dir = "results" + # name_pred_npy = os.path.join(fig_dir, "pred", (visual_file + "_" + name)) + # name_y_npy = os.path.join(fig_dir, "ground_truth", (visual_file + "_" + str(which_volume).zfill(3))) + + # ###################### + # # print("npy save pred as ", name_pred_npy) + # # print("npy save y as ", name_y_npy) + # # print("Currently on vol ", idx_vol) + # ###################### + + + # # Get middle xx slices cuz 288x288x160 too big + # roi = int(80 / 2) + # d1,d2,d3 = np.shape(pred_vol)[0:3] + # d1, d2, d3 = int(np.floor(d1/2)), int(np.floor(d2/2)), int(np.floor(d3/2)) + # pred_vol = pred_vol[(d1-roi):(d1+roi),(d2-roi):(d2+roi), (d3-roi):(d3+roi)] + # d1,d2,d3 = np.shape(y_vol)[0:3] + # d1, d2, d3 = int(np.floor(d1/2)), int(np.floor(d2/2)), int(np.floor(d3/2)) + # y_vol = y_vol[(d1-roi):(d1+roi),(d2-roi):(d2+roi), (d3-roi):(d3+roi)] + + # ###################### + # print('y_vol.shape', np.shape(y_vol)) + # ###################### + + # np.save(name_pred_npy,pred_vol) + # np.save(name_y_npy,y_vol) + # idx_vol += 1 + # ###################### + # print("Total voxels saved, pred:", np.sum(pred_vol), "y:", np.sum(y_vol)) + # ###################### + # del pred_vol + # del y_vol + + # break + +# def epoch_gif(model, + # logdir, + # tfrecords_dir, + # aug_strategy, + # visual_file, + # tpu_name, + # bucket_name, + # weights_dir, + # multi_class, + # model_args, + # which_slice, + # which_volume=1, + # epoch_limit=1000, + # gif_dir='', + # gif_cmap='gray', + # clean=False): + + # #load the database + # valid_ds = read_tfrecord_2d(tfrecords_dir=tfrecords_dir, #'gs://oai-challenge-dataset/tfrecords/valid/', + # batch_size=160, + # buffer_size=500, + # augmentation=aug_strategy, + # multi_class=multi_class, + # is_training=False, + # use_bfloat16=False, + # use_RGB=False) + + # # load the checkpoints in the specified log directory + # session_weights = get_all_weights(bucket_name, logdir, tpu_name, visual_file, weights_dir) + + # #figure for gif + # fig, ax = plt.subplots() + # images_gif = [] + + # for chkpt in session_weights: + # name = chkpt.split('/')[-1] + # name = name.split('.inde')[0] + + # if int(name.split('.')[1]) <= epoch_limit: + + # print("\n\n\n\n+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++") + # print(f"\t\tLoading weights from {name.split('.')[1]} epoch") + # print(f"\t\t {name}") + # print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n") + + # trained_model = model(*model_args) + # trained_model.load_weights('gs://' + os.path.join(bucket_name, + # weights_dir, + # tpu_name, + # visual_file, + # name)).expect_partial() + + # for idx, ds in enumerate(valid_ds): + + # if idx+1 == which_volume: + # x, _ = ds + # x_slice = np.expand_dims(x[which_slice-1], axis=0) + # print('Input image data type: {}, shape: {}\n'.format(type(x_slice), x_slice.shape)) + + # print('predicting slice {}'.format(which_slice)) + # predicted_slice = trained_model.predict(x_slice) + # if multi_class: + # predicted_slice = np.argmax(predicted_slice, axis=-1) + # else: + # predicted_slice = np.squeeze(predicted_slice, axis=-1) + + # print('slice predicted\n') + + # print("adding prediction to the queue") + # im = ax.imshow(predicted_slice[0], cmap=gif_cmap, animated=True) + # if not clean: + # text = ax.text(0.5,1.05,f"Epoch {int(name.split('.')[1])}", + # size=plt.rcParams["axes.titlesize"], + # ha="center", transform=ax.transAxes) + # images_gif.append([im, text]) + # else: + # ax.axis('off') + # images_gif.append([im]) + # print("prediction added\n") + + # break + + # else: + # break + + # pred_evolution_gif(fig, images_gif, save_dir=gif_dir, save=True, no_margins=clean) + +# def volume_gif(model, + # logdir, + # tfrecords_dir, + # aug_strategy, + # visual_file, + # tpu_name, + # bucket_name, + # weights_dir, + # multi_class, + # model_args, + # which_epoch, + # which_volume=1, + # gif_dir='', + # gif_cmap='gray', + # clean=False): + + # #load the database + # valid_ds = read_tfrecord_2d(tfrecords_dir=tfrecords_dir, #'gs://oai-challenge-dataset/tfrecords/valid/', + # batch_size=160, + # buffer_size=500, + # augmentation=aug_strategy, + # multi_class=multi_class, + # is_training=False, + # use_bfloat16=False, + # use_RGB=False) + + # # load the checkpoints in the specified log directory + # session_weights = get_all_weights(bucket_name, logdir, tpu_name, visual_file, weights_dir) + + # #figure for gif + # fig, ax = plt.subplots() + # images_gif = [] + + # for chkpt in session_weights: + # name = chkpt.split('/')[-1] + # name = name.split('.inde')[0] + + # if int(name.split('.')[1]) == which_epoch: + + # print("\n\n\n\n+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++") + # print(f"\t\tLoading weights from {name.split('.')[1]} epoch") + # print(f"\t\t {name}") + # print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n") + + # trained_model = model(*model_args) + # trained_model.load_weights('gs://' + os.path.join(bucket_name, + # weights_dir, + # tpu_name, + # visual_file, + # name)).expect_partial() + + # for idx, ds in enumerate(valid_ds): + + # if idx+1 == which_volume: + # x, _ = ds + # x = np.array(x) + # print('Input image data type: {}, shape: {}\n'.format(type(x), x.shape)) + + # print('predicting volume {}'.format(which_volume)) + # pred_vol = trained_model.predict(x) + # if multi_class: + # pred_vol = np.argmax(pred_vol, axis=-1) + # else: + # pred_vol = np.squeeze(pred_vol, axis=-1) + # print('volume predicted\n') + + # for i in range(x.shape[0]): + # print(f"Analysing slice {i+1}") + # im = ax.imshow(pred_vol[i,:,:], cmap='gray', animated=True, aspect='auto') + # if not clean: + # text = ax.text(0.5,1.05,f'Slice {i+1}', + # size=plt.rcParams["axes.titlesize"], + # ha="center", transform=ax.transAxes) + # images_gif.append([im, text]) + # else: + # ax.axis('off') + # images_gif.append([im]) + + # break + + # break + + # pred_evolution_gif(fig, images_gif, save_dir=gif_dir, save=True, no_margins=clean) + +# def volume_comparison_gif(model, + # logdir, + # tfrecords_dir, + # visual_file, + # tpu_name, + # bucket_name, + # weights_dir, + # multi_class, + # model_args, + # which_epoch, + # which_volume=1, + # gif_dir='', + # gif_cmap='gray', + # clean=False): + + # #load the database + # valid_ds = read_tfrecord_2d(tfrecords_dir=tfrecords_dir, #'gs://oai-challenge-dataset/tfrecords/valid/', + # batch_size=160, + # buffer_size=500, + # augmentation=None, + # multi_class=multi_class, + # is_training=False, + # use_bfloat16=False, + # use_RGB=False) + + # # load the checkpoints in the specified log directory + # session_weights = get_all_weights(bucket_name, logdir, tpu_name, visual_file, weights_dir) + + # #figure for gif + # fig, axes = plt.subplots(1, 3) + # images_gif = [] + + # for chkpt in session_weights: + # name = chkpt.split('/')[-1] + # name = name.split('.inde')[0] + + # if int(name.split('.')[1]) == which_epoch: + + # print("\n\n\n\n+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++") + # print(f"\t\tLoading weights from {name.split('.')[1]} epoch") + # print(f"\t\t {name}") + # print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n") + + # trained_model = model(*model_args) + # trained_model.load_weights('gs://' + os.path.join(bucket_name, + # weights_dir, + # tpu_name, + # visual_file, + # name)).expect_partial() + + # for idx, ds in enumerate(valid_ds): + + # if idx+1 == which_volume: + # x, y = ds + # x = np.array(x) + # x = np.squeeze(x, axis=-1) + + # print('predicting volume {}'.format(which_volume)) + # pred_vol = trained_model.predict(x) + # if multi_class: + # pred_vol = np.argmax(pred_vol, axis=-1) + # y = np.argmax(y, axis=-1) + # print('volume predicted\n') + + # print('input image data type: {}, shape: {}'.format(type(x), x.shape)) + # print('label image data type: {}, shape: {}'.format(type(y), y.shape)) + # print('prediction image data type: {}, shape: {}\n'.format(type(pred), pred.shape)) + + # for i in range(x.shape[0]): + # print(f"Analysing slice {i+1}") + # x_im = axes[0].imshow(x[i,:,:], cmap='gray', animated=True, aspect='auto') + # y_im = axes[1].imshow(y[i,:,:], cmap='gray', animated=True, aspect='auto') + # pred_im = axes[2].imshow(pred_vol[i,:,:], cmap='gray', animated=True, aspect='auto') + # if not clean: + # text = ax.text(0.5,1.05,f'Slice {i+1}', + # size=plt.rcParams["axes.titlesize"], + # ha="center", transform=ax.transAxes) + # images_gif.append([im, text]) + # else: + # ax.axis('off') + # images_gif.append([im]) + + # break + + # break + + # pred_evolution_gif(fig, images_gif, save_dir=gif_dir, save=True, no_margins=False) + + +def pred_evolution_gif(fig, + frames_list, + interval=300, + save_dir='', + save=True, + no_margins=True, + show=False): + + print("\n\n\n\n=================") + print("checking for ffmpeg...") + if not os.path.isfile('./../../../opt/conda/bin/ffmpeg'): + print("please 'pip install ffmpeg' to create gif") + print("gif not created") + + else: + print("ffmpeg found") + print("creating the gif ...\n") + gif = ArtistAnimation(fig, frames_list, interval, repeat=True) # create gif + + if save: + if no_margins: + plt.tight_layout() + plt.gca().set_axis_off() + plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, + hspace = 0, wspace = 0) + plt.margins(0,0) + plt.gca().xaxis.set_major_locator(plt.NullLocator()) + plt.gca().yaxis.set_major_locator(plt.NullLocator()) + + if save_dir == '': + time = datetime.now().strftime("%Y%m%d-%H%M%S") + save_dir = 'results/gif'+ time + '.gif' + + plt.rcParams['animation.ffmpeg_path'] = r'//opt//conda//bin//ffmpeg' # set directory of ffmpeg binary file + Writer = animation.writers['ffmpeg'] + ffmwriter = Writer(fps=1000//interval, metadata=dict(artist='Me'), bitrate=1800) #set the save writer + gif.save('results/temp_video.mp4', writer=ffmwriter) + + codeBASH = f"ffmpeg -i 'results/temp_video.mp4' -loop 0 {save_dir}" #convert mp4 to gif + os.system(codeBASH) + os.remove("results/temp_video.mp4") + + plt.close('all') + + if show: + plt.show() + plt.close('all') + + print("\n\n=================") + print('done\n\n') + +# def take_slice(model, + # logdir, + # tfrecords_dir, + # aug_strategy, + # visual_file, + # tpu_name, + # bucket_name, + # weights_dir, + # multi_as_binary, + # multi_class, + # model_args, + # which_epoch, + # which_slice, + # which_volume=1, + # save_dir='', + # cmap='gray', + # clean=False): + + # #load the database + # valid_ds = read_tfrecord_2d(tfrecords_dir=tfrecords_dir, #'gs://oai-challenge-dataset/tfrecords/valid/', + # batch_size=160, + # buffer_size=500, + # augmentation=aug_strategy, + # multi_class=multi_class, + # is_training=False, + # use_bfloat16=False, + # use_RGB=False) + + # # load the checkpoints in the specified log directory + # session_weights = get_all_weights(bucket_name, logdir, tpu_name, visual_file, weights_dir) + + # #figure for gif + # fig, axes = plt.subplots(1, 3) + # images_gif = [] + + # for chkpt in session_weights: + # name = chkpt.split('/')[-1] + # name = name.split('.inde')[0] + + # if int(name.split('.')[1]) == which_epoch: + + # print("\n\n\n\n+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++") + # print(f"\t\tLoading weights from {name.split('.')[1]} epoch") + # print(f"\t\t {name}") + # print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n") + + # trained_model = model(*model_args) + # trained_model.load_weights('gs://' + os.path.join(bucket_name, + # weights_dir, + # tpu_name, + # visual_file, + # name)).expect_partial() + + # for idx, ds in enumerate(valid_ds): + + # if idx+1 == which_volume: + # x, y = ds + # x_slice = np.expand_dims(x[which_slice-1], axis=0) + # y_slice = y[which_slice-1] + + # print('predicting slice {}'.format(which_slice)) + # pred_slice = trained_model.predict(x_slice) + # print('prediction image data type: {}, shape: {}\n'.format(type(pred_slice), pred_slice.shape)) + # if multi_class: + # pred_slice = np.argmax(pred_slice, axis=-1) + # y_slice = np.argmax(y_slice, axis=-1) + # if multi_as_binary: + # pred_slice[pred_slice>0] = 1 + # y_slice[y_slice>0] = 1 + # else: + # pred_slice = np.squeeze(pred_slice, axis=-1) + # y_slice = np.squeeze(y_slice, axis=-1) + # print('slice predicted\n') + + # print('input image data type: {}, shape: {}'.format(type(x), x.shape)) + # print('label image data type: {}, shape: {}'.format(type(y), y.shape)) + # print('prediction image data type: {}, shape: {}\n'.format(type(pred_slice), pred_slice.shape)) + + # print("Creating input image") + # x_s = np.squeeze(x[which_slice-1], axis=-1) + # fig_x = plt.figure() + # ax_x = fig_x.add_subplot(1, 1, 1) + # ax_x.imshow(x_s, cmap='gray') + + # print("Creating label image") + # fig_y = plt.figure() + # ax_y = fig_y.add_subplot(1, 1, 1) + # ax_y.imshow(y_slice, cmap='gray') + + # print("Creating prediction image") + # fig_pred = plt.figure() + # ax_pred = fig_pred.add_subplot(1, 1, 1) + # ax_pred.imshow(pred_slice[0], cmap='gray') + + # #Removing outside frame + # if clean: + # ax_x.axis('off') + # ax_y.axis('off') + # ax_pred.axis('off') + + + # print("Saving images") + # save_dir_x = save_dir + '_x.png' + # save_dir_y = save_dir + '_y.png' + # save_dir_pred = save_dir + '_pred.png' + # fig_x.savefig(save_dir_x) + # fig_y.savefig(save_dir_y) + # fig_pred.savefig(save_dir_pred) + + # break + + # break + +# def confusion_matrix(trained_model, + # weights_dir, + # fig_dir, + # dataset, + # validation_steps, + # multi_class, + # model_architecture, + # callbacks, + # num_classes=7 + # ): + + # trained_model.load_weights(weights_dir).expect_partial() + # trained_model.evaluate(dataset, steps=validation_steps, callbacks=callbacks) + + + # f = weights_dir.split('/')[-1] + # # Excluding parenthese before f too + # if weights_dir.endswith(f): + # writer_dir = weights_dir[:-(len(f)+1)] + # writer_dir = os.path.join(writer_dir, 'eval') + # # os.makedirs(writer_dir) + # eval_metric_writer = tf.summary.create_file_writer(writer_dir) + + + # if multi_class: + # cm = np.zeros((num_classes, num_classes)) + # classes = ["Background", + # "Femoral", + # "Medial Tibial", + # "Lateral Tibial", + # "Patellar", + # "Lateral Meniscus", + # "Medial Meniscus"] + # else: + # cm = np.zeros((2, 2)) + # classes = ["Background", + # "Cartilage"] + + # for step, (image, label) in enumerate(dataset): + # print(step) + # pred = trained_model.predict(image) + # visualise_multi_class(label, pred) + # cm = cm + get_confusion_matrix(label, pred, classes=list(range(0, num_classes))) + + # if multi_class: + # iou = iou_loss_eval(label, pred) + # dice = dice_coef_eval(label, pred) + # else: + # iou = iou_loss(label, pred) + # dice = dice_coef(label, pred) + + # with eval_metric_writer.as_default(): + # tf.summary.scalar('iou eval validation', iou, step=step) + # tf.summary.scalar('dice eval validation', dice, step=step) + + + + + # if step > validation_steps - 1: + # break + + # fig_file = model_architecture + '_matrix.png' + # fig_dir = os.path.join(fig_dir, fig_file) + # plot_confusion_matrix(cm, fig_dir, classes=classes) + + + + +########## Confusion Matrix ########## +def initialize_cm(multi_class, num_classes=7): + if multi_class: + cm = np.zeros((num_classes, num_classes)) + classes = ["Background", + "Femoral", + "Medial Tibial", + "Lateral Tibial", + "Patellar", + "Lateral Meniscus", + "Medial Meniscus"] + else: + cm = np.zeros((2, 2)) + classes = ["Background", + "Cartilage"] + + return cm, classes + + +def update_cm(cm, num_classes=7): + cm = cm + get_confusion_matrix(label, pred, classes=list(range(0, num_classes))) + return cm + +def save_cm(cm, model_architecture, fig_dir, classes): + fig_file = model_architecture + '_matrix.png' + fig_dir = os.path.join(fig_dir, fig_file) + plot_confusion_matrix(cm, fig_dir, classes=classes) +########## + + + +########## Gif ########## +def initialize_gif(): + #figure for gif + fig, axes = plt.subplots(1, 3) + images_gif = [] + return fig, axes, images_gif + +def update_gif_slice(x, y, trained_model, + aug_strategy, + multi_as_binary, multi_class, + which_epoch, which_slice, which_volume=1, + save_dir='', + gif_cmap='gray', + clean=False): + + + x_slice = np.expand_dims(x[which_slice-1], axis=0) + y_slice = y[which_slice-1] + + print('predicting slice {}'.format(which_slice)) + pred_slice = trained_model.predict(x_slice) + print('prediction image data type: {}, shape: {}\n'.format(type(pred_slice), pred_slice.shape)) + if multi_class: + pred_slice = np.argmax(pred_slice, axis=-1) + y_slice = np.argmax(y_slice, axis=-1) + if multi_as_binary: + pred_slice[pred_slice>0] = 1 + y_slice[y_slice>0] = 1 + else: + pred_slice = np.squeeze(pred_slice, axis=-1) + y_slice = np.squeeze(y_slice, axis=-1) + + ############### + print('slice predicted\n') + print('input image data type: {}, shape: {}'.format(type(x), x.shape)) + print('label image data type: {}, shape: {}'.format(type(y), y.shape)) + print('prediction image data type: {}, shape: {}\n'.format(type(pred_slice), pred_slice.shape)) + ############### + + print("Creating input image") + x_s = np.squeeze(x[which_slice-1], axis=-1) + fig_x = plt.figure() + ax_x = fig_x.add_subplot(1, 1, 1) + ax_x.imshow(x_s, cmap='gray') + + print("Creating label image") + fig_y = plt.figure() + ax_y = fig_y.add_subplot(1, 1, 1) + ax_y.imshow(y_slice, cmap='gray') + + print("Creating prediction image") + fig_pred = plt.figure() + ax_pred = fig_pred.add_subplot(1, 1, 1) + ax_pred.imshow(pred_slice[0], cmap='gray') + + #Removing outside frame + if clean: + ax_x.axis('off') + ax_y.axis('off') + ax_pred.axis('off') + + + print("Saving images") + save_dir_x = save_dir + '_x.png' + save_dir_y = save_dir + '_y.png' + save_dir_pred = save_dir + '_pred.png' + fig_x.savefig(save_dir_x) + fig_y.savefig(save_dir_y) + fig_pred.savefig(save_dir_pred) + + +def update_volume_comp_gif(x,y, images_gif, trained_model, + multi_class, + which_epoch, + which_volume=1, + gif_dir='', + gif_cmap='gray', + clean=False): + + x = np.array(x) + x = np.squeeze(x, axis=-1) + + print('predicting volume {}'.format(which_volume)) + pred_vol = trained_model.predict(x) + if multi_class: + pred_vol = np.argmax(pred_vol, axis=-1) + y = np.argmax(y, axis=-1) + print('volume predicted\n') + + print('input image data type: {}, shape: {}'.format(type(x), x.shape)) + print('label image data type: {}, shape: {}'.format(type(y), y.shape)) + print('prediction image data type: {}, shape: {}\n'.format(type(pred), pred.shape)) + + for i in range(x.shape[0]): + print(f"Analysing slice {i+1}") + x_im = axes[0].imshow(x[i,:,:], cmap='gray', animated=True, aspect='auto') + y_im = axes[1].imshow(y[i,:,:], cmap='gray', animated=True, aspect='auto') + pred_im = axes[2].imshow(pred_vol[i,:,:], cmap='gray', animated=True, aspect='auto') + if not clean: + text = ax.text(0.5,1.05,f'Slice {i+1}', + size=plt.rcParams["axes.titlesize"], + ha="center", transform=ax.transAxes) + images_gif.append([im, text]) + else: + ax.axis('off') + images_gif.append([im]) + + return images_gif + + +def update_epoch_gif(x, trained_model, aug_strategy, + multi_class, which_slice, which_volume=1, + epoch_limit=1000, + gif_dir='', + gif_cmap='gray', + clean=False): + + images_gif = [] + + x_slice = np.expand_dims(x[which_slice-1], axis=0) + print('Input image data type: {}, shape: {}\n'.format(type(x_slice), x_slice.shape)) + + print('predicting slice {}'.format(which_slice)) + predicted_slice = trained_model.predict(x_slice) + if multi_class: + predicted_slice = np.argmax(predicted_slice, axis=-1) + else: + predicted_slice = np.squeeze(predicted_slice, axis=-1) + + print('slice predicted\n') + + print("adding prediction to the queue") + im = ax.imshow(predicted_slice[0], cmap=gif_cmap, animated=True) + if not clean: + text = ax.text(0.5,1.05,f"Epoch {int(name.split('.')[1])}", + size=plt.rcParams["axes.titlesize"], + ha="center", transform=ax.transAxes) + images_gif.append([im, text]) + else: + ax.axis('off') + images_gif.append([im]) + print("prediction added\n") + + return images_gif +########## + + + +########## Plotly npys ########## +def update_volume_npy(y, pred, target, + sample_pred, sample_y, + visual_file, name, + which_volume, multi_class): + batch_size = y.shape[0] + y = np.array(y) + pred = np.array(pred) + + + if (get_depth(sample_pred) + batch_size) < target: # check if next batch will fit in volume (160) + sample_pred.append(pred) + del pred + sample_y.append(y) + del y + else: + remaining = target - get_depth(sample_pred) + sample_pred.append(pred[:remaining]) + sample_y.append(y[:remaining]) + pred_vol = np.concatenate(sample_pred) + del sample_pred + y_vol = np.concatenate(sample_y) + del sample_y + sample_pred = [pred[remaining:]] + sample_y = [y[remaining:]] + + del pred + del y + + ###################### + # print("===============") + # print("pred done") + # print(pred_vol.shape) + # print(y_vol.shape) + # print("===============") + # print('multi_class', multi_class) + ###################### + + if multi_class: # or np.shape(pred_vol)[-1] not + pred_vol = np.argmax(pred_vol, axis=-1) + y_vol = np.argmax(y_vol, axis=-1) + + ###################### + # print('np.shape(pred_vol)', np.shape(pred_vol)) + # print('np.shape(y_vol)',np.shape(y_vol)) + ###################### + + # Save volume as numpy file for plotlyyy + fig_dir = "results" + name_pred_npy = os.path.join(fig_dir, "pred", (visual_file + "_" + name)) + name_y_npy = os.path.join(fig_dir, "ground_truth", (visual_file + "_vol" + str(which_volume).zfill(3))) + + ###################### + # print("npy save pred as ", name_pred_npy) + # print("npy save y as ", name_y_npy) + # print("Currently on vol ", idx_vol) + ###################### + + + # Get middle xx slices cuz 288x288x160 too big + roi = int(80 / 2) + d1,d2,d3 = np.shape(pred_vol)[0:3] + d1, d2, d3 = int(np.floor(d1/2)), int(np.floor(d2/2)), int(np.floor(d3/2)) + pred_vol = pred_vol[(d1-roi):(d1+roi),(d2-roi):(d2+roi), (d3-roi):(d3+roi)] + d1,d2,d3 = np.shape(y_vol)[0:3] + d1, d2, d3 = int(np.floor(d1/2)), int(np.floor(d2/2)), int(np.floor(d3/2)) + y_vol = y_vol[(d1-roi):(d1+roi),(d2-roi):(d2+roi), (d3-roi):(d3+roi)] + + ###################### + print('y_vol.shape', np.shape(y_vol)) + ###################### + + np.save(name_pred_npy,pred_vol) + np.save(name_y_npy,y_vol) + ###################### + print("Total voxels saved, pred:", np.sum(pred_vol), "y:", np.sum(y_vol)) + ###################### + + sample_pred = [] + sample_y = [] + del pred_vol + del y_vol + + return sample_pred, sample_y + +########## + + + + + +def eval_loop(dataset, validation_steps, aug_strategy, + bucket_name, logdir, tpu_name, visual_file, weights_dir, + fig_dir, + which_volume, which_epoch, which_slice, + multi_as_binary, + trained_model, model_architecture, + callbacks, + num_classes=7 + ): + + """ Evaluate model and visualize as needed """ + + multi_class = num_classes > 1 + gif_dir='' + + # load the checkpoints in the specified log directory + session_weights = get_all_weights(bucket_name, logdir, tpu_name, visual_file, weights_dir) + last_epoch = len(session_weights) + + # trained_model.load_weights(weights_dir).expect_partial() + # trained_model.evaluate(dataset, steps=validation_steps, callbacks=callbacks) + + + + + # Callbacks (as in og conf matrix function) + f = weights_dir.split('/')[-1] + # Excluding parenthese before f too + if weights_dir.endswith(f): + writer_dir = weights_dir[:-(len(f)+1)] + writer_dir = os.path.join(writer_dir, 'eval') + eval_metric_writer = tf.summary.create_file_writer(writer_dir) + + + # Init visuals + cm, classes = initialize_cm(multi_class, num_classes) + fig, axes, images_gif = initialize_gif() + target = 160 # how many slices in 1 vol + sample_pred = [] # prediction for current 160,288,288 vol + sample_y = [] # y for current 160,288,288 vol + + + + + for chkpt in session_weights: + ### Skip to last chkpt if you only want evaluation + + name = chkpt.split('/')[-1] + name = name.split('.inde')[0] + epoch = name.split('.')[1] + + ######################### + print("\n\n\n\n+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++") + print(f"\t\tLoading weights from {epoch} epoch") + print(f"\t\t {name}") + print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n") + ######################### + + trained_model.load_weights('gs://' + os.path.join(bucket_name, + weights_dir, + tpu_name, + visual_file, + name)).expect_partial() + if epoch==last_epoch: + trained_model.evaluate(dataset, steps=validation_steps, callbacks=callbacks) + + + # Initializing volume saving + sample_pred = [] # prediction for current 160,288,288 vol + sample_y = [] # y for current 160,288,288 vol + + for step, (x, label) in enumerate(dataset): + print('step',step) + pred = trained_model.predict(x) + + + # Update visuals + cm = update_cm(cm, num_classes) + visualise_multi_class(label, pred) + + if step+1 == which_volume: + update_gif_slice(x, label, trained_model, + aug_strategy, + multi_as_binary, multi_class, + which_epoch, which_slice) + + images_gif = update_volume_comp_gif(x,label, images_gif, trained_model, + multi_class, + which_epoch, + gif_dir=gif_dir) + + images_gif = update_epoch_gif(x, trained_model, aug_strategy, + multi_class, which_slice, + gif_dir=gif_dir) + + sample_pred, sample_y = update_volume_npy(label, pred, target, + sample_pred, sample_y, + visual_file, name, + which_volume, multi_class) + + + + # if multi_class: + # iou = iou_loss_eval(label, pred) + # dice = dice_coef_eval(label, pred) + # else: + # iou = iou_loss(label, pred) + # dice = dice_coef(label, pred) + iou = iou_loss_eval(label, pred) if multi_class else iou_loss(label, pred) + dice = dice_coef_eval(label, pred) if multi_class else dice_coef(label, pred) + + with eval_metric_writer.as_default(): + tf.summary.scalar('iou eval validation', iou, step=step) + tf.summary.scalar('dice eval validation', dice, step=step) + + # Save visuals + save_cm(cm, model_architecture, fig_dir, classes) + pred_evolution_gif(fig, images_gif, save_dir=gif_dir, save=True, no_margins=False) +