Switch to unified view

a b/DigiPathAI/new_Segmentation.py
1
from __future__ import absolute_import
2
from __future__ import division
3
from __future__ import print_function
4
5
from datetime import datetime
6
import os
7
import glob
8
import random
9
10
import imgaug
11
from imgaug import augmenters as iaa
12
from PIL import Image
13
from tqdm import tqdm
14
import matplotlib.pyplot as plt
15
16
import openslide
17
import numpy as np
18
import tensorflow as tf
19
from tensorflow.keras import backend as K
20
from tensorflow.keras.models import Model
21
from tensorflow.keras.layers import Input, BatchNormalization, Conv2D, MaxPooling2D, AveragePooling2D, ZeroPadding2D, concatenate, Concatenate, UpSampling2D, Activation
22
from tensorflow.keras.losses import categorical_crossentropy
23
from tensorflow.keras.applications.densenet import DenseNet121
24
from tensorflow.keras.optimizers import Adam
25
from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler, TensorBoard
26
from tensorflow.keras import metrics
27
28
from torch.utils.data import DataLoader, Dataset
29
from torchvision import transforms  # noqa
30
31
import sklearn.metrics
32
import io
33
import itertools
34
from six.moves import range
35
36
import time
37
import argparse 
38
import cv2
39
from skimage.color import rgb2hsv
40
from skimage.filters import threshold_otsu
41
42
import sys
43
sys.path.append(os.path.dirname(os.path.abspath(os.getcwd())))
44
from models.seg_models import get_inception_resnet_v2_unet_softmax, unet_densenet121
45
from models.deeplabv3p_original import Deeplabv3
46
47
# Random Seeds
48
np.random.seed(0)
49
random.seed(0)
50
tf.set_random_seed(0)
51
import gc
52
import pandas as pd
53
54
import tifffile 
55
import skimage.io as io
56
import DigiPathAI
57
58
# Image Helper Functions
59
def imsave(*args, **kwargs):
60
     """
61
     Concatenate the images given in args and saves them as a single image in the specified output destination.
62
     Images should be numpy arrays and have same dimensions along the 0 axis.
63
     imsave(im1,im2,out="sample.png")
64
     """
65
     args_list = list(args)
66
     for i in range(len(args_list)):
67
         if type(args_list[i]) != np.ndarray:
68
             print("Not a numpy array")
69
             return 0
70
         if len(args_list[i].shape) == 2:
71
             args_list[i] = np.dstack([args_list[i]]*3)
72
             if args_list[i].max() == 1:
73
                args_list[i] = args_list[i]*255
74
75
     out_destination = kwargs.get("out",'')
76
     try:
77
         concatenated_arr = np.concatenate(args_list,axis=1)
78
         im = Image.fromarray(np.uint8(concatenated_arr))
79
     except Exception as e:
80
         print(e)
81
         import ipdb; ipdb.set_trace()
82
         return 0
83
     if out_destination:
84
         print("Saving to %s"%(out_destination))
85
         im.save(out_destination)
86
     else:
87
        return im
88
89
def imshow(*args,**kwargs):
90
    """ Handy function to show multiple plots in on row, possibly with different cmaps and titles
91
    Usage:
92
    imshow(img1, title="myPlot")
93
    imshow(img1,img2, title=['title1','title2'])
94
    imshow(img1,img2, cmap='hot')
95
    imshow(img1,img2,cmap=['gray','Blues']) """
96
    cmap = kwargs.get('cmap', 'gray')
97
    title= kwargs.get('title','')
98
    axis_off = kwargs.get('axis_off','')
99
    if len(args)==0:
100
        raise ValueError("No images given to imshow")
101
    elif len(args)==1:
102
        plt.title(title)
103
        plt.imshow(args[0], interpolation='none')
104
    else:
105
        n=len(args)
106
        if type(cmap)==str:
107
            cmap = [cmap]*n
108
        if type(title)==str:
109
            title= [title]*n
110
        plt.figure(figsize=(n*5,10))
111
        for i in range(n):
