Switch to side-by-side view

--- a
+++ b/code/produce_visualizations.py
@@ -0,0 +1,135 @@
+import pandas as pd
+from mil_data_generator import *
+from mil_models_pytorch import*
+from mil_trainer_torch import *
+from sklearn.utils.class_weight import compute_class_weight
+import torch
+from torchvision import transforms
+from PIL import Image
+import argparse
+
+np.random.seed(42)
+random.seed(42)
+torch.manual_seed(42)
+
+
+def main(args):
+
+    def image_normalization(x, input_shape, channel_first=True):
+        # image resize
+        x = cv2.resize(x, (input_shape[1], input_shape[2]))
+        # intensity normalization
+        x = x / 255.0
+        # channel first
+        if channel_first:
+            x = np.transpose(x, (2, 0, 1))
+        # numeric type
+        x.astype('float32')
+        return x
+
+    # INPUTS #
+    dir_slides = '../data/SICAP_MIL/slides/'
+    dir_data_frame = '../data/SICAP_MIL/dataframes/gt_global_slides.xlsx'
+    dir_experiment = '../data/results/' + args.experiment_name + '/'
+
+    classes = ['G3', 'G4', 'G5']
+    input_shape = (3, 224, 224)
+    images_on_ram = True
+    patch_size = 512
+    overlap = 0.25
+    save_annotations = args.save_annotations
+
+    # Load df and take only test biopsies
+    df = pd.read_excel(dir_data_frame)
+    df = df[df['Partition'] == 'test']
+    slices_test = list(df['slide_name'])
+
+    # slices in folder
+    slices = os.listdir(dir_slides)
+
+    # Load network -- we use the first iteration model as example
+    network = torch.load(dir_experiment + str(0) + '_network_weights_best.pth').cuda()
+
+    if not os.path.isdir(dir_experiment + '/visualizations/'):
+        os.mkdir(dir_experiment + '/visualizations/')
+    if save_annotations:
+        if not os.path.isdir(dir_experiment + '/visualizations_gt/'):
+            os.mkdir(dir_experiment + '/visualizations_gt/')
+
+    c = 0
+    for iSlide in slices:
+        c += 1
+        print(str(c) + '/' + str(len(slices)))
+
+        if iSlide.split('_')[0] in slices_test:
+
+            wsi = Image.open(os.path.join(dir_slides, iSlide))
+            wsi = np.asarray(wsi)
+
+            if save_annotations:
+                if os.path.isfile(os.path.join('../data/SICAP_MIL/annotation_masks/', iSlide)):
+                    wsi_gt = Image.open(os.path.join('../data/SICAP_MIL/annotation_masks/', iSlide))
+                    wsi_gt = np.asarray(wsi_gt)
+
+            tissue = cv2.cvtColor(wsi, cv2.COLOR_BGR2GRAY)
+            ret, thresh1 = cv2.threshold(tissue, 120, 255, cv2.THRESH_BINARY +
+                                         cv2.THRESH_OTSU)
+            tissue = tissue < (ret)
+            tissue = cv2.morphologyEx(np.uint8(tissue*255), cv2.MORPH_CLOSE, np.ones((25, 25), np.uint8)) / 255
+
+            if not save_annotations:
+                output = np.zeros((wsi.shape[0], wsi.shape[1], 4))
+                npatches = np.zeros((wsi.shape[0], wsi.shape[1]))
+                x0 = 0
+                while (x0 + patch_size) <= wsi.shape[1]:
+                    y0 = 0
+                    while (y0 + patch_size) <= wsi.shape[0]:
+                        # If there is tissue, get predictions
+                        if np.mean(tissue[y0:y0+patch_size, x0:x0+patch_size]) > 0.2:
+                            # Take patch
+                            patch = wsi[y0:y0+patch_size, x0:x0+patch_size, :]
+                            # Pre-process patch
+                            x = image_normalization(patch.copy(), input_shape)
+                            x = torch.tensor(x).cuda().float().unsqueeze(0)
+                            # Forward
+                            features = network.bb(x)
+                            yhat = torch.softmax(network.classifier(torch.squeeze(features)), 0)
+                            yhat = yhat.detach().cpu().numpy()
+                            # Update visualization heatmap
+                            output[y0:y0+patch_size, x0:x0+patch_size, :] += yhat
+                            npatches[y0:y0+patch_size, x0:x0+patch_size] += 1
+
+                        y0 = int(y0 + patch_size*overlap)
+                    x0 = int(x0 + patch_size*overlap)
+
+                a = output / (np.expand_dims(npatches, -1) + 1e-6)
+                mask = np.argmax(a, axis=-1)
+                mask = mask * tissue
+
+                colors = np.float64(np.concatenate([np.expand_dims(mask == 3, -1),
+                                                    np.expand_dims(mask == 1, -1),
+                                                    np.expand_dims(mask == 2, -1)], axis=-1))
+                overlay = wsi + 0.3 * (colors * 254)
+                overlay = np.clip(overlay, 0, 254) / 255
+
+                im = Image.fromarray((overlay * 255).astype(np.uint8))
+                im.save(dir_experiment + '/visualizations/' + iSlide)
+
+            if save_annotations and os.path.isfile(os.path.join('../data/SICAP_MIL/annotation_masks/', iSlide)):
+                colors = np.float64(np.concatenate([np.expand_dims(wsi_gt >= 170, -1),
+                                                    np.expand_dims((wsi_gt >= 25) * (wsi_gt <= 80) , -1),
+                                                    np.expand_dims((wsi_gt >= 80) * (wsi_gt <= 170), -1)], axis=-1))
+                overlay = wsi + 0.3 * (colors * 254)
+                overlay = np.clip(overlay, 0, 254) / 255
+
+                im = Image.fromarray((overlay * 255).astype(np.uint8))
+                im.save(dir_experiment + '/visualizations_gt/' + iSlide)
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--experiment_name", default="test_test_test", type=str)
+    parser.add_argument("--save_annotations", default=False, type=bool)
+
+    args = parser.parse_args()
+    main(args)