Switch to side-by-side view

--- a
+++ b/documentation/tutorial/visualization.py
@@ -0,0 +1,197 @@
+"""
+By KB Girum
+"""
+# import libraries
+import os
+import glob
+import nibabel as nib
+import imageio
+import datetime
+import numpy as np
+import matplotlib.pyplot as plt
+from skimage.transform import resize
+import h5py
+
+import random
+import pandas as pd
+from tqdm import tqdm
+import cv2
+
+from numpy.random import uniform, exponential
+from itertools import cycle
+import csv
+
+from lifelines import KaplanMeierFitter
+from lifelines.plotting import plot_lifetimes
+from lifelines.utils import concordance_index
+from lifelines.statistics import logrank_test
+
+from lifelines import CoxPHFitter
+from lifelines.utils import k_fold_cross_validation
+from lifelines.plotting import add_at_risk_counts
+
+from matplotlib import pyplot as plt
+from matplotlib.pyplot import *
+from numpy import ndarray
+
+kmf = KaplanMeierFitter()
+from medpy.metric import binary
+
+import ntpath
+from scipy.ndimage import label
+
+
+def superimpose_segmentation_images(pet_gt_prd_display, file_name, logzscore=None):
+    """
+
+    Args:
+        pet_ct_gt_prd:
+        file_name:
+        logzscore:
+    """
+    pet, gt, prd = pet_gt_prd_display[0], pet_gt_prd_display[1], pet_gt_prd_display[2]
+
+    if logzscore == "log":
+        pet = np.log(pet + 1)
+    elif logzscore == "zscore":
+        pet = (pet - np.mean(pet)) / (np.std(pet) + 1e-8)
+    elif logzscore == "clipping":
+        pet[pet > 50] = 50
+        pet /= 50
+    else:
+        pet = np.log(pet + 1)
+
+    img = pet
+    try:
+        img = np.squeeze(img, axis=-1)
+    except:
+        pass
+    try:
+        gt = np.squeeze(gt, axis=-1)
+    except:
+        pass
+
+    try:
+        prd = np.squeeze(prd, axis=-1)
+    except:
+        pass
+
+    img = np.rot90(img)
+
+    if len(prd):
+        prd = np.rot90(prd)
+        prd[prd > 0] = 1
+
+    img = 10 - img
+    if len(gt):
+        gt = np.rot90(gt)
+        gt[gt > 0] = 1
+
+        # miss classified regions
+        prd_error = prd + gt
+        prd_error[prd_error != 1] = 0
+        dice = binary.dc(prd, gt)
+        dice = np.round(dice * 100, 1)
+    else:
+        dice = 'unkown'
+
+    color = ['brg']
+    hfont = {'fontname': 'Arial'}
+    fontsize_ = 12
+    for clr in color:
+        viridis = cm.get_cmap(clr)
+        print("\n Image ID: \t %s", str(file_name))
+        fig, axs = plt.subplots(1, 3, figsize=(10, 10))
+        axs[0].imshow(img, cmap='gray')
+        axs[0].set_title('PET image', **hfont, fontsize=fontsize_)
+        axs[0].set_xticklabels([])
+        axs[0].set_yticklabels([])
+
+
+        axs[1].imshow(img, cmap='gray')
+        if len(gt):
+            gt = np.ma.masked_where(gt == 0, gt)
+            axs[1].imshow(gt, cmap=viridis)  # cmap='gray')#
+            axs[1].set_title('Expert', **hfont, fontsize=fontsize_)
+        else:
+            axs[1].set_title('No ground truth provided', **hfont, fontsize=fontsize_)
+        axs[1].set_xticklabels([])
+        axs[1].set_yticklabels([])
+        axs[1].set_aspect('equal')
+
+        axs[2].imshow(img, cmap='gray')
+        if len(prd):
+            prd = np.ma.masked_where(prd==0, prd)
+            axs[2].imshow(prd,  viridis)
+            axs[2].set_title('CNN (Dice score: {dice}%)'.format(dice=dice), **hfont, fontsize=fontsize_)
+        else:
+            axs[2].set_title('predicted image not found'.format(dice=dice), **hfont, fontsize=fontsize_)
+        axs[2].set_xticklabels([])
+        axs[2].set_yticklabels([])
+        axs[2].set_aspect('equal')
+
+        axs[0].axis('off')
+        axs[1].axis('off')
+        axs[2].axis('off')
+        # plt.savefig('images/' + str(file_name) + '.png', dpi=300)
+        plt.show()
+
+
+def display_image(im_display: ndarray, identifier: str = None):
+    """ display given array of images.
+
+    Args:
+        im_display: array of images to show
+        identifier: patient name to display as title
+    """
+    plt.figure(figsize=(10, 1))
+    plt.subplots_adjust(hspace=0.015)
+    plt.suptitle("Showing image: " + str(identifier), fontsize=12, y=0.95)
+    # loop through the length of tickers and keep track of index
+    for n, im in enumerate(im_display):
+        # add a new subplot iteratively
+        plt.subplot(int(len(im_display) // 2), 2, n + 1)
+        plt.imshow(np.log(im + 1))
+    plt.show()
+
+
+def read_predicted_images(path: str = None):
+    list_input_dir = os.listdir(path)
+    print(f'Number of cases: {len(list_input_dir)}')
+
+    for file_name in list_input_dir:
+        current_file = os.path.join(path, file_name)
+        # read ct, gt, and pet, and pred
+        pet_gt_prd = [ntpath.basename(nii) for nii in glob.glob(str(current_file) + "/*.nii")]
+        gt, pet, pred = [], [], []
+        # try:
+        for index in pet_gt_prd:
+            if "pet" in str(index).lower():
+                pet = np.asanyarray(nib.load(str(current_file) + "/" + str(index)).dataobj)
+            elif "predicted" in str(index).lower():
+                pred = np.asanyarray(nib.load(str(current_file) + "/" + str(index)).dataobj)
+            elif "ground_truth" in str(index).lower() or "gt" in str(index).lower():
+                gt = np.asanyarray(nib.load(str(current_file) + "/" + str(index)).dataobj)
+
+        if len(pred):
+            pred[pred>0.5] =1
+            pred[pred<0.5] = 0
+
+        for coronal_sagittal in range(2):
+            if len(gt) and len(pred):
+                pet_gt_prd_display = [pet[coronal_sagittal], gt[coronal_sagittal], pred[coronal_sagittal]]
+            elif len(pred):
+                pet_gt_prd_display = [pet[coronal_sagittal], gt, pred[coronal_sagittal]]
+            elif len(gt):
+                pet_gt_prd_display = [pet[coronal_sagittal], gt[coronal_sagittal], pred]
+            else:
+                pet_gt_prd_display = [pet[coronal_sagittal], gt, pred]
+
+            superimpose_segmentation_images(pet_gt_prd_display, file_name=file_name)
+        # except:
+        #     pass
+
+
+if __name__ == '__main__':
+    # Function to visualize image and clinical data
+    print("visualize data")