112
            plt.subplot(1,n,i+1)
113
            plt.title(title[i])
114
            plt.imshow(args[i], cmap[i])
115
            if axis_off: 
116
              plt.axis('off')  
117
    plt.show()
118
def normalize_minmax(data):
119
    """
120
    Normalize contrast across volume
121
    """
122
    _min = np.float(np.min(data))
123
    _max = np.float(np.max(data))
124
    if (_max-_min)!=0:
125
        img = (data - _min) / (_max-_min)
126
    else:
127
        img = np.zeros_like(data)            
128
    return img
129
130
# Functions
131
def BinMorphoProcessMask(mask,level):
132
    """
133
    Binary operation performed on tissue mask
134
    """
135
    close_kernel = np.ones((20, 20), dtype=np.uint8)
136
    image_close = cv2.morphologyEx(np.array(mask), cv2.MORPH_CLOSE, close_kernel)
137
    open_kernel = np.ones((5, 5), dtype=np.uint8)
138
    image_open = cv2.morphologyEx(np.array(image_close), cv2.MORPH_OPEN, open_kernel)
139
    if level == 2:
140
        kernel = np.ones((60, 60), dtype=np.uint8)
141
    elif level == 3:
142
        kernel = np.ones((35, 35), dtype=np.uint8)
143
    else:
144
        raise ValueError
145
    image = cv2.dilate(image_open,kernel,iterations = 1)
146
    return image
147
148
def get_bbox(cont_img, rgb_image=None):
149
    temp_img = np.uint8(cont_img.copy())
