|
a |
|
b/DigiPathAI/new_Segmentation.py |
|
|
1 |
from __future__ import absolute_import |
|
|
2 |
from __future__ import division |
|
|
3 |
from __future__ import print_function |
|
|
4 |
|
|
|
5 |
from datetime import datetime |
|
|
6 |
import os |
|
|
7 |
import glob |
|
|
8 |
import random |
|
|
9 |
|
|
|
10 |
import imgaug |
|
|
11 |
from imgaug import augmenters as iaa |
|
|
12 |
from PIL import Image |
|
|
13 |
from tqdm import tqdm |
|
|
14 |
import matplotlib.pyplot as plt |
|
|
15 |
|
|
|
16 |
import openslide |
|
|
17 |
import numpy as np |
|
|
18 |
import tensorflow as tf |
|
|
19 |
from tensorflow.keras import backend as K |
|
|
20 |
from tensorflow.keras.models import Model |
|
|
21 |
from tensorflow.keras.layers import Input, BatchNormalization, Conv2D, MaxPooling2D, AveragePooling2D, ZeroPadding2D, concatenate, Concatenate, UpSampling2D, Activation |
|
|
22 |
from tensorflow.keras.losses import categorical_crossentropy |
|
|
23 |
from tensorflow.keras.applications.densenet import DenseNet121 |
|
|
24 |
from tensorflow.keras.optimizers import Adam |
|
|
25 |
from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler, TensorBoard |
|
|
26 |
from tensorflow.keras import metrics |
|
|
27 |
|
|
|
28 |
from torch.utils.data import DataLoader, Dataset |
|
|
29 |
from torchvision import transforms # noqa |
|
|
30 |
|
|
|
31 |
import sklearn.metrics |
|
|
32 |
import io |
|
|
33 |
import itertools |
|
|
34 |
from six.moves import range |
|
|
35 |
|
|
|
36 |
import time |
|
|
37 |
import argparse |
|
|
38 |
import cv2 |
|
|
39 |
from skimage.color import rgb2hsv |
|
|
40 |
from skimage.filters import threshold_otsu |
|
|
41 |
|
|
|
42 |
import sys |
|
|
43 |
sys.path.append(os.path.dirname(os.path.abspath(os.getcwd()))) |
|
|
44 |
from models.seg_models import get_inception_resnet_v2_unet_softmax, unet_densenet121 |
|
|
45 |
from models.deeplabv3p_original import Deeplabv3 |
|
|
46 |
|
|
|
47 |
# Random Seeds |
|
|
48 |
np.random.seed(0) |
|
|
49 |
random.seed(0) |
|
|
50 |
tf.set_random_seed(0) |
|
|
51 |
import gc |
|
|
52 |
import pandas as pd |
|
|
53 |
|
|
|
54 |
import tifffile |
|
|
55 |
import skimage.io as io |
|
|
56 |
import DigiPathAI |
|
|
57 |
|
|
|
58 |
# Image Helper Functions |
|
|
59 |
def imsave(*args, **kwargs): |
|
|
60 |
""" |
|
|
61 |
Concatenate the images given in args and saves them as a single image in the specified output destination. |
|
|
62 |
Images should be numpy arrays and have same dimensions along the 0 axis. |
|
|
63 |
imsave(im1,im2,out="sample.png") |
|
|
64 |
""" |
|
|
65 |
args_list = list(args) |
|
|
66 |
for i in range(len(args_list)): |
|
|
67 |
if type(args_list[i]) != np.ndarray: |
|
|
68 |
print("Not a numpy array") |
|
|
69 |
return 0 |
|
|
70 |
if len(args_list[i].shape) == 2: |
|
|
71 |
args_list[i] = np.dstack([args_list[i]]*3) |
|
|
72 |
if args_list[i].max() == 1: |
|
|
73 |
args_list[i] = args_list[i]*255 |
|
|
74 |
|
|
|
75 |
out_destination = kwargs.get("out",'') |
|
|
76 |
try: |
|
|
77 |
concatenated_arr = np.concatenate(args_list,axis=1) |
|
|
78 |
im = Image.fromarray(np.uint8(concatenated_arr)) |
|
|
79 |
except Exception as e: |
|
|
80 |
print(e) |
|
|
81 |
import ipdb; ipdb.set_trace() |
|
|
82 |
return 0 |
|
|
83 |
if out_destination: |
|
|
84 |
print("Saving to %s"%(out_destination)) |
|
|
85 |
im.save(out_destination) |
|
|
86 |
else: |
|
|
87 |
return im |
|
|
88 |
|
|
|
89 |
def imshow(*args,**kwargs): |
|
|
90 |
""" Handy function to show multiple plots in on row, possibly with different cmaps and titles |
|
|
91 |
Usage: |
|
|
92 |
imshow(img1, title="myPlot") |
|
|
93 |
imshow(img1,img2, title=['title1','title2']) |
|
|
94 |
imshow(img1,img2, cmap='hot') |
|
|
95 |
imshow(img1,img2,cmap=['gray','Blues']) """ |
|
|
96 |
cmap = kwargs.get('cmap', 'gray') |
|
|
97 |
title= kwargs.get('title','') |
|
|
98 |
axis_off = kwargs.get('axis_off','') |
|
|
99 |
if len(args)==0: |
|
|
100 |
raise ValueError("No images given to imshow") |
|
|
101 |
elif len(args)==1: |
|
|
102 |
plt.title(title) |
|
|
103 |
plt.imshow(args[0], interpolation='none') |
|
|
104 |
else: |
|
|
105 |
n=len(args) |
|
|
106 |
if type(cmap)==str: |
|
|
107 |
cmap = [cmap]*n |
|
|
108 |
if type(title)==str: |
|
|
109 |
title= [title]*n |
|
|
110 |
plt.figure(figsize=(n*5,10)) |
|
|
111 |
for i in range(n): |
|
|
112 |
plt.subplot(1,n,i+1) |
|
|
113 |
plt.title(title[i]) |
|
|
114 |
plt.imshow(args[i], cmap[i]) |
|
|
115 |
if axis_off: |
|
|
116 |
plt.axis('off') |
|
|
117 |
plt.show() |
|
|
118 |
def normalize_minmax(data): |
|
|
119 |
""" |
|
|
120 |
Normalize contrast across volume |
|
|
121 |
""" |
|
|
122 |
_min = np.float(np.min(data)) |
|
|
123 |
_max = np.float(np.max(data)) |
|
|
124 |
if (_max-_min)!=0: |
|
|
125 |
img = (data - _min) / (_max-_min) |
|
|
126 |
else: |
|
|
127 |
img = np.zeros_like(data) |
|
|
128 |
return img |
|
|
129 |
|
|
|
130 |
# Functions |
|
|
131 |
def BinMorphoProcessMask(mask,level): |
|
|
132 |
""" |
|
|
133 |
Binary operation performed on tissue mask |
|
|
134 |
""" |
|
|
135 |
close_kernel = np.ones((20, 20), dtype=np.uint8) |
|
|
136 |
image_close = cv2.morphologyEx(np.array(mask), cv2.MORPH_CLOSE, close_kernel) |
|
|
137 |
open_kernel = np.ones((5, 5), dtype=np.uint8) |
|
|
138 |
image_open = cv2.morphologyEx(np.array(image_close), cv2.MORPH_OPEN, open_kernel) |
|
|
139 |
if level == 2: |
|
|
140 |
kernel = np.ones((60, 60), dtype=np.uint8) |
|
|
141 |
elif level == 3: |
|
|
142 |
kernel = np.ones((35, 35), dtype=np.uint8) |
|
|
143 |
else: |
|
|
144 |
raise ValueError |
|
|
145 |
image = cv2.dilate(image_open,kernel,iterations = 1) |
|
|
146 |
return image |
|
|
147 |
|
|
|
148 |
def get_bbox(cont_img, rgb_image=None): |
|
|
149 |
temp_img = np.uint8(cont_img.copy()) |
|
|
150 |
_,contours, _ = cv2.findContours(temp_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
|
151 |
rgb_contour = None |
|
|
152 |
if rgb_image is not None: |
|
|
153 |
rgb_contour = rgb_image.copy() |
|
|
154 |
line_color = (0, 0, 255) # blue color code |
|
|
155 |
cv2.drawContours(rgb_contour, contours, -1, line_color, 2) |
|
|
156 |
bounding_boxes = [cv2.boundingRect(c) for c in contours] |
|
|
157 |
for x, y, h, w in bounding_boxes: |
|
|
158 |
rgb_contour = cv2.rectangle(rgb_contour,(x,y),(x+h,y+w),(0,255,0),2) |
|
|
159 |
return bounding_boxes, rgb_contour |
|
|
160 |
|
|
|
161 |
def get_all_bbox_masks(mask, stride_factor): |
|
|
162 |
""" |
|
|
163 |
Find the bbox and corresponding masks |
|
|
164 |
""" |
|
|
165 |
bbox_mask = np.zeros_like(mask) |
|
|
166 |
bounding_boxes, _ = get_bbox(mask) |
|
|
167 |
y_size, x_size = bbox_mask.shape |
|
|
168 |
for x, y, h, w in bounding_boxes: |
|
|
169 |
x_min = x - stride_factor |
|
|
170 |
x_max = x + h + stride_factor |
|
|
171 |
y_min = y - stride_factor |
|
|
172 |
y_max = y + w + stride_factor |
|
|
173 |
if x_min < 0: |
|
|
174 |
x_min = 0 |
|
|
175 |
if y_min < 0: |
|
|
176 |
y_min = 0 |
|
|
177 |
if x_max > x_size: |
|
|
178 |
x_max = x_size - 1 |
|
|
179 |
if y_max > y_size: |
|
|
180 |
y_max = y_size - 1 |
|
|
181 |
bbox_mask[y_min:y_max, x_min:x_max]=1 |
|
|
182 |
return bbox_mask |
|
|
183 |
|
|
|
184 |
def get_all_bbox_masks_with_stride(mask, stride_factor): |
|
|
185 |
""" |
|
|
186 |
Find the bbox and corresponding masks |
|
|
187 |
""" |
|
|
188 |
bbox_mask = np.zeros_like(mask) |
|
|
189 |
bounding_boxes, _ = get_bbox(mask) |
|
|
190 |
y_size, x_size = bbox_mask.shape |
|
|
191 |
for x, y, h, w in bounding_boxes: |
|
|
192 |
x_min = x - stride_factor |
|
|
193 |
x_max = x + h + stride_factor |
|
|
194 |
y_min = y - stride_factor |
|
|
195 |
y_max = y + w + stride_factor |
|
|
196 |
if x_min < 0: |
|
|
197 |
x_min = 0 |
|
|
198 |
if y_min < 0: |
|
|
199 |
y_min = 0 |
|
|
200 |
if x_max > x_size: |
|
|
201 |
x_max = x_size - 1 |
|
|
202 |
if y_max > y_size: |
|
|
203 |
y_max = y_size - 1 |
|
|
204 |
bbox_mask[y_min:y_max:stride_factor, x_min:x_max:stride_factor]=1 |
|
|
205 |
|
|
|
206 |
return bbox_mask |
|
|
207 |
|
|
|
208 |
def find_largest_bbox(mask, stride_factor): |
|
|
209 |
""" |
|
|
210 |
Find the largest bounding box encompassing all the blobs |
|
|
211 |
""" |
|
|
212 |
y_size, x_size = mask.shape |
|
|
213 |
x, y = np.where(mask==1) |
|
|
214 |
bbox_mask = np.zeros_like(mask) |
|
|
215 |
x_min = np.min(x) - stride_factor |
|
|
216 |
x_max = np.max(x) + stride_factor |
|
|
217 |
y_min = np.min(y) - stride_factor |
|
|
218 |
y_max = np.max(y) + stride_factor |
|
|
219 |
|
|
|
220 |
if x_min < 0: |
|
|
221 |
x_min = 0 |
|
|
222 |
|
|
|
223 |
if y_min < 0: |
|
|
224 |
y_min = 0 |
|
|
225 |
|
|
|
226 |
if x_max > x_size: |
|
|
227 |
x_max = x_size - 1 |
|
|
228 |
|
|
|
229 |
if y_min > y_size: |
|
|
230 |
y_max = y_size - 1 |
|
|
231 |
|
|
|
232 |
bbox_mask[x_min:x_max, y_min:y_max]=1 |
|
|
233 |
return bbox_mask |
|
|
234 |
|
|
|
235 |
def TissueMaskGeneration(slide_obj, level, RGB_min=50): |
|
|
236 |
img_RGB = slide_obj.read_region((0, 0),level,slide_obj.level_dimensions[level]) |
|
|
237 |
img_RGB = np.transpose(np.array(img_RGB.convert('RGB')),axes=[1,0,2]) |
|
|
238 |
img_HSV = rgb2hsv(img_RGB) |
|
|
239 |
background_R = img_RGB[:, :, 0] > threshold_otsu(img_RGB[:, :, 0]) |
|
|
240 |
background_G = img_RGB[:, :, 1] > threshold_otsu(img_RGB[:, :, 1]) |
|
|
241 |
background_B = img_RGB[:, :, 2] > threshold_otsu(img_RGB[:, :, 2]) |
|
|
242 |
tissue_RGB = np.logical_not(background_R & background_G & background_B) |
|
|
243 |
tissue_S = img_HSV[:, :, 1] > threshold_otsu(img_HSV[:, :, 1]) |
|
|
244 |
min_R = img_RGB[:, :, 0] > RGB_min |
|
|
245 |
min_G = img_RGB[:, :, 1] > RGB_min |
|
|
246 |
min_B = img_RGB[:, :, 2] > RGB_min |
|
|
247 |
|
|
|
248 |
tissue_mask = tissue_S & tissue_RGB & min_R & min_G & min_B |
|
|
249 |
# r = img_RGB[:,:,0] < 235 |
|
|
250 |
# g = img_RGB[:,:,1] < 210 |
|
|
251 |
# b = img_RGB[:,:,2] < 235 |
|
|
252 |
# tissue_mask = np.logical_or(r,np.logical_or(g,b)) |
|
|
253 |
return tissue_mask |
|
|
254 |
def TissueMaskGenerationPatch(patchRGB): |
|
|
255 |
''' |
|
|
256 |
Returns mask of tissue that obeys the threshold set by paip |
|
|
257 |
''' |
|
|
258 |
r = patchRGB[:,:,0] < 235 |
|
|
259 |
g = patchRGB[:,:,1] < 210 |
|
|
260 |
b = patchRGB[:,:,2] < 235 |
|
|
261 |
tissue_mask = np.logical_or(r,np.logical_or(g,b)) |
|
|
262 |
return tissue_mask |
|
|
263 |
|
|
|
264 |
def TissueMaskGeneration_BIN(slide_obj, level): |
|
|
265 |
img_RGB = np.transpose(np.array(slide_obj.read_region((0, 0), |
|
|
266 |
level, |
|
|
267 |
slide_obj.level_dimensions[level]).convert('RGB')), |
|
|
268 |
axes=[1, 0, 2]) |
|
|
269 |
img_HSV = cv2.cvtColor(img_RGB, cv2.COLOR_BGR2HSV) |
|
|
270 |
img_S = img_HSV[:, :, 1] |
|
|
271 |
_,tissue_mask = cv2.threshold(img_S, 0, 255, cv2.THRESH_BINARY) |
|
|
272 |
return np.array(tissue_mask) |
|
|
273 |
|
|
|
274 |
def TissueMaskGeneration_BIN_OTSU(slide_obj, level): |
|
|
275 |
img_RGB = np.transpose(np.array(slide_obj.read_region((0, 0), |
|
|
276 |
level, |
|
|
277 |
slide_obj.level_dimensions[level]).convert('RGB')), |
|
|
278 |
axes=[1, 0, 2]) |
|
|
279 |
img_HSV = cv2.cvtColor(img_RGB, cv2.COLOR_BGR2HSV) |
|
|
280 |
img_S = img_HSV[:, :, 1] |
|
|
281 |
_,tissue_mask = cv2.threshold(img_S, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU) |
|
|
282 |
return np.array(tissue_mask) |
|
|
283 |
|
|
|
284 |
def labelthreshold(image, threshold=0.5): |
|
|
285 |
np.place(image,image>=threshold, 1) |
|
|
286 |
np.place(image,image<threshold, 0) |
|
|
287 |
return np.uint8(image) |
|
|
288 |
|
|
|
289 |
def calc_jacc_score(x,y,smoothing=1): |
|
|
290 |
for var in [x,y]: |
|
|
291 |
np.place(var,var==255,1) |
|
|
292 |
|
|
|
293 |
numerator = np.sum(x*y) |
|
|
294 |
denominator = np.sum(np.logical_or(x,y)) |
|
|
295 |
return (numerator+smoothing)/(denominator+smoothing) |
|
|
296 |
|
|
|
297 |
# DataLoader Implementation |
|
|
298 |
class WSIStridedPatchDataset(Dataset): |
|
|
299 |
""" |
|
|
300 |
Data producer that generate all the square grids, e.g. 3x3, of patches, |
|
|
301 |
from a WSI and its tissue mask, and their corresponding indices with |
|
|
302 |
respect to the tissue mask |
|
|
303 |
""" |
|
|
304 |
def __init__(self, wsi_path, mask_path, label_path=None, image_size=256, |
|
|
305 |
normalize=True, flip='NONE', rotate='NONE', |
|
|
306 |
level=5, sampling_stride=16, roi_masking=True): |
|
|
307 |
""" |
|
|
308 |
Initialize the data producer. |
|
|
309 |
|
|
|
310 |
Arguments: |
|
|
311 |
wsi_path: string, path to WSI file |
|
|
312 |
mask_path: string, path to mask file in numpy format OR None |
|
|
313 |
label_mask_path: string, path to ground-truth label mask path in tif file or |
|
|
314 |
None (incase of Normal WSI or test-time) |
|
|
315 |
image_size: int, size of the image before splitting into grid, e.g. 768 |
|
|
316 |
patch_size: int, size of the patch, e.g. 256 |
|
|
317 |
crop_size: int, size of the final crop that is feed into a CNN, |
|
|
318 |
e.g. 224 for ResNet |
|
|
319 |
normalize: bool, if normalize the [0, 255] pixel values to [-1, 1], |
|
|
320 |
mostly False for debuging purpose |
|
|
321 |
flip: string, 'NONE' or 'FLIP_LEFT_RIGHT' indicating the flip type |
|
|
322 |
rotate: string, 'NONE' or 'ROTATE_90' or 'ROTATE_180' or |
|
|
323 |
'ROTATE_270', indicating the rotate type |
|
|
324 |
level: Level to extract the WSI tissue mask |
|
|
325 |
roi_masking: True: Multiplies the strided WSI with tissue mask to eliminate white spaces, |
|
|
326 |
False: Ensures inference is done on the entire WSI |
|
|
327 |
sampling_stride: Number of pixels to skip in the tissue mask, basically it's the overlap |
|
|
328 |
fraction when patches are extracted from WSI during inference. |
|
|
329 |
stride=1 -> consecutive pixels are utilized |
|
|
330 |
stride= image_size/pow(2, level) -> non-overalaping patches |
|
|
331 |
""" |
|
|
332 |
self._wsi_path = wsi_path |
|
|
333 |
self._mask_path = mask_path |
|
|
334 |
self._label_path = label_path |
|
|
335 |
self._image_size = image_size |
|
|
336 |
self._normalize = normalize |
|
|
337 |
self._flip = flip |
|
|
338 |
self._rotate = rotate |
|
|
339 |
self._level = level |
|
|
340 |
self._sampling_stride = sampling_stride |
|
|
341 |
self._roi_masking = roi_masking |
|
|
342 |
|
|
|
343 |
self._preprocess() |
|
|
344 |
|
|
|
345 |
def _preprocess(self): |
|
|
346 |
self._slide = openslide.OpenSlide(self._wsi_path) |
|
|
347 |
|
|
|
348 |
if self._label_path is not None: |
|
|
349 |
self._label_slide = openslide.OpenSlide(self._label_path) |
|
|
350 |
|
|
|
351 |
X_slide, Y_slide = self._slide.level_dimensions[0] |
|
|
352 |
print("Image dimensions: (%d,%d)" %(X_slide,Y_slide)) |
|
|
353 |
|
|
|
354 |
factor = self._sampling_stride |
|
|
355 |
|
|
|
356 |
|
|
|
357 |
if self._mask_path is not None: |
|
|
358 |
mask_file_name = os.path.basename(self._mask_path) |
|
|
359 |
if mask_file_name.endswith('.tiff'): |
|
|
360 |
mask_obj = openslide.OpenSlide(self._mask_path) |
|
|
361 |
self._mask = np.array(mask_obj.read_region((0, 0), |
|
|
362 |
self._level, |
|
|
363 |
mask_obj.level_dimensions[self._level]).convert('L')).T |
|
|
364 |
np.place(self._mask,self._mask>0,255) |
|
|
365 |
else: |
|
|
366 |
# Generate tissue mask on the fly |
|
|
367 |
|
|
|
368 |
self._mask = TissueMaskGeneration(self._slide, self._level) |
|
|
369 |
# morphological operations ensure the holes are filled in tissue mask |
|
|
370 |
# and minor points are aggregated to form a larger chunk |
|
|
371 |
|
|
|
372 |
self._mask = BinMorphoProcessMask(np.uint8(self._mask),self._level) |
|
|
373 |
# self._all_bbox_mask = get_all_bbox_masks(self._mask, factor) |
|
|
374 |
# self._largest_bbox_mask = find_largest_bbox(self._mask, factor) |
|
|
375 |
# self._all_strided_bbox_mask = get_all_bbox_masks_with_stride(self._mask, factor) |
|
|
376 |
|
|
|
377 |
X_mask, Y_mask = self._mask.shape |
|
|
378 |
# print (self._mask.shape, np.where(self._mask>0)) |
|
|
379 |
# imshow(self._mask.T) |
|
|
380 |
# cm17 dataset had issues with images being power's of 2 precisely |
|
|
381 |
# if X_slide != X_mask or Y_slide != Y_mask: |
|
|
382 |
print('Mask (%d,%d) and Slide(%d,%d) '%(X_mask,Y_mask,X_slide,Y_slide)) |
|
|
383 |
if X_slide // X_mask != Y_slide // Y_mask: |
|
|
384 |
raise Exception('Slide/Mask dimension does not match ,' |
|
|
385 |
' X_slide / X_mask : {} / {},' |
|
|
386 |
' Y_slide / Y_mask : {} / {}' |
|
|
387 |
.format(X_slide, X_mask, Y_slide, Y_mask)) |
|
|
388 |
|
|
|
389 |
self._resolution = np.round(X_slide * 1.0 / X_mask) |
|
|
390 |
if not np.log2(self._resolution).is_integer(): |
|
|
391 |
raise Exception('Resolution (X_slide / X_mask) is not power of 2 :' |
|
|
392 |
' {}'.format(self._resolution)) |
|
|
393 |
|
|
|
394 |
# all the idces for tissue region from the tissue mask |
|
|
395 |
self._strided_mask = np.ones_like(self._mask) |
|
|
396 |
ones_mask = np.zeros_like(self._mask) |
|
|
397 |
ones_mask[::factor, ::factor] = self._strided_mask[::factor, ::factor] |
|
|
398 |
|
|
|
399 |
|
|
|
400 |
if self._roi_masking: |
|
|
401 |
self._strided_mask = ones_mask*self._mask |
|
|
402 |
# self._strided_mask = ones_mask*self._largest_bbox_mask |
|
|
403 |
# self._strided_mask = ones_mask*self._all_bbox_mask |
|
|
404 |
# self._strided_mask = self._all_strided_bbox_mask |
|
|
405 |
else: |
|
|
406 |
self._strided_mask = ones_mask |
|
|
407 |
# print (np.count_nonzero(self._strided_mask), np.count_nonzero(self._mask[::factor, ::factor])) |
|
|
408 |
# imshow(self._strided_mask.T, self._mask[::factor, ::factor].T) |
|
|
409 |
# imshow(self._mask.T, self._strided_mask.T) |
|
|
410 |
|
|
|
411 |
self._X_idcs, self._Y_idcs = np.where(self._strided_mask) |
|
|
412 |
self._idcs_num = len(self._X_idcs) |
|
|
413 |
|
|
|
414 |
def __len__(self): |
|
|
415 |
return self._idcs_num |
|
|
416 |
|
|
|
417 |
def save_scaled_imgs(self): |
|
|
418 |
scld_dms = self._slide.level_dimensions[self._level] |
|
|
419 |
self._slide_scaled = self._slide.read_region((0,0),self._level,scld_dms) |
|
|
420 |
|
|
|
421 |
if self._label_path is not None: |
|
|
422 |
self._label_scaled = np.array(self._label_slide.read_region((0,0),4,scld_dms).convert('L')) |
|
|
423 |
np.place(self._label_scaled,self._label_scaled>0,255) |
|
|
424 |
|
|
|
425 |
def save_get_mask(self, save_path): |
|
|
426 |
np.save(save_path, self._mask) |
|
|
427 |
|
|
|
428 |
def get_mask(self): |
|
|
429 |
return self._mask |
|
|
430 |
|
|
|
431 |
def get_strided_mask(self): |
|
|
432 |
return self._strided_mask |
|
|
433 |
|
|
|
434 |
def __getitem__(self, idx): |
|
|
435 |
x_coord, y_coord = self._X_idcs[idx], self._Y_idcs[idx] |
|
|
436 |
|
|
|
437 |
x_max_dim,y_max_dim = self._slide.level_dimensions[0] |
|
|
438 |
|
|
|
439 |
# x = int(x_coord * self._resolution) |
|
|
440 |
# y = int(y_coord * self._resolution) |
|
|
441 |
|
|
|
442 |
x = int(x_coord * self._resolution - self._image_size//2) |
|
|
443 |
y = int(y_coord * self._resolution - self._image_size//2) |
|
|
444 |
# x = int(x_coord * self._resolution) |
|
|
445 |
# y = int(y_coord * self._resolution) |
|
|
446 |
|
|
|
447 |
#If Image goes out of bounds |
|
|
448 |
if x>(x_max_dim - image_size): |
|
|
449 |
x = x_max_dim - image_size |
|
|
450 |
elif x<0: |
|
|
451 |
x = 0 |
|
|
452 |
if y>(y_max_dim - image_size): |
|
|
453 |
y = y_max_dim - image_size |
|
|
454 |
elif y<0: |
|
|
455 |
y = 0 |
|
|
456 |
|
|
|
457 |
#Converting pil image to np array transposes the w and h |
|
|
458 |
img = np.transpose(self._slide.read_region( |
|
|
459 |
(x, y), 0, (self._image_size, self._image_size)).convert('RGB'),[1,0,2]) |
|
|
460 |
|
|
|
461 |
if self._label_path is not None: |
|
|
462 |
label_img = self._label_slide.read_region( |
|
|
463 |
(x, y), 0, (self._image_size, self._image_size)).convert('L') |
|
|
464 |
else: |
|
|
465 |
#print('No label img') |
|
|
466 |
label_img = Image.fromarray(np.zeros((self._image_size, self._image_size), dtype=np.uint8)) |
|
|
467 |
|
|
|
468 |
if self._flip == 'FLIP_LEFT_RIGHT': |
|
|
469 |
img = img.transpose(Image.FLIP_LEFT_RIGHT) |
|
|
470 |
label_img = label_img.transpose(Image.FLIP_LEFT_RIGHT) |
|
|
471 |
|
|
|
472 |
if self._rotate == 'ROTATE_90': |
|
|
473 |
img = img.transpose(Image.ROTATE_90) |
|
|
474 |
label_img = label_img.transpose(Image.ROTATE_90) |
|
|
475 |
|
|
|
476 |
if self._rotate == 'ROTATE_180': |
|
|
477 |
img = img.transpose(Image.ROTATE_180) |
|
|
478 |
label_img = label_img.transpose(Image.ROTATE_180) |
|
|
479 |
|
|
|
480 |
if self._rotate == 'ROTATE_270': |
|
|
481 |
img = img.transpose(Image.ROTATE_270) |
|
|
482 |
label_img = label_img.transpose(Image.ROTATE_270) |
|
|
483 |
|
|
|
484 |
# PIL image: H x W x C |
|
|
485 |
img = np.array(img, dtype=np.float32) |
|
|
486 |
label_img = np.array(label_img, dtype=np.uint8) |
|
|
487 |
np.place(label_img, label_img>0, 255) |
|
|
488 |
|
|
|
489 |
if self._normalize: |
|
|
490 |
img = (img - 128.0)/128.0 |
|
|
491 |
|
|
|
492 |
return (img, x, y, label_img) |
|
|
493 |
|
|
|
494 |
def getSegmentation(img_path, |
|
|
495 |
patch_size = 256, |
|
|
496 |
stride_size = 128, |
|
|
497 |
batch_size = 32, |
|
|
498 |
quick = True, |
|
|
499 |
tta_list = None, |
|
|
500 |
crf = False, |
|
|
501 |
status = None): |
|
|
502 |
""" |
|
|
503 |
Saves the prediction at the same location as the input image |
|
|
504 |
args: |
|
|
505 |
img_path: WSI tiff image path (str) |
|
|
506 |
patch_size: patch size for inference (int) |
|
|
507 |
stride_size: stride to skip during segmentation (int) |
|
|
508 |
batch_size: batch_size during inference (int) |
|
|
509 |
quick: if True; prediction is of single model (bool) |
|
|
510 |
else: final segmentation is ensemble of 4 different models |
|
|
511 |
tta_list: type of augmentation required during inference |
|
|
512 |
allowed: ['FLIP_LEFT_RIGHT', 'ROTATE_90', 'ROTATE_180', 'ROTATE_270'] (list(str)) |
|
|
513 |
crf: application of conditional random fields in post processing step (bool) |
|
|
514 |
status: required for webserver (json) |
|
|
515 |
|
|
|
516 |
return : |
|
|
517 |
saves the prediction in given path (in .tiff format) |
|
|
518 |
prediction: predicted segmentation mask |
|
|
519 |
|
|
|
520 |
""" |
|
|
521 |
#Model loading |
|
|
522 |
core_config = tf.ConfigProto() |
|
|
523 |
core_config.gpu_options.allow_growth = False |
|
|
524 |
session =tf.Session(config=core_config) |
|
|
525 |
K.set_session(session) |
|
|
526 |
|
|
|
527 |
def load_incep_resnet(model_path): |
|
|
528 |
model = get_inception_resnet_v2_unet_softmax((None, None), weights=None) |
|
|
529 |
model.load_weights(model_path) |
|
|
530 |
print ("Loaded Model Weights %s" % model_path) |
|
|
531 |
return model |
|
|
532 |
|
|
|
533 |
def load_unet_densenet(model_path): |
|
|
534 |
model = unet_densenet121((None, None), weights=None) |
|
|
535 |
model.load_weights(model_path) |
|
|
536 |
print ("Loaded Model Weights %s" % model_path) |
|
|
537 |
return model |
|
|
538 |
|
|
|
539 |
def load_deeplabv3(model_path, OS): |
|
|
540 |
model = Deeplabv3(input_shape=(patch_size, patch_size, 3),weights=None,classes=2,activation='softmax',backbone='xception',OS=OS) |
|
|
541 |
model.load_weights(model_path) |
|
|
542 |
print ("Loaded Model Weights %s" % model_path) |
|
|
543 |
return model |
|
|
544 |
|
|
|
545 |
model_path_root = os.path.join(DigiPathAI.digipathai_folder,'digestpath_models') |
|
|
546 |
model_dict = {} |
|
|
547 |
if quick == True: |
|
|
548 |
model_dict['densenet'] = load_unet_densenet(os.path.join(model_path_root,'densenet.h5')) |
|
|
549 |
else: |
|
|
550 |
model_dict['inception'] = load_incep_resnet(os.path.join(model_path_root,'inception.h5')) |
|
|
551 |
model_dict['densenet'] = load_unet_densenet(os.path.join(model_path_root,'densenet.h5')) |
|
|
552 |
model_dict['deeplab'] = load_deeplabv3(os.path.join(model_path_root,'deeplab.h5')) |
|
|
553 |
|
|
|
554 |
ensemble_key = 'ensemble_key' |
|
|
555 |
model_dict[ensemble_key] = 'ensemble' |
|
|
556 |
models_to_save = [ensemble_key] |
|
|
557 |
model_keys = list(model_dict.keys()) |
|
|
558 |
|
|
|
559 |
#Stitcher |
|
|
560 |
start_time = time.time() |
|
|
561 |
wsi_path = img_path |
|
|
562 |
wsi_obj = openslide.OpenSlide(wsi_path) |
|
|
563 |
x_max_dim,y_max_dim = wsi_obj.level_dimensions[0] |
|
|
564 |
count_map = np.zeros(wsi_obj.level_dimensions[0],dtype='uint8') |
|
|
565 |
|
|
|
566 |
prd_im_fll_dict = {} |
|
|
567 |
memmaps_path = os.path.join(DigiPathAI.digipathai_folder,'memmaps') |
|
|
568 |
os.makedirs(memmaps_path,exist_ok=True) |
|
|
569 |
for key in models_to_save: |
|
|
570 |
prd_im_fll_dict[key] = np.memmap(os.path.join(memmaps_path,'%s.dat'%(key)), dtype=np.float32,mode='w+', shape=(wsi_obj.level_dimensions[0])) |
|
|
571 |
|
|
|
572 |
#Take the smallest resolution available |
|
|
573 |
level = len(wsi_obj.level_dimensions) -1 |
|
|
574 |
scld_dms = wsi_obj.level_dimensions[-1] |
|
|
575 |
scale_sampling_stride = stride_size//int(wsi_obj.level_downsamples[level]) |
|
|
576 |
print("Level %d , stride %d, scale stride %d" %(level,stride_size, scale_sampling_stride)) |
|
|
577 |
|
|
|
578 |
scale = lambda x: cv2.resize(x,tuple(reversed(scld_dms))).T |
|
|
579 |
mask_path = None |
|
|
580 |
start_time = time.time() |
|
|
581 |
dataset_obj = WSIStridedPatchDataset(wsi_path, |
|
|
582 |
mask_path=None, |
|
|
583 |
label_path=None, |
|
|
584 |
image_size=patch_size, |
|
|
585 |
normalize=True, |
|
|
586 |
flip=None, rotate=None, |
|
|
587 |
level=level, sampling_stride=scale_sampling_stride, roi_masking=True) |
|
|
588 |
|
|
|
589 |
dataloader = DataLoader(dataset_obj, batch_size=batch_size, num_workers=batch_size, drop_last=True) |
|
|
590 |
dataset_obj.save_scaled_imgs() |
|
|
591 |
|
|
|
592 |
print(dataset_obj.get_mask().shape) |
|
|
593 |
st_im = dataset_obj.get_strided_mask() |
|
|
594 |
mask_im = np.dstack([dataset_obj.get_mask().T]*3).astype('uint8')*255 |
|
|
595 |
st_im = np.dstack([dataset_obj.get_strided_mask().T]*3).astype('uint8')*255 |
|
|
596 |
im_im = np.array(dataset_obj._slide_scaled.convert('RGB')) |
|
|
597 |
ov_im = mask_im/2 + im_im/2 |
|
|
598 |
ov_im_stride = st_im/2 + im_im/2 |
|
|
599 |
|
|
|
600 |
print("Total iterations: %d %d" % (dataloader.__len__(), dataloader.dataset.__len__())) |
|
|
601 |
for i,(data, xes, ys, label) in enumerate(dataloader): |
|
|
602 |
tmp_pls= lambda x: x + patch_size |
|
|
603 |
tmp_mns= lambda x: x |
|
|
604 |
image_patches = data.cpu().data.numpy() |
|
|
605 |
image_patches = data.cpu().data.numpy() |
|
|
606 |
|
|
|
607 |
pred_map_dict = {} |
|
|
608 |
pred_map_dict[ensemble_key] = 0 |
|
|
609 |
for key in model_keys: |
|
|
610 |
pred_map_dict[key] = model_dict[key].predict(image_patches,verbose=0,batch_size=batch_size) |
|
|
611 |
pred_map_dict[ensemble_key]+=pred_map_dict[key] |
|
|
612 |
pred_map_dict[ensemble_key]/=len(model_keys) |
|
|
613 |
|
|
|
614 |
actual_batch_size = image_patches.shape[0] |
|
|
615 |
for j in range(actual_batch_size): |
|
|
616 |
x = int(xes[j]) |
|
|
617 |
y = int(ys[j]) |
|
|
618 |
|
|
|
619 |
wsi_img = image_patches[j]*128+128 |
|
|
620 |
patch_mask = TissueMaskGenerationPatch(wsi_img) |
|
|
621 |
|
|
|
622 |
for key in models_to_save: |
|
|
623 |
prediction = pred_map_dict[key][j,:,:,1] |
|
|
624 |
prediction*=patch_mask |
|
|
625 |
prd_im_fll_dict[key][tmp_mns(x):tmp_pls(x),tmp_mns(y):tmp_pls(y)] += prediction |
|
|
626 |
|
|
|
627 |
count_map[tmp_mns(x):tmp_pls(x),tmp_mns(y):tmp_pls(y)] += np.ones((patch_size,patch_size),dtype='uint8') |
|
|
628 |
if (i+1)%100==0 or i==0 or i<10: |
|
|
629 |
print("Completed %i Time elapsed %.2f min | Max count %d "%(i,(time.time()-start_time)/60,count_map.max())) |
|
|
630 |
|
|
|
631 |
print("Fully completed %i Time elapsed %.2f min | Max count %d "%(i,(time.time()-start_time)/60,count_map.max())) |
|
|
632 |
start_time = time.time() |
|
|
633 |
|
|
|
634 |
print("\t Dividing by count_map") |
|
|
635 |
np.place(count_map, count_map==0, 1) |
|
|
636 |
for key in models_to_save: |
|
|
637 |
prd_im_fll_dict[key]/=count_map |
|
|
638 |
del count_map |
|
|
639 |
gc.collect() |
|
|
640 |
|
|
|
641 |
print("\t Scaling prediciton") |
|
|
642 |
prob_map_dict = {} |
|
|
643 |
for key in models_to_save: |
|
|
644 |
prob_map_dict[key] = scale(prd_im_fll_dict[key]) |
|
|
645 |
prob_map_dict[key] = (prob_map_dict[key]*255).astype('uint8') |
|
|
646 |
|
|
|
647 |
print("\t Thresholding prediction") |
|
|
648 |
threshold = 0.5 |
|
|
649 |
for key in models_to_save: |
|
|
650 |
np.place(prd_im_fll_dict[key],prd_im_fll_dict[key]>=threshold, 255) |
|
|
651 |
np.place(prd_im_fll_dict[key],prd_im_fll_dict[key]<threshold, 0) |
|
|
652 |
print("\t Calculated in %f" % ((time.time() - start_time)/60)) |
|
|
653 |
start_time = time.time() |
|
|
654 |
|
|
|
655 |
print("\t Saving ground truth") |
|
|
656 |
save_model_keys = models_to_save |
|
|
657 |
save_path = '-'.join(img_path.split('-')[:-1]+["mask"])+'.'+'.tiff' |
|
|
658 |
for key in models_to_save: |
|
|
659 |
print("\t Saving to %s %s" %(save_path,key)) |
|
|
660 |
tifffile.imsave(os.path.join(save_path, prd_im_fll_dict[key].T, compress=9)) |
|
|
661 |
print("\t Calculated in %f" % ((time.time() - start_time)/60)) |
|
|
662 |
start_time = time.time() |
|
|
663 |
|
|
|
664 |
start_time = time.time() |
|
|
665 |
print("\t Saving ground truth") |
|
|
666 |
os.system('convert ' + save_path + " -compress jpeg -quality 90 -define tiff:tile-geometry=256x256 ptif:"+save_path) |
|
|
667 |
print("\t Calculated in %f" % ((time.time() - start_time)/60)) |
|
|
668 |
start_time = time.time() |
|
|
669 |
|
|
|
670 |
# print("\t Generating scaled version of ground truth") |
|
|
671 |
# scaled_prd_im_fll_dict = {} |
|
|
672 |
# for key in models_to_save: |
|
|
673 |
# scaled_prd_im_fll_dict[key] = scale(prd_im_fll_dict[key]) |
|
|
674 |
# del prd_im_fll_dict |
|
|
675 |
# gc.collect() |
|
|
676 |
|
|
|
677 |
# mask_im = np.dstack([dataset_obj.get_mask().T]*3).astype('uint8')*255 |
|
|
678 |
# mask_im = np.dstack([TissueMaskGenerationPatch(im_im)]*3).astype('uint8')*255 |
|
|
679 |
# for key in models_to_save: |
|
|
680 |
# mask_im[:,:,0] = scaled_prd_im_fll_dict[key]*255 |
|
|
681 |
# ov_prob_stride = st_im + (np.dstack([prob_map_dict[key]]*3)*255).astype('uint8') |
|
|
682 |
# np.place(ov_prob_stride,ov_prob_stride>255,255) |
|
|
683 |
# imsave(mask_im,ov_prob_stride,prob_map_dict[key],scaled_prd_im_fll_dict[key],im_im,out=os.path.join(out_dir_dict[key],'ref_'+out_file)+'.png') |
|
|
684 |
|
|
|
685 |
# for key in models_to_save: |
|
|
686 |
# with open(os.path.join(out_dir_dict[key],'jacc_scores.txt'), 'a') as f: |
|
|
687 |
# f.write("Total,%f\n" %(total_jacc_score_dict[key]/len(sample_ids))) |