--- a
+++ b/DigiPathAI/new_Segmentation.py
@@ -0,0 +1,687 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from datetime import datetime
+import os
+import glob
+import random
+
+import imgaug
+from imgaug import augmenters as iaa
+from PIL import Image
+from tqdm import tqdm
+import matplotlib.pyplot as plt
+
+import openslide
+import numpy as np
+import tensorflow as tf
+from tensorflow.keras import backend as K
+from tensorflow.keras.models import Model
+from tensorflow.keras.layers import Input, BatchNormalization, Conv2D, MaxPooling2D, AveragePooling2D, ZeroPadding2D, concatenate, Concatenate, UpSampling2D, Activation
+from tensorflow.keras.losses import categorical_crossentropy
+from tensorflow.keras.applications.densenet import DenseNet121
+from tensorflow.keras.optimizers import Adam
+from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler, TensorBoard
+from tensorflow.keras import metrics
+
+from torch.utils.data import DataLoader, Dataset
+from torchvision import transforms  # noqa
+
+import sklearn.metrics
+import io
+import itertools
+from six.moves import range
+
+import time
+import argparse 
+import cv2
+from skimage.color import rgb2hsv
+from skimage.filters import threshold_otsu
+
+import sys
+sys.path.append(os.path.dirname(os.path.abspath(os.getcwd())))
+from models.seg_models import get_inception_resnet_v2_unet_softmax, unet_densenet121
+from models.deeplabv3p_original import Deeplabv3
+
+# Random Seeds
+np.random.seed(0)
+random.seed(0)
+tf.set_random_seed(0)
+import gc
+import pandas as pd
+
+import tifffile 
+import skimage.io as io
+import DigiPathAI
+
+# Image Helper Functions
+def imsave(*args, **kwargs):
+     """
+     Concatenate the images given in args and saves them as a single image in the specified output destination.
+     Images should be numpy arrays and have same dimensions along the 0 axis.
+     imsave(im1,im2,out="sample.png")
+     """
+     args_list = list(args)
+     for i in range(len(args_list)):
+         if type(args_list[i]) != np.ndarray:
+             print("Not a numpy array")
+             return 0
+         if len(args_list[i].shape) == 2:
+             args_list[i] = np.dstack([args_list[i]]*3)
+             if args_list[i].max() == 1:
+                args_list[i] = args_list[i]*255
+
+     out_destination = kwargs.get("out",'')
+     try:
+         concatenated_arr = np.concatenate(args_list,axis=1)
+         im = Image.fromarray(np.uint8(concatenated_arr))
+     except Exception as e:
+         print(e)
+         import ipdb; ipdb.set_trace()
+         return 0
+     if out_destination:
+         print("Saving to %s"%(out_destination))
+         im.save(out_destination)
+     else:
+        return im
+
+def imshow(*args,**kwargs):
+    """ Handy function to show multiple plots in on row, possibly with different cmaps and titles
+    Usage:
+    imshow(img1, title="myPlot")
+    imshow(img1,img2, title=['title1','title2'])
+    imshow(img1,img2, cmap='hot')
+    imshow(img1,img2,cmap=['gray','Blues']) """
+    cmap = kwargs.get('cmap', 'gray')
+    title= kwargs.get('title','')
+    axis_off = kwargs.get('axis_off','')
+    if len(args)==0:
+        raise ValueError("No images given to imshow")
+    elif len(args)==1:
+        plt.title(title)
+        plt.imshow(args[0], interpolation='none')
+    else:
+        n=len(args)
+        if type(cmap)==str:
+            cmap = [cmap]*n
+        if type(title)==str:
+            title= [title]*n
+        plt.figure(figsize=(n*5,10))
+        for i in range(n):
+            plt.subplot(1,n,i+1)
+            plt.title(title[i])
+            plt.imshow(args[i], cmap[i])
+            if axis_off: 
+              plt.axis('off')  
+    plt.show()
+def normalize_minmax(data):
+    """
+    Normalize contrast across volume
+    """
+    _min = np.float(np.min(data))
+    _max = np.float(np.max(data))
+    if (_max-_min)!=0:
+        img = (data - _min) / (_max-_min)
+    else:
+        img = np.zeros_like(data)            
+    return img
+
+# Functions
+def BinMorphoProcessMask(mask,level):
+    """
+    Binary operation performed on tissue mask
+    """
+    close_kernel = np.ones((20, 20), dtype=np.uint8)
+    image_close = cv2.morphologyEx(np.array(mask), cv2.MORPH_CLOSE, close_kernel)
+    open_kernel = np.ones((5, 5), dtype=np.uint8)
+    image_open = cv2.morphologyEx(np.array(image_close), cv2.MORPH_OPEN, open_kernel)
+    if level == 2:
+        kernel = np.ones((60, 60), dtype=np.uint8)
+    elif level == 3:
+        kernel = np.ones((35, 35), dtype=np.uint8)
+    else:
+        raise ValueError
+    image = cv2.dilate(image_open,kernel,iterations = 1)
+    return image
+
+def get_bbox(cont_img, rgb_image=None):
+    temp_img = np.uint8(cont_img.copy())
+    _,contours, _ = cv2.findContours(temp_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+    rgb_contour = None
+    if rgb_image is not None:
+        rgb_contour = rgb_image.copy()
+        line_color = (0, 0, 255)  # blue color code
+        cv2.drawContours(rgb_contour, contours, -1, line_color, 2)
+    bounding_boxes = [cv2.boundingRect(c) for c in contours]
+    for x, y, h, w in bounding_boxes:
+        rgb_contour = cv2.rectangle(rgb_contour,(x,y),(x+h,y+w),(0,255,0),2)
+    return bounding_boxes, rgb_contour
+
+def get_all_bbox_masks(mask, stride_factor):
+    """
+    Find the bbox and corresponding masks
+    """
+    bbox_mask = np.zeros_like(mask)
+    bounding_boxes, _ = get_bbox(mask)
+    y_size, x_size = bbox_mask.shape
+    for x, y, h, w in bounding_boxes:
+        x_min = x - stride_factor
+        x_max = x + h + stride_factor
+        y_min = y - stride_factor
+        y_max = y + w + stride_factor
+        if x_min < 0: 
+         x_min = 0
+        if y_min < 0: 
+         y_min = 0
+        if x_max > x_size: 
+         x_max = x_size - 1
+        if y_max > y_size: 
+         y_max = y_size - 1      
+        bbox_mask[y_min:y_max, x_min:x_max]=1
+    return bbox_mask
+
+def get_all_bbox_masks_with_stride(mask, stride_factor):
+    """
+    Find the bbox and corresponding masks
+    """
+    bbox_mask = np.zeros_like(mask)
+    bounding_boxes, _ = get_bbox(mask)
+    y_size, x_size = bbox_mask.shape
+    for x, y, h, w in bounding_boxes:
+        x_min = x - stride_factor
+        x_max = x + h + stride_factor
+        y_min = y - stride_factor
+        y_max = y + w + stride_factor
+        if x_min < 0: 
+         x_min = 0
+        if y_min < 0: 
+         y_min = 0
+        if x_max > x_size: 
+         x_max = x_size - 1
+        if y_max > y_size: 
+         y_max = y_size - 1      
+        bbox_mask[y_min:y_max:stride_factor, x_min:x_max:stride_factor]=1
+        
+    return bbox_mask
+
+def find_largest_bbox(mask, stride_factor):
+    """
+    Find the largest bounding box encompassing all the blobs
+    """
+    y_size, x_size = mask.shape
+    x, y = np.where(mask==1)
+    bbox_mask = np.zeros_like(mask)
+    x_min = np.min(x) - stride_factor
+    x_max = np.max(x) + stride_factor
+    y_min = np.min(y) - stride_factor
+    y_max = np.max(y) + stride_factor
+    
+    if x_min < 0: 
+     x_min = 0
+    
+    if y_min < 0: 
+     y_min = 0
+
+    if x_max > x_size: 
+     x_max = x_size - 1
+    
+    if y_min > y_size: 
+     y_max = y_size - 1    
+    
+    bbox_mask[x_min:x_max, y_min:y_max]=1
+    return bbox_mask
+    
+def TissueMaskGeneration(slide_obj, level, RGB_min=50):
+    img_RGB = slide_obj.read_region((0, 0),level,slide_obj.level_dimensions[level])
+    img_RGB = np.transpose(np.array(img_RGB.convert('RGB')),axes=[1,0,2])
+    img_HSV = rgb2hsv(img_RGB)
+    background_R = img_RGB[:, :, 0] > threshold_otsu(img_RGB[:, :, 0])
+    background_G = img_RGB[:, :, 1] > threshold_otsu(img_RGB[:, :, 1])
+    background_B = img_RGB[:, :, 2] > threshold_otsu(img_RGB[:, :, 2])
+    tissue_RGB = np.logical_not(background_R & background_G & background_B)
+    tissue_S = img_HSV[:, :, 1] > threshold_otsu(img_HSV[:, :, 1])
+    min_R = img_RGB[:, :, 0] > RGB_min
+    min_G = img_RGB[:, :, 1] > RGB_min
+    min_B = img_RGB[:, :, 2] > RGB_min
+
+    tissue_mask = tissue_S & tissue_RGB & min_R & min_G & min_B
+    # r = img_RGB[:,:,0] < 235
+    # g = img_RGB[:,:,1] < 210
+    # b = img_RGB[:,:,2] < 235
+    # tissue_mask = np.logical_or(r,np.logical_or(g,b))
+    return tissue_mask 
+def TissueMaskGenerationPatch(patchRGB):
+    '''
+    Returns mask of tissue that obeys the threshold set by paip
+    '''
+    r = patchRGB[:,:,0] < 235
+    g = patchRGB[:,:,1] < 210
+    b = patchRGB[:,:,2] < 235
+    tissue_mask = np.logical_or(r,np.logical_or(g,b))
+    return tissue_mask 
+    
+def TissueMaskGeneration_BIN(slide_obj, level):
+    img_RGB = np.transpose(np.array(slide_obj.read_region((0, 0),
+                       level,
+                       slide_obj.level_dimensions[level]).convert('RGB')),
+                       axes=[1, 0, 2])    
+    img_HSV = cv2.cvtColor(img_RGB, cv2.COLOR_BGR2HSV)
+    img_S = img_HSV[:, :, 1]
+    _,tissue_mask = cv2.threshold(img_S, 0, 255, cv2.THRESH_BINARY)
+    return np.array(tissue_mask)
+
+def TissueMaskGeneration_BIN_OTSU(slide_obj, level):
+    img_RGB = np.transpose(np.array(slide_obj.read_region((0, 0),
+                       level,
+                       slide_obj.level_dimensions[level]).convert('RGB')),
+                       axes=[1, 0, 2])    
+    img_HSV = cv2.cvtColor(img_RGB, cv2.COLOR_BGR2HSV)
+    img_S = img_HSV[:, :, 1]
+    _,tissue_mask = cv2.threshold(img_S, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
+    return np.array(tissue_mask)
+
+def labelthreshold(image, threshold=0.5):
+    np.place(image,image>=threshold, 1)
+    np.place(image,image<threshold, 0)
+    return np.uint8(image)
+
+def calc_jacc_score(x,y,smoothing=1):
+    for var in [x,y]:
+        np.place(var,var==255,1)
+    
+    numerator = np.sum(x*y)
+    denominator = np.sum(np.logical_or(x,y))
+    return (numerator+smoothing)/(denominator+smoothing)
+
+# DataLoader Implementation
+class WSIStridedPatchDataset(Dataset):
+    """
+    Data producer that generate all the square grids, e.g. 3x3, of patches,
+    from a WSI and its tissue mask, and their corresponding indices with
+    respect to the tissue mask
+    """
+    def __init__(self, wsi_path, mask_path, label_path=None, image_size=256,
+                 normalize=True, flip='NONE', rotate='NONE',                
+                 level=5, sampling_stride=16, roi_masking=True):
+        """
+        Initialize the data producer.
+
+        Arguments:
+            wsi_path: string, path to WSI file
+            mask_path: string, path to mask file in numpy format OR None
+            label_mask_path: string, path to ground-truth label mask path in tif file or
+                            None (incase of Normal WSI or test-time)
+            image_size: int, size of the image before splitting into grid, e.g. 768
+            patch_size: int, size of the patch, e.g. 256
+            crop_size: int, size of the final crop that is feed into a CNN,
+                e.g. 224 for ResNet
+            normalize: bool, if normalize the [0, 255] pixel values to [-1, 1],
+                mostly False for debuging purpose
+            flip: string, 'NONE' or 'FLIP_LEFT_RIGHT' indicating the flip type
+            rotate: string, 'NONE' or 'ROTATE_90' or 'ROTATE_180' or
+                'ROTATE_270', indicating the rotate type
+            level: Level to extract the WSI tissue mask
+            roi_masking: True: Multiplies the strided WSI with tissue mask to eliminate white spaces,
+                                False: Ensures inference is done on the entire WSI   
+            sampling_stride: Number of pixels to skip in the tissue mask, basically it's the overlap
+                            fraction when patches are extracted from WSI during inference.
+                            stride=1 -> consecutive pixels are utilized
+                            stride= image_size/pow(2, level) -> non-overalaping patches 
+        """
+        self._wsi_path = wsi_path
+        self._mask_path = mask_path
+        self._label_path = label_path
+        self._image_size = image_size
+        self._normalize = normalize
+        self._flip = flip
+        self._rotate = rotate
+        self._level = level
+        self._sampling_stride = sampling_stride
+        self._roi_masking = roi_masking
+        
+        self._preprocess()
+
+    def _preprocess(self):
+        self._slide = openslide.OpenSlide(self._wsi_path)
+        
+        if self._label_path is not None:
+            self._label_slide = openslide.OpenSlide(self._label_path)
+        
+        X_slide, Y_slide = self._slide.level_dimensions[0]
+        print("Image dimensions: (%d,%d)" %(X_slide,Y_slide))
+        
+        factor = self._sampling_stride
+
+        
+        if self._mask_path is not None:
+            mask_file_name = os.path.basename(self._mask_path)
+            if mask_file_name.endswith('.tiff'):
+                mask_obj = openslide.OpenSlide(self._mask_path)
+                self._mask = np.array(mask_obj.read_region((0, 0),
+                       self._level,
+                       mask_obj.level_dimensions[self._level]).convert('L')).T
+                np.place(self._mask,self._mask>0,255)
+        else:
+            # Generate tissue mask on the fly    
+            
+            self._mask = TissueMaskGeneration(self._slide, self._level)
+        # morphological operations ensure the holes are filled in tissue mask
+        # and minor points are aggregated to form a larger chunk         
+
+        self._mask = BinMorphoProcessMask(np.uint8(self._mask),self._level)
+        # self._all_bbox_mask = get_all_bbox_masks(self._mask, factor)
+        # self._largest_bbox_mask = find_largest_bbox(self._mask, factor)
+        # self._all_strided_bbox_mask = get_all_bbox_masks_with_stride(self._mask, factor)
+
+        X_mask, Y_mask = self._mask.shape
+        # print (self._mask.shape, np.where(self._mask>0))
+        # imshow(self._mask.T)
+        # cm17 dataset had issues with images being power's of 2 precisely        
+#         if X_slide != X_mask or Y_slide != Y_mask:
+        print('Mask (%d,%d) and Slide(%d,%d) '%(X_mask,Y_mask,X_slide,Y_slide))
+        if X_slide // X_mask != Y_slide // Y_mask:
+            raise Exception('Slide/Mask dimension does not match ,'
+                            ' X_slide / X_mask : {} / {},'
+                            ' Y_slide / Y_mask : {} / {}'
+                            .format(X_slide, X_mask, Y_slide, Y_mask))
+
+        self._resolution = np.round(X_slide * 1.0 / X_mask)
+        if not np.log2(self._resolution).is_integer():
+            raise Exception('Resolution (X_slide / X_mask) is not power of 2 :'
+                            ' {}'.format(self._resolution))
+             
+        # all the idces for tissue region from the tissue mask  
+        self._strided_mask =  np.ones_like(self._mask)
+        ones_mask = np.zeros_like(self._mask)
+        ones_mask[::factor, ::factor] = self._strided_mask[::factor, ::factor]
+        
+        
+        if self._roi_masking:
+            self._strided_mask = ones_mask*self._mask   
+            # self._strided_mask = ones_mask*self._largest_bbox_mask   
+            # self._strided_mask = ones_mask*self._all_bbox_mask 
+            # self._strided_mask = self._all_strided_bbox_mask  
+        else:
+            self._strided_mask = ones_mask  
+        # print (np.count_nonzero(self._strided_mask), np.count_nonzero(self._mask[::factor, ::factor]))
+        # imshow(self._strided_mask.T, self._mask[::factor, ::factor].T)
+        # imshow(self._mask.T, self._strided_mask.T)
+ 
+        self._X_idcs, self._Y_idcs = np.where(self._strided_mask)        
+        self._idcs_num = len(self._X_idcs)
+
+    def __len__(self):        
+        return self._idcs_num 
+
+    def save_scaled_imgs(self):
+        scld_dms = self._slide.level_dimensions[self._level]
+        self._slide_scaled = self._slide.read_region((0,0),self._level,scld_dms)
+        
+        if self._label_path is not None:
+            self._label_scaled = np.array(self._label_slide.read_region((0,0),4,scld_dms).convert('L'))
+            np.place(self._label_scaled,self._label_scaled>0,255)
+        
+    def save_get_mask(self, save_path):
+        np.save(save_path, self._mask)
+
+    def get_mask(self):
+        return self._mask
+
+    def get_strided_mask(self):
+        return self._strided_mask
+    
+    def __getitem__(self, idx):
+        x_coord, y_coord = self._X_idcs[idx], self._Y_idcs[idx]
+        
+        x_max_dim,y_max_dim = self._slide.level_dimensions[0]
+
+        # x = int(x_coord * self._resolution)
+        # y = int(y_coord * self._resolution)    
+
+        x = int(x_coord * self._resolution - self._image_size//2)
+        y = int(y_coord * self._resolution - self._image_size//2)    
+#         x = int(x_coord * self._resolution)
+#         y = int(y_coord * self._resolution)    
+        
+        #If Image goes out of bounds
+        if x>(x_max_dim - image_size):
+            x = x_max_dim - image_size
+        elif x<0:
+            x = 0
+        if y>(y_max_dim - image_size):
+            y = y_max_dim - image_size
+        elif y<0:
+            y = 0
+    
+        #Converting pil image to np array transposes the w and h
+        img = np.transpose(self._slide.read_region(
+            (x, y), 0, (self._image_size, self._image_size)).convert('RGB'),[1,0,2])
+        
+        if self._label_path is not None:
+            label_img = self._label_slide.read_region(
+                (x, y), 0, (self._image_size, self._image_size)).convert('L')
+        else:
+            #print('No label img')
+            label_img = Image.fromarray(np.zeros((self._image_size, self._image_size), dtype=np.uint8))
+        
+        if self._flip == 'FLIP_LEFT_RIGHT':
+            img = img.transpose(Image.FLIP_LEFT_RIGHT)
+            label_img = label_img.transpose(Image.FLIP_LEFT_RIGHT)
+            
+        if self._rotate == 'ROTATE_90':
+            img = img.transpose(Image.ROTATE_90)
+            label_img = label_img.transpose(Image.ROTATE_90)
+            
+        if self._rotate == 'ROTATE_180':
+            img = img.transpose(Image.ROTATE_180)
+            label_img = label_img.transpose(Image.ROTATE_180)
+
+        if self._rotate == 'ROTATE_270':
+            img = img.transpose(Image.ROTATE_270)
+            label_img = label_img.transpose(Image.ROTATE_270)
+
+        # PIL image:   H x W x C
+        img = np.array(img, dtype=np.float32)
+        label_img = np.array(label_img, dtype=np.uint8)
+        np.place(label_img, label_img>0, 255)
+
+        if self._normalize:
+            img = (img - 128.0)/128.0
+   
+        return (img, x, y, label_img)
+
+def getSegmentation(img_path, 
+			patch_size  = 256, 
+			stride_size = 128,
+			batch_size  = 32,
+			quick       = True,
+			tta_list    = None,
+			crf         = False,
+			status      = None):
+    """
+        Saves the prediction at the same location as the input image
+        args:
+            img_path: WSI tiff image path (str)
+            patch_size: patch size for inference (int)
+            stride_size: stride to skip during segmentation (int)
+            batch_size: batch_size during inference (int)
+            quick: if True; prediction is of single model (bool)
+                    else: final segmentation is ensemble of 4 different models
+            tta_list: type of augmentation required during inference
+                     allowed: ['FLIP_LEFT_RIGHT', 'ROTATE_90', 'ROTATE_180', 'ROTATE_270'] (list(str))
+            crf: application of conditional random fields in post processing step (bool)
+            status: required for webserver (json)
+
+        return :
+            saves the prediction in given path (in .tiff format)
+            prediction: predicted segmentation mask
+
+    """
+    #Model loading
+    core_config = tf.ConfigProto()
+    core_config.gpu_options.allow_growth = False
+    session =tf.Session(config=core_config) 
+    K.set_session(session)
+
+    def load_incep_resnet(model_path):
+        model = get_inception_resnet_v2_unet_softmax((None, None), weights=None)
+        model.load_weights(model_path)
+        print ("Loaded Model Weights %s" % model_path)
+        return model
+
+    def load_unet_densenet(model_path):
+        model = unet_densenet121((None, None), weights=None)
+        model.load_weights(model_path)
+        print ("Loaded Model Weights %s" % model_path)
+        return model
+
+    def load_deeplabv3(model_path, OS):
+        model = Deeplabv3(input_shape=(patch_size, patch_size, 3),weights=None,classes=2,activation='softmax',backbone='xception',OS=OS)
+        model.load_weights(model_path)
+        print ("Loaded Model Weights %s" % model_path)
+        return model
+
+    model_path_root = os.path.join(DigiPathAI.digipathai_folder,'digestpath_models')
+    model_dict = {}
+    if quick == True:
+        model_dict['densenet'] = load_unet_densenet(os.path.join(model_path_root,'densenet.h5'))
+    else:
+        model_dict['inception'] = load_incep_resnet(os.path.join(model_path_root,'inception.h5'))
+        model_dict['densenet'] = load_unet_densenet(os.path.join(model_path_root,'densenet.h5'))
+        model_dict['deeplab'] = load_deeplabv3(os.path.join(model_path_root,'deeplab.h5'))
+    
+    ensemble_key = 'ensemble_key'
+    model_dict[ensemble_key] = 'ensemble'
+    models_to_save = [ensemble_key]
+    model_keys = list(model_dict.keys())
+        
+    #Stitcher
+    start_time = time.time()
+    wsi_path = img_path
+    wsi_obj = openslide.OpenSlide(wsi_path)
+    x_max_dim,y_max_dim = wsi_obj.level_dimensions[0]
+    count_map = np.zeros(wsi_obj.level_dimensions[0],dtype='uint8')
+
+    prd_im_fll_dict = {}
+    memmaps_path = os.path.join(DigiPathAI.digipathai_folder,'memmaps')
+    os.makedirs(memmaps_path,exist_ok=True)
+    for key in models_to_save:
+        prd_im_fll_dict[key] = np.memmap(os.path.join(memmaps_path,'%s.dat'%(key)), dtype=np.float32,mode='w+', shape=(wsi_obj.level_dimensions[0]))
+
+    #Take the smallest resolution available
+    level = len(wsi_obj.level_dimensions) -1 
+    scld_dms = wsi_obj.level_dimensions[-1]
+    scale_sampling_stride = stride_size//int(wsi_obj.level_downsamples[level])
+    print("Level %d , stride %d, scale stride %d" %(level,stride_size, scale_sampling_stride))
+    
+    scale = lambda x: cv2.resize(x,tuple(reversed(scld_dms))).T
+    mask_path = None
+    start_time = time.time()
+    dataset_obj = WSIStridedPatchDataset(wsi_path, 
+                                        mask_path=None,
+                                        label_path=None,
+                                        image_size=patch_size,
+                                        normalize=True,
+                                        flip=None, rotate=None,
+                                        level=level, sampling_stride=scale_sampling_stride, roi_masking=True)
+
+    dataloader = DataLoader(dataset_obj, batch_size=batch_size, num_workers=batch_size, drop_last=True)
+    dataset_obj.save_scaled_imgs()
+
+    print(dataset_obj.get_mask().shape)
+    st_im = dataset_obj.get_strided_mask()
+    mask_im = np.dstack([dataset_obj.get_mask().T]*3).astype('uint8')*255
+    st_im = np.dstack([dataset_obj.get_strided_mask().T]*3).astype('uint8')*255
+    im_im = np.array(dataset_obj._slide_scaled.convert('RGB'))
+    ov_im = mask_im/2 + im_im/2
+    ov_im_stride = st_im/2 + im_im/2
+
+    print("Total iterations: %d %d" % (dataloader.__len__(), dataloader.dataset.__len__()))
+    for i,(data, xes, ys, label) in enumerate(dataloader):
+        tmp_pls= lambda x: x + patch_size
+        tmp_mns= lambda x: x 
+        image_patches = data.cpu().data.numpy()
+        image_patches = data.cpu().data.numpy()
+        
+        pred_map_dict = {}
+        pred_map_dict[ensemble_key] = 0
+        for key in model_keys:
+            pred_map_dict[key] = model_dict[key].predict(image_patches,verbose=0,batch_size=batch_size)
+            pred_map_dict[ensemble_key]+=pred_map_dict[key]
+        pred_map_dict[ensemble_key]/=len(model_keys)
+    
+        actual_batch_size =  image_patches.shape[0]
+        for j in range(actual_batch_size):
+            x = int(xes[j])
+            y = int(ys[j])
+
+            wsi_img = image_patches[j]*128+128
+            patch_mask = TissueMaskGenerationPatch(wsi_img)
+
+            for key in models_to_save:
+                prediction = pred_map_dict[key][j,:,:,1]
+                prediction*=patch_mask
+                prd_im_fll_dict[key][tmp_mns(x):tmp_pls(x),tmp_mns(y):tmp_pls(y)] += prediction
+
+            count_map[tmp_mns(x):tmp_pls(x),tmp_mns(y):tmp_pls(y)] += np.ones((patch_size,patch_size),dtype='uint8')
+        if (i+1)%100==0 or i==0 or i<10:
+            print("Completed %i Time elapsed %.2f min | Max count %d "%(i,(time.time()-start_time)/60,count_map.max()))
+        
+    print("Fully completed %i Time elapsed %.2f min | Max count %d "%(i,(time.time()-start_time)/60,count_map.max()))
+    start_time = time.time()
+
+    print("\t Dividing by count_map")
+    np.place(count_map, count_map==0, 1)
+    for key in models_to_save:
+        prd_im_fll_dict[key]/=count_map
+    del count_map
+    gc.collect()
+
+    print("\t Scaling prediciton")
+    prob_map_dict = {}
+    for key in  models_to_save:
+        prob_map_dict[key] = scale(prd_im_fll_dict[key])
+        prob_map_dict[key] = (prob_map_dict[key]*255).astype('uint8')
+
+    print("\t Thresholding prediction")
+    threshold = 0.5
+    for key in  models_to_save:
+        np.place(prd_im_fll_dict[key],prd_im_fll_dict[key]>=threshold, 255)
+        np.place(prd_im_fll_dict[key],prd_im_fll_dict[key]<threshold, 0)
+    print("\t Calculated in %f" % ((time.time() - start_time)/60))
+    start_time = time.time()
+
+    print("\t Saving ground truth")
+    save_model_keys = models_to_save
+    save_path =  '-'.join(img_path.split('-')[:-1]+["mask"])+'.'+'.tiff'
+    for key in  models_to_save:
+        print("\t Saving to %s %s" %(save_path,key))
+        tifffile.imsave(os.path.join(save_path, prd_im_fll_dict[key].T, compress=9))
+    print("\t Calculated in %f" % ((time.time() - start_time)/60))
+    start_time = time.time()
+
+    start_time = time.time()
+    print("\t Saving ground truth")
+    os.system('convert ' + save_path + " -compress jpeg -quality 90 -define tiff:tile-geometry=256x256 ptif:"+save_path)
+    print("\t Calculated in %f" % ((time.time() - start_time)/60))
+    start_time = time.time()
+
+    # print("\t Generating scaled version of ground truth")
+    # scaled_prd_im_fll_dict = {}
+    # for key in  models_to_save:
+        # scaled_prd_im_fll_dict[key] = scale(prd_im_fll_dict[key])
+    # del prd_im_fll_dict
+    # gc.collect()
+
+    # mask_im = np.dstack([dataset_obj.get_mask().T]*3).astype('uint8')*255
+    # mask_im = np.dstack([TissueMaskGenerationPatch(im_im)]*3).astype('uint8')*255
+    # for key in  models_to_save:
+        # mask_im[:,:,0] = scaled_prd_im_fll_dict[key]*255
+        # ov_prob_stride = st_im + (np.dstack([prob_map_dict[key]]*3)*255).astype('uint8')
+        # np.place(ov_prob_stride,ov_prob_stride>255,255)
+        # imsave(mask_im,ov_prob_stride,prob_map_dict[key],scaled_prd_im_fll_dict[key],im_im,out=os.path.join(out_dir_dict[key],'ref_'+out_file)+'.png')
+
+# for key in  models_to_save:
+    # with open(os.path.join(out_dir_dict[key],'jacc_scores.txt'), 'a') as f:
+        # f.write("Total,%f\n" %(total_jacc_score_dict[key]/len(sample_ids)))