150
    _,contours, _ = cv2.findContours(temp_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
151
    rgb_contour = None
152
    if rgb_image is not None:
153
        rgb_contour = rgb_image.copy()
154
        line_color = (0, 0, 255)  # blue color code
155
        cv2.drawContours(rgb_contour, contours, -1, line_color, 2)
156
    bounding_boxes = [cv2.boundingRect(c) for c in contours]
157
    for x, y, h, w in bounding_boxes:
158
        rgb_contour = cv2.rectangle(rgb_contour,(x,y),(x+h,y+w),(0,255,0),2)
159
    return bounding_boxes, rgb_contour
160
161
def get_all_bbox_masks(mask, stride_factor):
162
    """
163
    Find the bbox and corresponding masks
164
    """
165
    bbox_mask = np.zeros_like(mask)
166
    bounding_boxes, _ = get_bbox(mask)
167
    y_size, x_size = bbox_mask.shape
168
    for x, y, h, w in bounding_boxes:
169
        x_min = x - stride_factor
170
        x_max = x + h + stride_factor
171
        y_min = y - stride_factor
172
        y_max = y + w + stride_factor
173
        if x_min < 0: 
174
         x_min = 0
175
        if y_min < 0: 
176
         y_min = 0
177
        if x_max > x_size: 
178
         x_max = x_size - 1
179
        if y_max > y_size: 
180
         y_max = y_size - 1      
181
        bbox_mask[y_min:y_max, x_min:x_max]=1
182
    return bbox_mask
183
184
def get_all_bbox_masks_with_stride(mask, stride_factor):
185
    """
186
    Find the bbox and corresponding masks
187
    """
188
    bbox_mask = np.zeros_like(mask)
189
    bounding_boxes, _ = get_bbox(mask)
190
    y_size, x_size = bbox_mask.shape
191
    for x, y, h, w in bounding_boxes:
192
        x_min = x - stride_factor
193
        x_max = x + h + stride_factor
194
        y_min = y - stride_factor
195
        y_max = y + w + stride_factor
196
        if x_min < 0: 
197
         x_min = 0
198
        if y_min < 0: 
199
         y_min = 0
200
        if x_max > x_size: 
201
         x_max = x_size - 1
202
        if y_max > y_size: 
203
         y_max = y_size - 1      
204
        bbox_mask[y_min:y_max:stride_factor, x_min:x_max:stride_factor]=1
205
        
206
    return bbox_mask
207
208
def find_largest_bbox(mask, stride_factor):
209
    """
210
    Find the largest bounding box encompassing all the blobs
211
    """
212
    y_size, x_size = mask.shape
213
    x, y = np.where(mask==1)
214
    bbox_mask = np.zeros_like(mask)
215
    x_min = np.min(x) - stride_factor
216
    x_max = np.max(x) + stride_factor
217
    y_min = np.min(y) - stride_factor
218
    y_max = np.max(y) + stride_factor
219
    
220
    if x_min < 0: 
221
     x_min = 0
222
    
223
    if y_min < 0: 
224
     y_min = 0
225
226
    if x_max > x_size: 
227
     x_max = x_size - 1
228
    
229
    if y_min > y_size: 
230
     y_max = y_size - 1    
231
    
232
    bbox_mask[x_min:x_max, y_min:y_max]=1
233
    return bbox_mask
234
    
235
def TissueMaskGeneration(slide_obj, level, RGB_min=50):
236
    img_RGB = slide_obj.read_region((0, 0),level,slide_obj.level_dimensions[level])
237
    img_RGB = np.transpose(np.array(img_RGB.convert('RGB')),axes=[1,0,2])
238
    img_HSV = rgb2hsv(img_RGB)
239
    background_R = img_RGB[:, :, 0] > threshold_otsu(img_RGB[:, :, 0])
240
    background_G = img_RGB[:, :, 1] > threshold_otsu(img_RGB[:, :, 1])
241
    background_B = img_RGB[:, :, 2] > threshold_otsu(img_RGB[:, :, 2])
242
    tissue_RGB = np.logical_not(background_R & background_G & background_B)
243
    tissue_S = img_HSV[:, :, 1] > threshold_otsu(img_HSV[:, :, 1])
244
    min_R = img_RGB[:, :, 0] > RGB_min
245
    min_G = img_RGB[:, :, 1] > RGB_min
246
    min_B = img_RGB[:, :, 2] > RGB_min
247
248
    tissue_mask = tissue_S & tissue_RGB & min_R & min_G & min_B
249
    # r = img_RGB[:,:,0] < 235
250
    # g = img_RGB[:,:,1] < 210
251
    # b = img_RGB[:,:,2] < 235
252
    # tissue_mask = np.logical_or(r,np.logical_or(g,b))
253
    return tissue_mask 
254
def TissueMaskGenerationPatch(patchRGB):
255
    '''
256
    Returns mask of tissue that obeys the threshold set by paip
257
    '''
258
    r = patchRGB[:,:,0] < 235
259
    g = patchRGB[:,:,1] < 210
260
    b = patchRGB[:,:,2] < 235
261
    tissue_mask = np.logical_or(r,np.logical_or(g,b))
262
    return tissue_mask 
263
    
264
def TissueMaskGeneration_BIN(slide_obj, level):
265
    img_RGB = np.transpose(np.array(slide_obj.read_region((0, 0),
266
                       level,
267
                       slide_obj.level_dimensions[level]).convert('RGB')),
268
                       axes=[1, 0, 2])    
269
    img_HSV = cv2.cvtColor(img_RGB, cv2.COLOR_BGR2HSV)
270
    img_S = img_HSV[:, :, 1]
271
    _,tissue_mask = cv2.threshold(img_S, 0, 255, cv2.THRESH_BINARY)
272
    return np.array(tissue_mask)
273
274
def TissueMaskGeneration_BIN_OTSU(slide_obj, level):
275
    img_RGB = np.transpose(np.array(slide_obj.read_region((0, 0),
276
                       level,
277
                       slide_obj.level_dimensions[level]).convert('RGB')),
278
                       axes=[1, 0, 2])    
279
    img_HSV = cv2.cvtColor(img_RGB, cv2.COLOR_BGR2HSV)
280
    img_S = img_HSV[:, :, 1]
281
    _,tissue_mask = cv2.threshold(img_S, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
282
    return np.array(tissue_mask)
283
284
def labelthreshold(image, threshold=0.5):
285
    np.place(image,image>=threshold, 1)
286
    np.place(image,image<threshold, 0)
287
    return np.uint8(image)
288
289
def calc_jacc_score(x,y,smoothing=1):
290
    for var in [x,y]:
291
        np.place(var,var==255,1)
292
    
293
    numerator = np.sum(x*y)
294
    denominator = np.sum(np.logical_or(x,y))
295
    return (numerator+smoothing)/(denominator+smoothing)
296
297
# DataLoader Implementation
298
class WSIStridedPatchDataset(Dataset):
299
    """
300
    Data producer that generate all the square grids, e.g. 3x3, of patches,
301
    from a WSI and its tissue mask, and their corresponding indices with
302
    respect to the tissue mask
303
    """
304
    def __init__(self, wsi_path, mask_path, label_path=None, image_size=256,
305
                 normalize=True, flip='NONE', rotate='NONE',                
306
                 level=5, sampling_stride=16, roi_masking=True):
307
        """
308
        Initialize the data producer.
309
310
        Arguments:
311
            wsi_path: string, path to WSI file
312
            mask_path: string, path to mask file in numpy format OR None
313
            label_mask_path: string, path to ground-truth label mask path in tif file or
314
                            None (incase of Normal WSI or test-time)
315
            image_size: int, size of the image before splitting into grid, e.g. 768
316
            patch_size: int, size of the patch, e.g. 256
317
            crop_size: int, size of the final crop that is feed into a CNN,
318
                e.g. 224 for ResNet
319
            normalize: bool, if normalize the [0, 255] pixel values to [-1, 1],
320
                mostly False for debuging purpose
321
            flip: string, 'NONE' or 'FLIP_LEFT_RIGHT' indicating the flip type
322
            rotate: string, 'NONE' or 'ROTATE_90' or 'ROTATE_180' or
323
                'ROTATE_270', indicating the rotate type
324
            level: Level to extract the WSI tissue mask
325
            roi_masking: True: Multiplies the strided WSI with tissue mask to eliminate white spaces,
326
                                False: Ensures inference is done on the entire WSI   
327
            sampling_stride: Number of pixels to skip in the tissue mask, basically it's the overlap
328
                            fraction when patches are extracted from WSI during inference.
329
                            stride=1 -> consecutive pixels are utilized
330
                            stride= image_size/pow(2, level) -> non-overalaping patches 
331
        """
332
        self._wsi_path = wsi_path
333
        self._mask_path = mask_path
334
        self._label_path = label_path
335
        self._image_size = image_size
336
        self._normalize = normalize
337
        self._flip = flip
338
        self._rotate = rotate
339
        self._level = level
340
        self._sampling_stride = sampling_stride
341
        self._roi_masking = roi_masking
342
        
343
        self._preprocess()
344
345
    def _preprocess(self):
346
        self._slide = openslide.OpenSlide(self._wsi_path)
347
        
348
        if self._label_path is not None:
349
            self._label_slide = openslide.OpenSlide(self._label_path)
350
        
351
        X_slide, Y_slide = self._slide.level_dimensions[0]
352
        print("Image dimensions: (%d,%d)" %(X_slide,Y_slide))
353
        
354
        factor = self._sampling_stride
355
356
        
357
        if self._mask_path is not None:
358
            mask_file_name = os.path.basename(self._mask_path)
359
            if mask_file_name.endswith('.tiff'):
360
                mask_obj = openslide.OpenSlide(self._mask_path)
361
                self._mask = np.array(mask_obj.read_region((0, 0),
362
                       self._level,
363
                       mask_obj.level_dimensions[self._level]).convert('L')).T
364
                np.place(self._mask,self._mask>0,255)
365
        else:
366
            # Generate tissue mask on the fly    
367
            
368
            self._mask = TissueMaskGeneration(self._slide, self._level)
369
        # morphological operations ensure the holes are filled in tissue mask
370
        # and minor points are aggregated to form a larger chunk         
371
372
        self._mask = BinMorphoProcessMask(np.uint8(self._mask),self._level)
373
        # self._all_bbox_mask = get_all_bbox_masks(self._mask, factor)
374
        # self._largest_bbox_mask = find_largest_bbox(self._mask, factor)
375
        # self._all_strided_bbox_mask = get_all_bbox_masks_with_stride(self._mask, factor)
376
377
        X_mask, Y_mask = self._mask.shape
378
        # print (self._mask.shape, np.where(self._mask>0))
379
        # imshow(self._mask.T)
380
        # cm17 dataset had issues with images being power's of 2 precisely        
381
#         if X_slide != X_mask or Y_slide != Y_mask:
382
        print('Mask (%d,%d) and Slide(%d,%d) '%(X_mask,Y_mask,X_slide,Y_slide))
383
        if X_slide // X_mask != Y_slide // Y_mask:
384
            raise Exception('Slide/Mask dimension does not match ,'
385
                            ' X_slide / X_mask : {} / {},'
386
                            ' Y_slide / Y_mask : {} / {}'
387
                            .format(X_slide, X_mask, Y_slide, Y_mask))
388
389
        self._resolution = np.round(X_slide * 1.0 / X_mask)
390
        if not np.log2(self._resolution).is_integer():
391
            raise Exception('Resolution (X_slide / X_mask) is not power of 2 :'
392
                            ' {}'.format(self._resolution))
393
             
394
        # all the idces for tissue region from the tissue mask  
395
        self._strided_mask =  np.ones_like(self._mask)
396
        ones_mask = np.zeros_like(self._mask)
397
        ones_mask[::factor, ::factor] = self._strided_mask[::factor, ::factor]
398
        
399
        
400
        if self._roi_masking:
401
            self._strided_mask = ones_mask*self._mask   
402
            # self._strided_mask = ones_mask*self._largest_bbox_mask   
403
            # self._strided_mask = ones_mask*self._all_bbox_mask 
404
            # self._strided_mask = self._all_strided_bbox_mask  
405
        else:
406
            self._strided_mask = ones_mask  
407
        # print (np.count_nonzero(self._strided_mask), np.count_nonzero(self._mask[::factor, ::factor]))
408
        # imshow(self._strided_mask.T, self._mask[::factor, ::factor].T)
409
        # imshow(self._mask.T, self._strided_mask.T)
410
 
411
        self._X_idcs, self._Y_idcs = np.where(self._strided_mask)        
412
        self._idcs_num = len(self._X_idcs)
413
414
    def __len__(self):        
415
        return self._idcs_num 
416
417
    def save_scaled_imgs(self):
418
        scld_dms = self._slide.level_dimensions[self._level]
419
        self._slide_scaled = self._slide.read_region((0,0),self._level,scld_dms)
420
        
421
        if self._label_path is not None:
422
            self._label_scaled = np.array(self._label_slide.read_region((0,0),4,scld_dms).convert('L'))
423
            np.place(self._label_scaled,self._label_scaled>0,255)
424
        
425
    def save_get_mask(self, save_path):
426
        np.save(save_path, self._mask)
427
428
    def get_mask(self):
429
        return self._mask
430
431
    def get_strided_mask(self):
432
        return self._strided_mask
433
    
434
    def __getitem__(self, idx):
435
        x_coord, y_coord = self._X_idcs[idx], self._Y_idcs[idx]
436
        
437
        x_max_dim,y_max_dim = self._slide.level_dimensions[0]
438
439
        # x = int(x_coord * self._resolution)
440
        # y = int(y_coord * self._resolution)    
441
442
        x = int(x_coord * self._resolution - self._image_size//2)
443
        y = int(y_coord * self._resolution - self._image_size//2)    
444
#         x = int(x_coord * self._resolution)
445
#         y = int(y_coord * self._resolution)    
446
        
447
        #If Image goes out of bounds
448
        if x>(x_max_dim - image_size):
449
            x = x_max_dim - image_size
450
        elif x<0:
451
            x = 0
452
        if y>(y_max_dim - image_size):
453
            y = y_max_dim - image_size
454
        elif y<0:
455
            y = 0
456
    
457
        #Converting pil image to np array transposes the w and h
458
        img = np.transpose(self._slide.read_region(
459
            (x, y), 0, (self._image_size, self._image_size)).convert('RGB'),[1,0,2])
460
        
461
        if self._label_path is not None:
462
            label_img = self._label_slide.read_region(
463
                (x, y), 0, (self._image_size, self._image_size)).convert('L')
464
        else:
465
            #print('No label img')
466
            label_img = Image.fromarray(np.zeros((self._image_size, self._image_size), dtype=np.uint8))
467
        
468
        if self._flip == 'FLIP_LEFT_RIGHT':
469
            img = img.transpose(Image.FLIP_LEFT_RIGHT)
470
            label_img = label_img.transpose(Image.FLIP_LEFT_RIGHT)
471
            
472
        if self._rotate == 'ROTATE_90':
473
            img = img.transpose(Image.ROTATE_90)
474
            label_img = label_img.transpose(Image.ROTATE_90)
475
            
476
        if self._rotate == 'ROTATE_180':
477
            img = img.transpose(Image.ROTATE_180)
478
            label_img = label_img.transpose(Image.ROTATE_180)
479
480
        if self._rotate == 'ROTATE_270':
481
            img = img.transpose(Image.ROTATE_270)
482
            label_img = label_img.transpose(Image.ROTATE_270)
483
484
        # PIL image:   H x W x C
485
        img = np.array(img, dtype=np.float32)
486
        label_img = np.array(label_img, dtype=np.uint8)
487
        np.place(label_img, label_img>0, 255)
488
489
        if self._normalize:
490
            img = (img - 128.0)/128.0
491
   
492
        return (img, x, y, label_img)
493
494
def getSegmentation(img_path, 
495
            patch_size  = 256, 
496
            stride_size = 128,
497
            batch_size  = 32,
498
            quick       = True,
499
            tta_list    = None,
500
            crf         = False,
501
            status      = None):
502
    """
503
        Saves the prediction at the same location as the input image
504
        args:
505
            img_path: WSI tiff image path (str)
506
            patch_size: patch size for inference (int)
507
            stride_size: stride to skip during segmentation (int)
508
            batch_size: batch_size during inference (int)
509
            quick: if True; prediction is of single model (bool)
510
                    else: final segmentation is ensemble of 4 different models
511
            tta_list: type of augmentation required during inference
512
                     allowed: ['FLIP_LEFT_RIGHT', 'ROTATE_90', 'ROTATE_180', 'ROTATE_270'] (list(str))
513
            crf: application of conditional random fields in post processing step (bool)
514
            status: required for webserver (json)
515
516
        return :
517
            saves the prediction in given path (in .tiff format)
518
            prediction: predicted segmentation mask
519
520
    """
521
    #Model loading
522
    core_config = tf.ConfigProto()
523
    core_config.gpu_options.allow_growth = False
524
    session =tf.Session(config=core_config) 
525
    K.set_session(session)
526
527
    def load_incep_resnet(model_path):
528
        model = get_inception_resnet_v2_unet_softmax((None, None), weights=None)
529
        model.load_weights(model_path)
530
        print ("Loaded Model Weights %s" % model_path)
531
        return model
532
533
    def load_unet_densenet(model_path):
534
        model = unet_densenet121((None, None), weights=None)
535
        model.load_weights(model_path)
536
        print ("Loaded Model Weights %s" % model_path)
537
        return model
538
539
    def load_deeplabv3(model_path, OS):
540
        model = Deeplabv3(input_shape=(patch_size, patch_size, 3),weights=None,classes=2,activation='softmax',backbone='xception',OS=OS)
541
        model.load_weights(model_path)
542
        print ("Loaded Model Weights %s" % model_path)
543
        return model
544
545
    model_path_root = os.path.join(DigiPathAI.digipathai_folder,'digestpath_models')
546
    model_dict = {}
547
    if quick == True:
548
        model_dict['densenet'] = load_unet_densenet(os.path.join(model_path_root,'densenet.h5'))
549
    else:
550
        model_dict['inception'] = load_incep_resnet(os.path.join(model_path_root,'inception.h5'))
551
        model_dict['densenet'] = load_unet_densenet(os.path.join(model_path_root,'densenet.h5'))
552
        model_dict['deeplab'] = load_deeplabv3(os.path.join(model_path_root,'deeplab.h5'))
553
    
554
    ensemble_key = 'ensemble_key'
555
    model_dict[ensemble_key] = 'ensemble'
556
    models_to_save = [ensemble_key]
557
    model_keys = list(model_dict.keys())
558
        
559
    #Stitcher
560
    start_time = time.time()
561
    wsi_path = img_path
562
    wsi_obj = openslide.OpenSlide(wsi_path)
563
    x_max_dim,y_max_dim = wsi_obj.level_dimensions[0]
564
    count_map = np.zeros(wsi_obj.level_dimensions[0],dtype='uint8')
565
566
    prd_im_fll_dict = {}
567
    memmaps_path = os.path.join(DigiPathAI.digipathai_folder,'memmaps')
568
    os.makedirs(memmaps_path,exist_ok=True)
569
    for key in models_to_save:
570
        prd_im_fll_dict[key] = np.memmap(os.path.join(memmaps_path,'%s.dat'%(key)), dtype=np.float32,mode='w+', shape=(wsi_obj.level_dimensions[0]))
571
572
    #Take the smallest resolution available
573
    level = len(wsi_obj.level_dimensions) -1 
574
    scld_dms = wsi_obj.level_dimensions[-1]
575
    scale_sampling_stride = stride_size//int(wsi_obj.level_downsamples[level])
576
    print("Level %d , stride %d, scale stride %d" %(level,stride_size, scale_sampling_stride))
577
    
578
    scale = lambda x: cv2.resize(x,tuple(reversed(scld_dms))).T
579
    mask_path = None
580
    start_time = time.time()
581
    dataset_obj = WSIStridedPatchDataset(wsi_path, 
582
                                        mask_path=None,
583
                                        label_path=None,
584
                                        image_size=patch_size,
585
                                        normalize=True,
586
                                        flip=None, rotate=None,
587
                                        level=level, sampling_stride=scale_sampling_stride, roi_masking=True)
588
589
    dataloader = DataLoader(dataset_obj, batch_size=batch_size, num_workers=batch_size, drop_last=True)
590
    dataset_obj.save_scaled_imgs()
591
592
    print(dataset_obj.get_mask().shape)
593
    st_im = dataset_obj.get_strided_mask()
594
    mask_im = np.dstack([dataset_obj.get_mask().T]*3).astype('uint8')*255
595
    st_im = np.dstack([dataset_obj.get_strided_mask().T]*3).astype('uint8')*255
596
    im_im = np.array(dataset_obj._slide_scaled.convert('RGB'))
597
    ov_im = mask_im/2 + im_im/2
598
    ov_im_stride = st_im/2 + im_im/2
599
600
    print("Total iterations: %d %d" % (dataloader.__len__(), dataloader.dataset.__len__()))
601
    for i,(data, xes, ys, label) in enumerate(dataloader):
602
        tmp_pls= lambda x: x + patch_size
603
        tmp_mns= lambda x: x 
604
        image_patches = data.cpu().data.numpy()
605
        image_patches = data.cpu().data.numpy()
606
        
607
        pred_map_dict = {}
608
        pred_map_dict[ensemble_key] = 0
609
        for key in model_keys:
610
            pred_map_dict[key] = model_dict[key].predict(image_patches,verbose=0,batch_size=batch_size)
611
            pred_map_dict[ensemble_key]+=pred_map_dict[key]
612
        pred_map_dict[ensemble_key]/=len(model_keys)
613
    
614
        actual_batch_size =  image_patches.shape[0]
615
        for j in range(actual_batch_size):
616
            x = int(xes[j])
617
            y = int(ys[j])
618
619
            wsi_img = image_patches[j]*128+128
620
            patch_mask = TissueMaskGenerationPatch(wsi_img)
621
622
            for key in models_to_save:
623
                prediction = pred_map_dict[key][j,:,:,1]
624
                prediction*=patch_mask
625
                prd_im_fll_dict[key][tmp_mns(x):tmp_pls(x),tmp_mns(y):tmp_pls(y)] += prediction
626
627
            count_map[tmp_mns(x):tmp_pls(x),tmp_mns(y):tmp_pls(y)] += np.ones((patch_size,patch_size),dtype='uint8')
628
        if (i+1)%100==0 or i==0 or i<10:
629
            print("Completed %i Time elapsed %.2f min | Max count %d "%(i,(time.time()-start_time)/60,count_map.max()))
630
        
631
    print("Fully completed %i Time elapsed %.2f min | Max count %d "%(i,(time.time()-start_time)/60,count_map.max()))
632
    start_time = time.time()
633
634
    print("\t Dividing by count_map")
635
    np.place(count_map, count_map==0, 1)
636
    for key in models_to_save:
637
        prd_im_fll_dict[key]/=count_map
638
    del count_map
639
    gc.collect()
640
641
    print("\t Scaling prediciton")
642
    prob_map_dict = {}
643
    for key in  models_to_save:
644
        prob_map_dict[key] = scale(prd_im_fll_dict[key])
645
        prob_map_dict[key] = (prob_map_dict[key]*255).astype('uint8')
646
647
    print("\t Thresholding prediction")
648
    threshold = 0.5
649
    for key in  models_to_save:
650
        np.place(prd_im_fll_dict[key],prd_im_fll_dict[key]>=threshold, 255)
651
        np.place(prd_im_fll_dict[key],prd_im_fll_dict[key]<threshold, 0)
652
    print("\t Calculated in %f" % ((time.time() - start_time)/60))
653
    start_time = time.time()
654
655
    print("\t Saving ground truth")
656
    save_model_keys = models_to_save
657
    save_path =  '-'.join(img_path.split('-')[:-1]+["mask"])+'.'+'.tiff'
658
    for key in  models_to_save:
659
        print("\t Saving to %s %s" %(save_path,key))
660
        tifffile.imsave(os.path.join(save_path, prd_im_fll_dict[key].T, compress=9))
661
    print("\t Calculated in %f" % ((time.time() - start_time)/60))
662
    start_time = time.time()
663
664
    start_time = time.time()
665
    print("\t Saving ground truth")
666
    os.system('convert ' + save_path + " -compress jpeg -quality 90 -define tiff:tile-geometry=256x256 ptif:"+save_path)
667
    print("\t Calculated in %f" % ((time.time() - start_time)/60))
668
    start_time = time.time()
669
670
    # print("\t Generating scaled version of ground truth")
671
    # scaled_prd_im_fll_dict = {}
672
    # for key in  models_to_save:
673
        # scaled_prd_im_fll_dict[key] = scale(prd_im_fll_dict[key])
674
    # del prd_im_fll_dict
675
    # gc.collect()
676
677
    # mask_im = np.dstack([dataset_obj.get_mask().T]*3).astype('uint8')*255
678
    # mask_im = np.dstack([TissueMaskGenerationPatch(im_im)]*3).astype('uint8')*255
679
    # for key in  models_to_save:
680
        # mask_im[:,:,0] = scaled_prd_im_fll_dict[key]*255
681
        # ov_prob_stride = st_im + (np.dstack([prob_map_dict[key]]*3)*255).astype('uint8')
682
        # np.place(ov_prob_stride,ov_prob_stride>255,255)
683
        # imsave(mask_im,ov_prob_stride,prob_map_dict[key],scaled_prd_im_fll_dict[key],im_im,out=os.path.join(out_dir_dict[key],'ref_'+out_file)+'.png')
684
685
# for key in  models_to_save:
686
    # with open(os.path.join(out_dir_dict[key],'jacc_scores.txt'), 'a') as f:
687
        # f.write("Total,%f\n" %(total_jacc_score_dict[key]/len(sample_ids)))