|
a |
|
b/extract_features.py |
|
|
1 |
import argparse |
|
|
2 |
import os |
|
|
3 |
import time |
|
|
4 |
import numpy as np |
|
|
5 |
|
|
|
6 |
import openslide |
|
|
7 |
import cv2 |
|
|
8 |
from PIL import Image, ImageDraw |
|
|
9 |
from shapely.affinity import scale |
|
|
10 |
from shapely.geometry import Polygon, MultiPolygon |
|
|
11 |
from shapely.ops import unary_union |
|
|
12 |
from collections import defaultdict |
|
|
13 |
|
|
|
14 |
import nmslib |
|
|
15 |
from sklearn.metrics.pairwise import euclidean_distances, cosine_similarity |
|
|
16 |
|
|
|
17 |
# Optional if stain deconvolution is used. |
|
|
18 |
import histomicstk as htk #pip install histomicstk --find-links https://girder.github.io/large_image_wheels |
|
|
19 |
|
|
|
20 |
import torch |
|
|
21 |
from torch.utils.data import DataLoader, Dataset |
|
|
22 |
from torchvision import transforms |
|
|
23 |
from esvit.utils import bool_flag |
|
|
24 |
|
|
|
25 |
# You can use your own encoder to extract features. Here's examples including EsVIT the one used in the publication. |
|
|
26 |
from encoders import load_encoder_esVIT, load_encoder_resnet |
|
|
27 |
|
|
|
28 |
def get_args_parser(): |
|
|
29 |
parser = argparse.ArgumentParser('Preprocessing script esvit', add_help=False) |
|
|
30 |
parser.add_argument( |
|
|
31 |
"--input_slide", |
|
|
32 |
type=str, |
|
|
33 |
help="Path to input WSI file", |
|
|
34 |
) |
|
|
35 |
parser.add_argument( |
|
|
36 |
"--output_dir", |
|
|
37 |
type=str, |
|
|
38 |
help="Directory to save output data", |
|
|
39 |
) |
|
|
40 |
parser.add_argument( |
|
|
41 |
"--checkpoint", |
|
|
42 |
type=str, |
|
|
43 |
help="Feature extractor weights checkpoint", |
|
|
44 |
) |
|
|
45 |
parser.add_argument( |
|
|
46 |
"--batch_size", |
|
|
47 |
type=int, |
|
|
48 |
default=512, |
|
|
49 |
) |
|
|
50 |
parser.add_argument( |
|
|
51 |
"--tile_size", |
|
|
52 |
help="Desired tile size in microns (should be the same value as used in feature extraction model).", |
|
|
53 |
type=int, |
|
|
54 |
required=True, |
|
|
55 |
) |
|
|
56 |
parser.add_argument( |
|
|
57 |
"--out_size", |
|
|
58 |
help="Resize the square tile to this output size (in pixels).", |
|
|
59 |
type=int, |
|
|
60 |
default=224, |
|
|
61 |
) |
|
|
62 |
parser.add_argument( |
|
|
63 |
"--method", |
|
|
64 |
help="Segmentation method, otsu or stain deconv", |
|
|
65 |
type=str, |
|
|
66 |
default='otsu', |
|
|
67 |
) |
|
|
68 |
parser.add_argument( |
|
|
69 |
"--dist_threshold", |
|
|
70 |
type=int, |
|
|
71 |
default=4, |
|
|
72 |
help="L2 norm distance when spatially merging pacthes.", |
|
|
73 |
) |
|
|
74 |
parser.add_argument( |
|
|
75 |
"--corr_threshold", |
|
|
76 |
type=float, |
|
|
77 |
default=0.6, |
|
|
78 |
help="Cosine similarity distance when semantically merging pacthes.", |
|
|
79 |
) |
|
|
80 |
parser.add_argument( |
|
|
81 |
"--workers", |
|
|
82 |
help="The number of workers to use for the data loader. Only relevant when using a GPU.", |
|
|
83 |
type=int, |
|
|
84 |
default=4, |
|
|
85 |
) |
|
|
86 |
parser.add_argument( |
|
|
87 |
'--cfg', |
|
|
88 |
help='experiment configure file name. See EsVIT repo.', |
|
|
89 |
type=str |
|
|
90 |
) |
|
|
91 |
parser.add_argument( |
|
|
92 |
'--arch', default='deit_small', type=str, |
|
|
93 |
choices=['cvt_tiny', 'swin_tiny','swin_small', 'swin_base', 'swin_large', 'swin', 'vil', 'vil_1281', 'vil_2262', 'deit_tiny', 'deit_small', 'vit_base'], |
|
|
94 |
help="""Name of architecture to train. For quick experiments with ViTs, we recommend using deit_tiny or deit_small. See EsVIT repo.""" |
|
|
95 |
) |
|
|
96 |
parser.add_argument( |
|
|
97 |
'--n_last_blocks', |
|
|
98 |
default=4, |
|
|
99 |
type=int, |
|
|
100 |
help="""Concatenate [CLS] tokens for the `n` last blocks. We use `n=4` when evaluating DeiT-Small and `n=1` with ViT-Base. See EsVIT repo.""" |
|
|
101 |
) |
|
|
102 |
parser.add_argument( |
|
|
103 |
'--avgpool_patchtokens', |
|
|
104 |
default=False, |
|
|
105 |
type=bool_flag, |
|
|
106 |
help="""Whether ot not to concatenate the global average pooled features to the [CLS] token. |
|
|
107 |
We typically set this to False for DeiT-Small and to True with ViT-Base. See EsVIT repo.""" |
|
|
108 |
) |
|
|
109 |
parser.add_argument( |
|
|
110 |
'--patch_size', |
|
|
111 |
default=8, |
|
|
112 |
type=int, |
|
|
113 |
help='Patch resolution of the model. See EsVIT repo.' |
|
|
114 |
) |
|
|
115 |
parser.add_argument( |
|
|
116 |
'opts', |
|
|
117 |
help="Modify config options using the command-line. See EsVIT repo.", |
|
|
118 |
default=None, |
|
|
119 |
nargs=argparse.REMAINDER |
|
|
120 |
) |
|
|
121 |
parser.add_argument( |
|
|
122 |
"--rank", |
|
|
123 |
default=0, |
|
|
124 |
type=int, |
|
|
125 |
help="Please ignore and do not set this argument.") |
|
|
126 |
|
|
|
127 |
return parser |
|
|
128 |
|
|
|
129 |
def segment_tissue(img): |
|
|
130 |
img_hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) |
|
|
131 |
mthresh = 7 |
|
|
132 |
img_med = cv2.medianBlur(img_hsv[:, :, 1], mthresh) |
|
|
133 |
_, img_prepped = cv2.threshold(img_med, 0, 255, cv2.THRESH_OTSU + cv2.THRESH_BINARY) |
|
|
134 |
|
|
|
135 |
close = 4 |
|
|
136 |
kernel = np.ones((close, close), np.uint8) |
|
|
137 |
img_prepped = cv2.morphologyEx(img_prepped, cv2.MORPH_CLOSE, kernel) |
|
|
138 |
|
|
|
139 |
# Find and filter contours |
|
|
140 |
contours, hierarchy = cv2.findContours( |
|
|
141 |
img_prepped, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE |
|
|
142 |
) |
|
|
143 |
return contours, hierarchy |
|
|
144 |
|
|
|
145 |
def segment_tissue_deconv_stain(img): |
|
|
146 |
""" |
|
|
147 |
Method 2: Tissue segmentation using stain deconvolution. Alternative to Otsu thresholding. |
|
|
148 |
""" |
|
|
149 |
image = img.copy() |
|
|
150 |
|
|
|
151 |
image[image[...,-1]==0] = [255,255,255,0] |
|
|
152 |
|
|
|
153 |
image = Image.fromarray(image) |
|
|
154 |
image = np.asarray(image.convert('RGB')) |
|
|
155 |
|
|
|
156 |
I_0 = 255 |
|
|
157 |
|
|
|
158 |
# Create stain to color map |
|
|
159 |
stain_color_map = htk.preprocessing.color_deconvolution.stain_color_map |
|
|
160 |
|
|
|
161 |
# Specify stains of input image |
|
|
162 |
stains = ['hematoxylin', # nuclei stain |
|
|
163 |
'eosin'] # cytoplasm stain |
|
|
164 |
|
|
|
165 |
w_est = htk.preprocessing.color_deconvolution.rgb_separate_stains_macenko_pca(image, I_0) |
|
|
166 |
deconv_result = htk.preprocessing.color_deconvolution.color_deconvolution(image, w_est, I_0) |
|
|
167 |
|
|
|
168 |
final_mask = np.zeros(image.shape[0:2], np.uint8) |
|
|
169 |
|
|
|
170 |
for i in 0, 1: |
|
|
171 |
channel = htk.preprocessing.color_deconvolution.find_stain_index( |
|
|
172 |
stain_color_map[stains[i]], w_est) |
|
|
173 |
|
|
|
174 |
img_for_thresholding = 255 - deconv_result.Stains[:, :, channel] |
|
|
175 |
_, img_prepped = cv2.threshold( |
|
|
176 |
img_for_thresholding, 0, 255, cv2.THRESH_OTSU+cv2.THRESH_BINARY) |
|
|
177 |
|
|
|
178 |
final_mask = cv2.bitwise_or(final_mask, img_prepped) |
|
|
179 |
|
|
|
180 |
for i in range(5): |
|
|
181 |
close = 3 |
|
|
182 |
kernel = np.ones((close, close), np.uint8) |
|
|
183 |
final_mask = cv2.morphologyEx(final_mask, cv2.MORPH_OPEN, kernel) |
|
|
184 |
final_mask = cv2.morphologyEx(final_mask, cv2.MORPH_CLOSE, kernel) |
|
|
185 |
|
|
|
186 |
return final_mask |
|
|
187 |
|
|
|
188 |
def mask_to_polygons(mask, min_area, min_area_holes=10., epsilon=10.): |
|
|
189 |
"""Convert a mask ndarray (binarized image) to Multipolygons""" |
|
|
190 |
# first, find contours with cv2: it's much faster than shapely |
|
|
191 |
contours, hierarchy = cv2.findContours(mask, |
|
|
192 |
cv2.RETR_CCOMP, |
|
|
193 |
cv2.CHAIN_APPROX_NONE) |
|
|
194 |
if not contours: |
|
|
195 |
return MultiPolygon() |
|
|
196 |
|
|
|
197 |
cnt_children = defaultdict(list) |
|
|
198 |
child_contours = set() |
|
|
199 |
assert hierarchy.shape[0] == 1 |
|
|
200 |
|
|
|
201 |
# http://docs.opencv.org/3.1.0/d9/d8b/tutorial_py_contours_hierarchy.html |
|
|
202 |
for idx, (_, _, _, parent_idx) in enumerate(hierarchy[0]): |
|
|
203 |
if parent_idx != -1: |
|
|
204 |
child_contours.add(idx) |
|
|
205 |
cnt_children[parent_idx].append(contours[idx]) |
|
|
206 |
|
|
|
207 |
# create actual polygons filtering by area (removes artifacts) |
|
|
208 |
all_polygons = [] |
|
|
209 |
|
|
|
210 |
for idx, cnt in enumerate(contours): |
|
|
211 |
|
|
|
212 |
if idx not in child_contours and cv2.contourArea(cnt) >= min_area: |
|
|
213 |
assert cnt.shape[1] == 1 |
|
|
214 |
poly = Polygon( |
|
|
215 |
shell=cnt[:, 0, :], |
|
|
216 |
holes=[c[:, 0, :] for c in cnt_children.get(idx, []) |
|
|
217 |
if cv2.contourArea(c) >= min_area_holes]) |
|
|
218 |
|
|
|
219 |
if not poly.is_valid: |
|
|
220 |
# This is likely becausee the polygon is self-touching or self-crossing. |
|
|
221 |
# Try and 'correct' the polygon using the zero-length buffer() trick. |
|
|
222 |
# See https://shapely.readthedocs.io/en/stable/manual.html#object.buffer |
|
|
223 |
poly = poly.buffer(0) |
|
|
224 |
|
|
|
225 |
all_polygons.append(poly) |
|
|
226 |
|
|
|
227 |
if len(all_polygons) == 0: |
|
|
228 |
raise Exception("Raw tissue mask consists of 0 polygons") |
|
|
229 |
|
|
|
230 |
# if this raises an issue - try instead unary_union(all_polygons) |
|
|
231 |
all_polygons = MultiPolygon(all_polygons) |
|
|
232 |
|
|
|
233 |
return all_polygons |
|
|
234 |
|
|
|
235 |
def detect_foreground(contours, hierarchy): |
|
|
236 |
hierarchy = np.squeeze(hierarchy, axis=(0,))[:, 2:] |
|
|
237 |
|
|
|
238 |
# find foreground contours (parent == -1) |
|
|
239 |
hierarchy_1 = np.flatnonzero(hierarchy[:, 1] == -1) |
|
|
240 |
foreground_contours = [contours[cont_idx] for cont_idx in hierarchy_1] |
|
|
241 |
|
|
|
242 |
all_holes = [] |
|
|
243 |
for cont_idx in hierarchy_1: |
|
|
244 |
all_holes.append(np.flatnonzero(hierarchy[:, 1] == cont_idx)) |
|
|
245 |
|
|
|
246 |
hole_contours = [] |
|
|
247 |
for hole_ids in all_holes: |
|
|
248 |
holes = [contours[idx] for idx in hole_ids] |
|
|
249 |
hole_contours.append(holes) |
|
|
250 |
|
|
|
251 |
return foreground_contours, hole_contours |
|
|
252 |
|
|
|
253 |
def construct_polygon(foreground_contours, hole_contours, min_area): |
|
|
254 |
polys = [] |
|
|
255 |
for foreground, holes in zip(foreground_contours, hole_contours): |
|
|
256 |
# We remove all contours that consist of fewer than 3 points, as these won't work with the Polygon constructor. |
|
|
257 |
if len(foreground) < 3: |
|
|
258 |
continue |
|
|
259 |
|
|
|
260 |
# remove redundant dimensions from the contour and convert to Shapely Polygon |
|
|
261 |
poly = Polygon(np.squeeze(foreground)) |
|
|
262 |
|
|
|
263 |
# discard all polygons that are considered too small |
|
|
264 |
if poly.area < min_area: |
|
|
265 |
continue |
|
|
266 |
|
|
|
267 |
if not poly.is_valid: |
|
|
268 |
# This is likely becausee the polygon is self-touching or self-crossing. |
|
|
269 |
# Try and 'correct' the polygon using the zero-length buffer() trick. |
|
|
270 |
# See https://shapely.readthedocs.io/en/stable/manual.html#object.buffer |
|
|
271 |
poly = poly.buffer(0) |
|
|
272 |
|
|
|
273 |
# Punch the holes in the polygon |
|
|
274 |
for hole_contour in holes: |
|
|
275 |
if len(hole_contour) < 3: |
|
|
276 |
continue |
|
|
277 |
|
|
|
278 |
hole = Polygon(np.squeeze(hole_contour)) |
|
|
279 |
|
|
|
280 |
if not hole.is_valid: |
|
|
281 |
continue |
|
|
282 |
|
|
|
283 |
# ignore all very small holes |
|
|
284 |
if hole.area < min_area: |
|
|
285 |
continue |
|
|
286 |
|
|
|
287 |
poly = poly.difference(hole) |
|
|
288 |
|
|
|
289 |
polys.append(poly) |
|
|
290 |
|
|
|
291 |
if len(polys) == 0: |
|
|
292 |
raise Exception("Raw tissue mask consists of 0 polygons") |
|
|
293 |
|
|
|
294 |
# If we have multiple polygons, we merge any overlap between them using unary_union(). |
|
|
295 |
# This will result in a Polygon or MultiPolygon with most tissue masks. |
|
|
296 |
return unary_union(polys) |
|
|
297 |
|
|
|
298 |
def generate_tiles(tile_width_pix, tile_height_pix, img_width, img_height, offsets=[(0, 0)]): |
|
|
299 |
# Generate tiles covering the entire image. |
|
|
300 |
# Provide an offset (x,y) to create a stride-like overlap effect. |
|
|
301 |
# Add an additional tile size to the range stop to prevent tiles being cut off at the edges. |
|
|
302 |
range_stop_width = int(np.ceil(img_width + tile_width_pix)) |
|
|
303 |
range_stop_height = int(np.ceil(img_height + tile_height_pix)) |
|
|
304 |
|
|
|
305 |
rects = [] |
|
|
306 |
for xmin, ymin in offsets: |
|
|
307 |
cols = range(int(np.floor(xmin)), range_stop_width, tile_width_pix) |
|
|
308 |
rows = range(int(np.floor(ymin)), range_stop_height, tile_height_pix) |
|
|
309 |
for x in cols: |
|
|
310 |
for y in rows: |
|
|
311 |
rect = Polygon( |
|
|
312 |
[ |
|
|
313 |
(x, y), |
|
|
314 |
(x + tile_width_pix, y), |
|
|
315 |
(x + tile_width_pix, y - tile_height_pix), |
|
|
316 |
(x, y - tile_height_pix), |
|
|
317 |
] |
|
|
318 |
) |
|
|
319 |
rects.append(rect) |
|
|
320 |
return rects |
|
|
321 |
|
|
|
322 |
def make_tile_QC_fig(tiles, slide, level, line_width_pix=1, extra_tiles=None): |
|
|
323 |
# Render the tiles on an image derived from the specified zoom level |
|
|
324 |
img = slide.read_region((0, 0), level, slide.level_dimensions[level]) |
|
|
325 |
downsample = 1 / slide.level_downsamples[level] |
|
|
326 |
|
|
|
327 |
draw = ImageDraw.Draw(img, "RGBA") |
|
|
328 |
for tile in tiles: |
|
|
329 |
bbox = tuple(np.array(tile.bounds) * downsample) |
|
|
330 |
draw.rectangle(bbox, outline="lightgreen", width=line_width_pix) |
|
|
331 |
|
|
|
332 |
# allow to display other tiles, such as excluded or sampled |
|
|
333 |
if extra_tiles: |
|
|
334 |
for tile in extra_tiles: |
|
|
335 |
bbox = tuple(np.array(tile.bounds) * downsample) |
|
|
336 |
draw.rectangle(bbox, outline="blue", width=line_width_pix + 1) |
|
|
337 |
|
|
|
338 |
return img |
|
|
339 |
|
|
|
340 |
def create_tissue_mask(wsi, seg_level, method='otsu'): |
|
|
341 |
# Determine the best level to determine the segmentation on |
|
|
342 |
level_dims = wsi.level_dimensions[seg_level] |
|
|
343 |
|
|
|
344 |
img = np.array(wsi.read_region((0, 0), seg_level, level_dims)) |
|
|
345 |
|
|
|
346 |
# Get the total surface area of the slide level that was used |
|
|
347 |
level_area = level_dims[0] * level_dims[1] |
|
|
348 |
|
|
|
349 |
# Minimum surface area of tissue polygons (in pixels) |
|
|
350 |
# Note that this value should be sensible in the context of the chosen tile size |
|
|
351 |
min_area = level_area / 500 |
|
|
352 |
|
|
|
353 |
if method=='stain_deconv': |
|
|
354 |
tissue_mask = segment_tissue_deconv_stain(img) |
|
|
355 |
tissue_mask = mask_to_polygons(tissue_mask, min_area) |
|
|
356 |
else: |
|
|
357 |
contours, hierarchy = segment_tissue(img) |
|
|
358 |
foreground_contours, hole_contours = detect_foreground(contours, hierarchy) |
|
|
359 |
tissue_mask = construct_polygon(foreground_contours, hole_contours, min_area) |
|
|
360 |
|
|
|
361 |
# Scale the tissue mask polygon to be in the coordinate space of the slide's level 0 |
|
|
362 |
scale_factor = wsi.level_downsamples[seg_level] |
|
|
363 |
tissue_mask_scaled = scale( |
|
|
364 |
tissue_mask, xfact=scale_factor, yfact=scale_factor, zfact=1.0, origin=(0, 0) |
|
|
365 |
) |
|
|
366 |
|
|
|
367 |
return tissue_mask_scaled |
|
|
368 |
|
|
|
369 |
def create_tissue_tiles(wsi, tissue_mask_scaled, tile_size_microns, offsets_micron=None): |
|
|
370 |
|
|
|
371 |
print(f"tile size is {tile_size_microns} um") |
|
|
372 |
|
|
|
373 |
# Compute the tile size in pixels from the desired tile size in microns and the image resolution |
|
|
374 |
assert ( |
|
|
375 |
openslide.PROPERTY_NAME_MPP_X in wsi.properties |
|
|
376 |
), "microns per pixel along X-dimension not available" |
|
|
377 |
assert ( |
|
|
378 |
openslide.PROPERTY_NAME_MPP_Y in wsi.properties |
|
|
379 |
), "microns per pixel along Y-dimension not available" |
|
|
380 |
|
|
|
381 |
mpp_x = float(wsi.properties[openslide.PROPERTY_NAME_MPP_X]) |
|
|
382 |
mpp_y = float(wsi.properties[openslide.PROPERTY_NAME_MPP_Y]) |
|
|
383 |
|
|
|
384 |
# For larger tiles in micron, NKI scanner outputs mppx slight different than mppy. |
|
|
385 |
# Force tiles to be squared. |
|
|
386 |
mpp_scale_factor = min(mpp_x, mpp_y) |
|
|
387 |
if mpp_x != mpp_y: |
|
|
388 |
print( |
|
|
389 |
f"mpp_x of {mpp_x} and mpp_y of {mpp_y} are not the same. Using smallest value: {mpp_scale_factor}" |
|
|
390 |
) |
|
|
391 |
|
|
|
392 |
tile_size_pix = round(tile_size_microns / mpp_scale_factor) |
|
|
393 |
|
|
|
394 |
# Use the tissue mask bounds as base offsets (+ a margin of a few tiles) to avoid wasting CPU power creating tiles that are never going |
|
|
395 |
# to be inside the tissue mask. |
|
|
396 |
tissue_margin_pix = tile_size_pix * 2 |
|
|
397 |
minx, miny, maxx, maxy = tissue_mask_scaled.bounds |
|
|
398 |
min_offset_x = minx - tissue_margin_pix |
|
|
399 |
min_offset_y = miny - tissue_margin_pix |
|
|
400 |
offsets = [(min_offset_x, min_offset_y)] |
|
|
401 |
|
|
|
402 |
if offsets_micron is not None: |
|
|
403 |
assert ( |
|
|
404 |
len(offsets_micron) > 0 |
|
|
405 |
), "offsets_micron needs to contain at least one value" |
|
|
406 |
# Compute the offsets in micron scale |
|
|
407 |
offset_pix = [round(o / mpp_scale_factor) for o in offsets_micron] |
|
|
408 |
offsets = [(o + min_offset_x, o + min_offset_y) for o in offset_pix] |
|
|
409 |
|
|
|
410 |
# Generate tiles covering the entire WSI |
|
|
411 |
all_tiles = generate_tiles( |
|
|
412 |
tile_size_pix, |
|
|
413 |
tile_size_pix, |
|
|
414 |
maxx + tissue_margin_pix, |
|
|
415 |
maxy + tissue_margin_pix, |
|
|
416 |
offsets=offsets, |
|
|
417 |
) |
|
|
418 |
|
|
|
419 |
# Retain only the tiles that sit within the tissue mask polygon |
|
|
420 |
filtered_tiles = [rect for rect in all_tiles if tissue_mask_scaled.intersects(rect)] |
|
|
421 |
|
|
|
422 |
return filtered_tiles |
|
|
423 |
|
|
|
424 |
def tile_is_not_empty(tile, threshold_white=20): |
|
|
425 |
histogram = tile.histogram() |
|
|
426 |
|
|
|
427 |
# Take the median of each RGB channel. Alpha channel is not of interest. |
|
|
428 |
# If roughly each chanel median is below a threshold, i.e close to 0 till color value around 250 (white reference) then tile mostly white. |
|
|
429 |
whiteness_check = [0, 0, 0] |
|
|
430 |
for channel_id in (0, 1, 2): |
|
|
431 |
whiteness_check[channel_id] = np.median( |
|
|
432 |
histogram[256 * channel_id : 256 * (channel_id + 1)][100:200] |
|
|
433 |
) |
|
|
434 |
|
|
|
435 |
if all(c <= threshold_white for c in whiteness_check): |
|
|
436 |
# exclude tile |
|
|
437 |
return False |
|
|
438 |
|
|
|
439 |
# keep tile |
|
|
440 |
return True |
|
|
441 |
|
|
|
442 |
def crop_rect_from_slide(slide, rect): |
|
|
443 |
minx, miny, maxx, maxy = rect.bounds |
|
|
444 |
# Note that the y-axis is flipped in the slide: the top of the shapely polygon is y = ymax, |
|
|
445 |
# but in the slide it is y = 0. Hence: miny instead of maxy. |
|
|
446 |
top_left_coords = (int(minx), int(miny)) |
|
|
447 |
return slide.read_region(top_left_coords, 0, (int(maxx - minx), int(maxy - miny))) |
|
|
448 |
|
|
|
449 |
class BagOfTiles(Dataset): |
|
|
450 |
def __init__(self, wsi, tiles, resize_to=224): |
|
|
451 |
self.wsi = wsi |
|
|
452 |
self.tiles = tiles |
|
|
453 |
|
|
|
454 |
self.roi_transforms = transforms.Compose( |
|
|
455 |
[ |
|
|
456 |
# As we can't be sure that the input tile dimensions are all consistent, we resize |
|
|
457 |
# them to a commonly used size before feeding them to the model. |
|
|
458 |
# Note: assumes a square image. |
|
|
459 |
transforms.Resize(resize_to), |
|
|
460 |
# Turn the PIL image into a (C x H x W) float tensor in the range [0.0, 1.0] |
|
|
461 |
transforms.ToTensor(), |
|
|
462 |
] |
|
|
463 |
) |
|
|
464 |
|
|
|
465 |
def __len__(self): |
|
|
466 |
return len(self.tiles) |
|
|
467 |
|
|
|
468 |
def __getitem__(self, idx): |
|
|
469 |
tile = self.tiles[idx] |
|
|
470 |
img = crop_rect_from_slide(self.wsi, tile) |
|
|
471 |
|
|
|
472 |
# RGB filtering - calling here speeds up computation since it requires crop_rect_from_slide function. |
|
|
473 |
#is_tile_kept = tile_is_not_empty(img, threshold_white=20) |
|
|
474 |
is_tile_kept = True |
|
|
475 |
|
|
|
476 |
# Ensure the img is RGB, as expected by the pretrained model. |
|
|
477 |
# See https://pytorch.org/docs/stable/torchvision/models.html |
|
|
478 |
img = img.convert("RGB") |
|
|
479 |
|
|
|
480 |
# Ensure we have a square tile in our hands. |
|
|
481 |
# We can't handle non-squares currently, as this would requiring changes to |
|
|
482 |
# the aspect ratio when resizing. |
|
|
483 |
width, height = img.size |
|
|
484 |
assert width == height, "input image is not a square" |
|
|
485 |
|
|
|
486 |
img = self.roi_transforms(img).unsqueeze(0) |
|
|
487 |
coord = tile.bounds |
|
|
488 |
return img, coord, is_tile_kept |
|
|
489 |
|
|
|
490 |
def collate_features(batch): |
|
|
491 |
# Item 2 is the boolean value from tile filtering. |
|
|
492 |
img = torch.cat([item[0] for item in batch if item[2]], dim=0) |
|
|
493 |
coords = np.vstack([item[1] for item in batch if item[2]]) |
|
|
494 |
return [img, coords] |
|
|
495 |
|
|
|
496 |
def mergedpatch_gen(features, coords, dist_threshold=4, corr_threshold = 0.6): |
|
|
497 |
|
|
|
498 |
# Get patch distance in pixels with rendered segmentation level. Note that each patch is squared and therefore same distance. |
|
|
499 |
patch_dist = abs(coords[0,2] - coords[0,0]) |
|
|
500 |
print(patch_dist) |
|
|
501 |
|
|
|
502 |
# Compute feature similarity (cosine) and nearby pacthes (L2 norm - only need the top left x,y coordinates) |
|
|
503 |
cosine_matrix = cosine_similarity(features, features) |
|
|
504 |
coordinate_matrix = euclidean_distances(coords[:,:2], coords[:,:2]) |
|
|
505 |
|
|
|
506 |
# NOTE: random selection for the first patch for patch merging might be less biased towards tissue orientation and size. |
|
|
507 |
indices_avail = np.arange(features.shape[0]) |
|
|
508 |
np.random.seed(0) |
|
|
509 |
np.random.shuffle(indices_avail) |
|
|
510 |
|
|
|
511 |
# Merging together nearby patches and similar within pre-defined threshold. |
|
|
512 |
mergedfeatures = [] |
|
|
513 |
indices_used = [] |
|
|
514 |
for ref in indices_avail: |
|
|
515 |
|
|
|
516 |
# This has been merged already |
|
|
517 |
if ref not in indices_used: |
|
|
518 |
|
|
|
519 |
# Making sure they won't be selected once more |
|
|
520 |
if indices_used: |
|
|
521 |
coordinate_matrix[ref,indices_used] = [np.Inf]*len(indices_used) |
|
|
522 |
cosine_matrix[ref,indices_used] = [0.0]*len(indices_used) |
|
|
523 |
|
|
|
524 |
indices_dist = np.where(coordinate_matrix[ref] < patch_dist*dist_threshold, 1 , 0) |
|
|
525 |
indices_corr = np.where(cosine_matrix[ref] > corr_threshold, 1 , 0) |
|
|
526 |
final_indices = indices_dist * indices_corr |
|
|
527 |
|
|
|
528 |
# which includes already the ref patch |
|
|
529 |
indices_used.extend(list(np.where(final_indices == 1)[0])) |
|
|
530 |
mergedfeatures.append(tuple((features[final_indices==1,:], coords[final_indices==1,:]))) |
|
|
531 |
else: |
|
|
532 |
continue |
|
|
533 |
|
|
|
534 |
assert len(indices_used)==features.shape[0], f'Probably issue in contruscting merged features for graph {len(indices_used)}!={features.shape[0]}' |
|
|
535 |
|
|
|
536 |
return mergedfeatures |
|
|
537 |
|
|
|
538 |
class HNSW: |
|
|
539 |
def __init__(self, space): |
|
|
540 |
self.space = space |
|
|
541 |
|
|
|
542 |
def fit(self, X): |
|
|
543 |
# See https://nmslib.github.io/nmslib/quickstart.html |
|
|
544 |
index = nmslib.init(space=self.space, method='hnsw') |
|
|
545 |
index.addDataPointBatch(X) |
|
|
546 |
index.createIndex() |
|
|
547 |
self.index_ = index |
|
|
548 |
return self |
|
|
549 |
|
|
|
550 |
def query(self, vector, topn): |
|
|
551 |
indices, dist = self.index_.knnQuery(vector, k=topn) |
|
|
552 |
return indices, dist |
|
|
553 |
|
|
|
554 |
@torch.no_grad() |
|
|
555 |
def extract_features(model, device, wsi, filtered_tiles, workers, out_size, batch_size, n_last_blocks, avgpool_patchtokens, depths): |
|
|
556 |
# Use multiple workers if running on the GPU, otherwise we'll need all workers for evaluating the model. |
|
|
557 |
kwargs = ( |
|
|
558 |
{"num_workers": workers, "pin_memory": True} if device.type == "cuda" else {} |
|
|
559 |
) |
|
|
560 |
loader = DataLoader( |
|
|
561 |
dataset=BagOfTiles(wsi, filtered_tiles, resize_to=out_size), |
|
|
562 |
batch_size=batch_size, |
|
|
563 |
collate_fn=collate_features, |
|
|
564 |
**kwargs, |
|
|
565 |
) |
|
|
566 |
features_ = [] |
|
|
567 |
coords_ = [] |
|
|
568 |
for batch, coords in loader: |
|
|
569 |
batch = batch.to(device, non_blocking=True) |
|
|
570 |
# NOTE: Example using EsVIT. You may want to call your own feature extractor otherwise. |
|
|
571 |
features = model.forward_return_n_last_blocks(batch, n_last_blocks, avgpool_patchtokens, depths).cpu().numpy() |
|
|
572 |
features_.extend(features) |
|
|
573 |
coords_.extend(coords) |
|
|
574 |
return np.asarray(features_), np.asarray(coords_) |
|
|
575 |
|
|
|
576 |
def extract_save_features(args): |
|
|
577 |
# Derive the slide ID from its name. |
|
|
578 |
slide_id, _ = os.path.splitext(os.path.basename(args.input_slide)) |
|
|
579 |
wip_file_path = os.path.join(args.output_dir, slide_id + "_wip.h5") |
|
|
580 |
output_file_path = os.path.join(args.output_dir, slide_id + "_features.h5") |
|
|
581 |
|
|
|
582 |
os.makedirs(args.output_dir, exist_ok=True) |
|
|
583 |
|
|
|
584 |
# Check if the _features output file already exist. If so, we terminate to avoid |
|
|
585 |
# overwriting it by accident. This also simplifies resuming bulk batch jobs. |
|
|
586 |
if os.path.exists(output_file_path): |
|
|
587 |
raise Exception(f"{output_file_path} already exists") |
|
|
588 |
|
|
|
589 |
# Open the slide for reading. |
|
|
590 |
wsi = openslide.open_slide(args.input_slide) |
|
|
591 |
|
|
|
592 |
# Decide on which slide level we want to base the segmentation. |
|
|
593 |
seg_level = wsi.get_best_level_for_downsample(64) |
|
|
594 |
|
|
|
595 |
# Run the segmentation and tiling procedure. |
|
|
596 |
start_time = time.time() |
|
|
597 |
tissue_mask_scaled = create_tissue_mask(wsi, seg_level, method=args.method) |
|
|
598 |
filtered_tiles = create_tissue_tiles(wsi, tissue_mask_scaled, args.tile_size) |
|
|
599 |
|
|
|
600 |
# Build a figure for quality control purposes, to check if the tiles are where we expect them. |
|
|
601 |
qc_img = make_tile_QC_fig(filtered_tiles, wsi, seg_level, 2) |
|
|
602 |
qc_img_target_width = 1920 |
|
|
603 |
qc_img = qc_img.resize((qc_img_target_width, int(qc_img.height / (qc_img.width / qc_img_target_width)))) |
|
|
604 |
qc_img_file_path = os.path.join(args.output_dir, f"{slide_id}_features_QC.png") |
|
|
605 |
qc_img.save(qc_img_file_path) |
|
|
606 |
print(f"Finished creating {len(filtered_tiles)} tissue tiles in {time.time() - start_time}s") |
|
|
607 |
|
|
|
608 |
# Save QC figure. |
|
|
609 |
qc_img_file_path = os.path.join( |
|
|
610 |
args.output_dir, f"{slide_id}_N{len(mergedpatches)}mergedpatches_distThreshold{args.dist_threshold}_corrThreshold{args.corr_threshold}.png" |
|
|
611 |
) |
|
|
612 |
|
|
|
613 |
# Extract the rectangles, and compute the feature vectors. Example using EsVIT. |
|
|
614 |
device = torch.device("cuda") |
|
|
615 |
model, _, depths = load_encoder_esVIT(args, device) |
|
|
616 |
|
|
|
617 |
features, coords = extract_features( |
|
|
618 |
model, |
|
|
619 |
device, |
|
|
620 |
wsi, |
|
|
621 |
filtered_tiles, |
|
|
622 |
args.workers, |
|
|
623 |
args.out_size, |
|
|
624 |
args.batch_size, |
|
|
625 |
n_last_blocks = args.n_last_blocks, |
|
|
626 |
avgpool_patchtokens = args.avgpool_patchtokens, |
|
|
627 |
depths = depths, |
|
|
628 |
) |
|
|
629 |
|
|
|
630 |
print(f'Number of features N={len(features)}') |
|
|
631 |
# Merging nearby patches with similar semantic. |
|
|
632 |
mergedpatches = mergedpatch_gen(features, coords, dist_threshold=args.dist_threshold, corr_threshold=args.corr_threshold) |
|
|
633 |
print(f'Merging step => N={len(mergedpatches)}') |
|
|
634 |
|
|
|
635 |
# Saving features. |
|
|
636 |
torch.save(mergedpatches, wip_file_path) |
|
|
637 |
|
|
|
638 |
# Rename the file containing the patches to ensure we can easily |
|
|
639 |
# distinguish incomplete bags of patches (due to e.g. errors) from complete ones in case a job fails. |
|
|
640 |
os.rename(wip_file_path, output_file_path) |
|
|
641 |
|
|
|
642 |
print('Done.') |
|
|
643 |
|
|
|
644 |
if __name__ == '__main__': |
|
|
645 |
parser = argparse.ArgumentParser('Preprocessing script esvit', parents=[get_args_parser()]) |
|
|
646 |
args = parser.parse_args() |
|
|
647 |
|
|
|
648 |
assert os.path.isfile(args.checkpoint), f'{args.checkpoint} does not exist' |
|
|
649 |
assert torch.cuda.is_available(), 'Need cuda for this job' |
|
|
650 |
assert os.path.isfile(args.input_slide), f'{args.input_slide} does not exist' |
|
|
651 |
|
|
|
652 |
extract_save_features(args) |