Diff of /extract_features.py [000000] .. [352cae]

Switch to unified view

a b/extract_features.py
1
import argparse
2
import os
3
import time
4
import numpy as np
5
6
import openslide
7
import cv2
8
from PIL import Image, ImageDraw
9
from shapely.affinity import scale
10
from shapely.geometry import Polygon, MultiPolygon
11
from shapely.ops import unary_union
12
from collections import defaultdict
13
14
import nmslib
15
from sklearn.metrics.pairwise import euclidean_distances, cosine_similarity
16
17
# Optional if stain deconvolution is used.
18
import histomicstk as htk #pip install histomicstk --find-links https://girder.github.io/large_image_wheels
19
20
import torch
21
from torch.utils.data import DataLoader, Dataset
22
from torchvision import transforms
23
from esvit.utils import bool_flag
24
25
# You can use your own encoder to extract features. Here's examples including EsVIT the one used in the publication. 
26
from encoders import load_encoder_esVIT, load_encoder_resnet
27
28
def get_args_parser():
29
    parser = argparse.ArgumentParser('Preprocessing script esvit', add_help=False)
30
    parser.add_argument(
31
        "--input_slide",
32
        type=str,
33
        help="Path to input WSI file",
34
    )
35
    parser.add_argument(
36
        "--output_dir",
37
        type=str,
38
        help="Directory to save output data",
39
    )
40
    parser.add_argument(
41
        "--checkpoint",
42
        type=str,
43
        help="Feature extractor weights checkpoint",
44
    )
45
    parser.add_argument(
46
        "--batch_size",
47
        type=int,
48
        default=512,
49
    )
50
    parser.add_argument(
51
        "--tile_size",
52
        help="Desired tile size in microns (should be the same value as used in feature extraction model).",
53
        type=int,
54
        required=True,
55
    )
56
    parser.add_argument(
57
        "--out_size",
58
        help="Resize the square tile to this output size (in pixels).",
59
        type=int,
60
        default=224,
61
    )
62
    parser.add_argument(
63
        "--method",
64
        help="Segmentation method, otsu or stain deconv",
65
        type=str,
66
        default='otsu',
67
    )
68
    parser.add_argument(
69
        "--dist_threshold",
70
        type=int,
71
        default=4,
72
        help="L2 norm distance when spatially merging pacthes.",
73
    )
74
    parser.add_argument(
75
        "--corr_threshold",
76
        type=float,
77
        default=0.6,
78
        help="Cosine similarity distance when semantically merging pacthes.",
79
    )
80
    parser.add_argument(
81
        "--workers",
82
        help="The number of workers to use for the data loader. Only relevant when using a GPU.",
83
        type=int,
84
        default=4,
85
    )
86
    parser.add_argument(
87
        '--cfg',
88
        help='experiment configure file name. See EsVIT repo.',
89
        type=str
90
    )
91
    parser.add_argument(
92
        '--arch', default='deit_small', type=str,
93
        choices=['cvt_tiny', 'swin_tiny','swin_small', 'swin_base', 'swin_large', 'swin', 'vil', 'vil_1281', 'vil_2262', 'deit_tiny', 'deit_small', 'vit_base'],
94
        help="""Name of architecture to train. For quick experiments with ViTs, we recommend using deit_tiny or deit_small. See EsVIT repo."""
95
    )
96
    parser.add_argument(
97
        '--n_last_blocks', 
98
        default=4, 
99
        type=int, 
100
        help="""Concatenate [CLS] tokens for the `n` last blocks. We use `n=4` when evaluating DeiT-Small and `n=1` with ViT-Base. See EsVIT repo."""
101
    )
102
    parser.add_argument(
103
        '--avgpool_patchtokens', 
104
        default=False, 
105
        type=bool_flag,
106
        help="""Whether ot not to concatenate the global average pooled features to the [CLS] token.
107
        We typically set this to False for DeiT-Small and to True with ViT-Base. See EsVIT repo."""
108
    )
109
    parser.add_argument(
110
        '--patch_size', 
111
        default=8, 
112
        type=int, 
113
        help='Patch resolution of the model. See EsVIT repo.'
114
    )
115
    parser.add_argument(
116
        'opts',
117
        help="Modify config options using the command-line. See EsVIT repo.",
118
        default=None,
119
        nargs=argparse.REMAINDER
120
    )
121
    parser.add_argument(
122
        "--rank", 
123
        default=0, 
124
        type=int, 
125
        help="Please ignore and do not set this argument.")
126
127
    return parser
128
129
def segment_tissue(img):
130
    img_hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
131
    mthresh = 7
132
    img_med = cv2.medianBlur(img_hsv[:, :, 1], mthresh)
133
    _, img_prepped = cv2.threshold(img_med, 0, 255, cv2.THRESH_OTSU + cv2.THRESH_BINARY)
134
135
    close = 4
136
    kernel = np.ones((close, close), np.uint8)
137
    img_prepped = cv2.morphologyEx(img_prepped, cv2.MORPH_CLOSE, kernel)
138
139
    # Find and filter contours
140
    contours, hierarchy = cv2.findContours(
141
        img_prepped, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE
142
    )
143
    return contours, hierarchy
144
145
def segment_tissue_deconv_stain(img):
146
    """
147
    Method 2: Tissue segmentation using stain deconvolution. Alternative to Otsu thresholding. 
148
    """
149
    image = img.copy()
150
151
    image[image[...,-1]==0] = [255,255,255,0]
152
153
    image = Image.fromarray(image)
154
    image = np.asarray(image.convert('RGB'))
155
    
156
    I_0 = 255
157
    
158
    # Create stain to color map
159
    stain_color_map = htk.preprocessing.color_deconvolution.stain_color_map
160
161
    # Specify stains of input image
162
    stains = ['hematoxylin',  # nuclei stain
163
              'eosin']        # cytoplasm stain
164
    
165
    w_est = htk.preprocessing.color_deconvolution.rgb_separate_stains_macenko_pca(image, I_0)
166
    deconv_result = htk.preprocessing.color_deconvolution.color_deconvolution(image, w_est, I_0)
167
    
168
    final_mask = np.zeros(image.shape[0:2], np.uint8)
169
170
    for i in 0, 1: 
171
        channel = htk.preprocessing.color_deconvolution.find_stain_index(
172
            stain_color_map[stains[i]], w_est)
173
174
        img_for_thresholding = 255 - deconv_result.Stains[:, :, channel]
175
        _, img_prepped = cv2.threshold(
176
            img_for_thresholding, 0, 255, cv2.THRESH_OTSU+cv2.THRESH_BINARY)
177
178
        final_mask = cv2.bitwise_or(final_mask, img_prepped)
179
        
180
    for i in range(5):
181
        close = 3
182
        kernel = np.ones((close, close), np.uint8)
183
        final_mask = cv2.morphologyEx(final_mask, cv2.MORPH_OPEN, kernel)
184
        final_mask = cv2.morphologyEx(final_mask, cv2.MORPH_CLOSE, kernel)
185
    
186
    return final_mask
187
188
def mask_to_polygons(mask, min_area, min_area_holes=10., epsilon=10.):
189
    """Convert a mask ndarray (binarized image) to Multipolygons"""
190
    # first, find contours with cv2: it's much faster than shapely
191
    contours, hierarchy = cv2.findContours(mask,
192
                                  cv2.RETR_CCOMP,
193
                                  cv2.CHAIN_APPROX_NONE)
194
    if not contours:
195
        return MultiPolygon()
196
197
    cnt_children = defaultdict(list)
198
    child_contours = set()
199
    assert hierarchy.shape[0] == 1
200
    
201
    # http://docs.opencv.org/3.1.0/d9/d8b/tutorial_py_contours_hierarchy.html
202
    for idx, (_, _, _, parent_idx) in enumerate(hierarchy[0]):
203
        if parent_idx != -1:
204
            child_contours.add(idx)
205
            cnt_children[parent_idx].append(contours[idx])
206
            
207
    # create actual polygons filtering by area (removes artifacts)
208
    all_polygons = []
209
210
    for idx, cnt in enumerate(contours):
211
212
        if idx not in child_contours and cv2.contourArea(cnt) >= min_area:
213
            assert cnt.shape[1] == 1
214
            poly = Polygon(
215
                shell=cnt[:, 0, :],
216
                holes=[c[:, 0, :] for c in cnt_children.get(idx, [])
217
                       if cv2.contourArea(c) >= min_area_holes])
218
            
219
            if not poly.is_valid:
220
                # This is likely becausee the polygon is self-touching or self-crossing.
221
                # Try and 'correct' the polygon using the zero-length buffer() trick.
222
                # See https://shapely.readthedocs.io/en/stable/manual.html#object.buffer
223
                poly = poly.buffer(0)
224
    
225
            all_polygons.append(poly)
226
227
    if len(all_polygons) == 0:
228
        raise Exception("Raw tissue mask consists of 0 polygons")
229
230
    # if this raises an issue - try instead unary_union(all_polygons)        
231
    all_polygons = MultiPolygon(all_polygons)
232
233
    return all_polygons
234
235
def detect_foreground(contours, hierarchy):
236
    hierarchy = np.squeeze(hierarchy, axis=(0,))[:, 2:]
237
238
    # find foreground contours (parent == -1)
239
    hierarchy_1 = np.flatnonzero(hierarchy[:, 1] == -1)
240
    foreground_contours = [contours[cont_idx] for cont_idx in hierarchy_1]
241
242
    all_holes = []
243
    for cont_idx in hierarchy_1:
244
        all_holes.append(np.flatnonzero(hierarchy[:, 1] == cont_idx))
245
246
    hole_contours = []
247
    for hole_ids in all_holes:
248
        holes = [contours[idx] for idx in hole_ids]
249
        hole_contours.append(holes)
250
251
    return foreground_contours, hole_contours
252
253
def construct_polygon(foreground_contours, hole_contours, min_area):
254
    polys = []
255
    for foreground, holes in zip(foreground_contours, hole_contours):
256
        # We remove all contours that consist of fewer than 3 points, as these won't work with the Polygon constructor.
257
        if len(foreground) < 3:
258
            continue
259
260
        # remove redundant dimensions from the contour and convert to Shapely Polygon
261
        poly = Polygon(np.squeeze(foreground))
262
263
        # discard all polygons that are considered too small
264
        if poly.area < min_area:
265
            continue
266
267
        if not poly.is_valid:
268
            # This is likely becausee the polygon is self-touching or self-crossing.
269
            # Try and 'correct' the polygon using the zero-length buffer() trick.
270
            # See https://shapely.readthedocs.io/en/stable/manual.html#object.buffer
271
            poly = poly.buffer(0)
272
273
        # Punch the holes in the polygon
274
        for hole_contour in holes:
275
            if len(hole_contour) < 3:
276
                continue
277
278
            hole = Polygon(np.squeeze(hole_contour))
279
280
            if not hole.is_valid:
281
                continue
282
283
            # ignore all very small holes
284
            if hole.area < min_area:
285
                continue
286
287
            poly = poly.difference(hole)
288
289
        polys.append(poly)
290
291
    if len(polys) == 0:
292
        raise Exception("Raw tissue mask consists of 0 polygons")
293
294
    # If we have multiple polygons, we merge any overlap between them using unary_union().
295
    # This will result in a Polygon or MultiPolygon with most tissue masks.
296
    return unary_union(polys)
297
298
def generate_tiles(tile_width_pix, tile_height_pix, img_width, img_height, offsets=[(0, 0)]):
299
    # Generate tiles covering the entire image.
300
    # Provide an offset (x,y) to create a stride-like overlap effect.
301
    # Add an additional tile size to the range stop to prevent tiles being cut off at the edges.
302
    range_stop_width = int(np.ceil(img_width + tile_width_pix))
303
    range_stop_height = int(np.ceil(img_height + tile_height_pix))
304
305
    rects = []
306
    for xmin, ymin in offsets:
307
        cols = range(int(np.floor(xmin)), range_stop_width, tile_width_pix)
308
        rows = range(int(np.floor(ymin)), range_stop_height, tile_height_pix)
309
        for x in cols:
310
            for y in rows:
311
                rect = Polygon(
312
                    [
313
                        (x, y),
314
                        (x + tile_width_pix, y),
315
                        (x + tile_width_pix, y - tile_height_pix),
316
                        (x, y - tile_height_pix),
317
                    ]
318
                )
319
                rects.append(rect)
320
    return rects
321
322
def make_tile_QC_fig(tiles, slide, level, line_width_pix=1, extra_tiles=None):
323
    # Render the tiles on an image derived from the specified zoom level
324
    img = slide.read_region((0, 0), level, slide.level_dimensions[level])
325
    downsample = 1 / slide.level_downsamples[level]
326
327
    draw = ImageDraw.Draw(img, "RGBA")
328
    for tile in tiles:
329
        bbox = tuple(np.array(tile.bounds) * downsample)
330
        draw.rectangle(bbox, outline="lightgreen", width=line_width_pix)
331
332
    # allow to display other tiles, such as excluded or sampled
333
    if extra_tiles:
334
        for tile in extra_tiles:
335
            bbox = tuple(np.array(tile.bounds) * downsample)
336
            draw.rectangle(bbox, outline="blue", width=line_width_pix + 1)
337
338
    return img
339
340
def create_tissue_mask(wsi, seg_level, method='otsu'):
341
    # Determine the best level to determine the segmentation on
342
    level_dims = wsi.level_dimensions[seg_level]
343
344
    img = np.array(wsi.read_region((0, 0), seg_level, level_dims))
345
346
    # Get the total surface area of the slide level that was used
347
    level_area = level_dims[0] * level_dims[1]
348
349
    # Minimum surface area of tissue polygons (in pixels)
350
    # Note that this value should be sensible in the context of the chosen tile size
351
    min_area = level_area / 500
352
353
    if method=='stain_deconv':
354
        tissue_mask = segment_tissue_deconv_stain(img)
355
        tissue_mask = mask_to_polygons(tissue_mask, min_area)
356
    else:
357
        contours, hierarchy = segment_tissue(img)
358
        foreground_contours, hole_contours = detect_foreground(contours, hierarchy)
359
        tissue_mask = construct_polygon(foreground_contours, hole_contours, min_area)
360
361
    # Scale the tissue mask polygon to be in the coordinate space of the slide's level 0
362
    scale_factor = wsi.level_downsamples[seg_level]
363
    tissue_mask_scaled = scale(
364
        tissue_mask, xfact=scale_factor, yfact=scale_factor, zfact=1.0, origin=(0, 0)
365
    )
366
367
    return tissue_mask_scaled
368
369
def create_tissue_tiles(wsi, tissue_mask_scaled, tile_size_microns, offsets_micron=None):
370
371
    print(f"tile size is {tile_size_microns} um")
372
373
    # Compute the tile size in pixels from the desired tile size in microns and the image resolution
374
    assert (
375
        openslide.PROPERTY_NAME_MPP_X in wsi.properties
376
    ), "microns per pixel along X-dimension not available"
377
    assert (
378
        openslide.PROPERTY_NAME_MPP_Y in wsi.properties
379
    ), "microns per pixel along Y-dimension not available"
380
381
    mpp_x = float(wsi.properties[openslide.PROPERTY_NAME_MPP_X])
382
    mpp_y = float(wsi.properties[openslide.PROPERTY_NAME_MPP_Y])
383
384
    # For larger tiles in micron, NKI scanner outputs mppx slight different than mppy.
385
    # Force tiles to be squared.
386
    mpp_scale_factor = min(mpp_x, mpp_y)
387
    if mpp_x != mpp_y:
388
        print(
389
            f"mpp_x of {mpp_x} and mpp_y of {mpp_y} are not the same. Using smallest value: {mpp_scale_factor}"
390
        )
391
392
    tile_size_pix = round(tile_size_microns / mpp_scale_factor)
393
394
    # Use the tissue mask bounds as base offsets (+ a margin of a few tiles) to avoid wasting CPU power creating tiles that are never going
395
    # to be inside the tissue mask.
396
    tissue_margin_pix = tile_size_pix * 2
397
    minx, miny, maxx, maxy = tissue_mask_scaled.bounds
398
    min_offset_x = minx - tissue_margin_pix
399
    min_offset_y = miny - tissue_margin_pix
400
    offsets = [(min_offset_x, min_offset_y)]
401
402
    if offsets_micron is not None:
403
        assert (
404
            len(offsets_micron) > 0
405
        ), "offsets_micron needs to contain at least one value"
406
        # Compute the offsets in micron scale
407
        offset_pix = [round(o / mpp_scale_factor) for o in offsets_micron]
408
        offsets = [(o + min_offset_x, o + min_offset_y) for o in offset_pix]
409
410
    # Generate tiles covering the entire WSI
411
    all_tiles = generate_tiles(
412
        tile_size_pix,
413
        tile_size_pix,
414
        maxx + tissue_margin_pix,
415
        maxy + tissue_margin_pix,
416
        offsets=offsets,
417
    )
418
419
    # Retain only the tiles that sit within the tissue mask polygon
420
    filtered_tiles = [rect for rect in all_tiles if tissue_mask_scaled.intersects(rect)]
421
422
    return filtered_tiles
423
424
def tile_is_not_empty(tile, threshold_white=20):
425
    histogram = tile.histogram()
426
427
    # Take the median of each RGB channel. Alpha channel is not of interest.
428
    # If roughly each chanel median is below a threshold, i.e close to 0 till color value around 250 (white reference) then tile mostly white.
429
    whiteness_check = [0, 0, 0]
430
    for channel_id in (0, 1, 2):
431
        whiteness_check[channel_id] = np.median(
432
            histogram[256 * channel_id : 256 * (channel_id + 1)][100:200]
433
        )
434
435
    if all(c <= threshold_white for c in whiteness_check):
436
        # exclude tile
437
        return False
438
439
    # keep tile
440
    return True
441
442
def crop_rect_from_slide(slide, rect):
443
    minx, miny, maxx, maxy = rect.bounds
444
    # Note that the y-axis is flipped in the slide: the top of the shapely polygon is y = ymax,
445
    # but in the slide it is y = 0. Hence: miny instead of maxy.
446
    top_left_coords = (int(minx), int(miny))
447
    return slide.read_region(top_left_coords, 0, (int(maxx - minx), int(maxy - miny)))
448
449
class BagOfTiles(Dataset):
450
    def __init__(self, wsi, tiles, resize_to=224):
451
        self.wsi = wsi
452
        self.tiles = tiles
453
454
        self.roi_transforms = transforms.Compose(
455
            [
456
                # As we can't be sure that the input tile dimensions are all consistent, we resize
457
                # them to a commonly used size before feeding them to the model.
458
                # Note: assumes a square image.
459
                transforms.Resize(resize_to),
460
                # Turn the PIL image into a (C x H x W) float tensor in the range [0.0, 1.0]
461
                transforms.ToTensor(),
462
            ]
463
        )
464
465
    def __len__(self):
466
        return len(self.tiles)
467
468
    def __getitem__(self, idx):
469
        tile = self.tiles[idx]
470
        img = crop_rect_from_slide(self.wsi, tile)
471
472
        # RGB filtering - calling here speeds up computation since it requires crop_rect_from_slide function.
473
        #is_tile_kept = tile_is_not_empty(img, threshold_white=20)
474
        is_tile_kept = True
475
476
        # Ensure the img is RGB, as expected by the pretrained model.
477
        # See https://pytorch.org/docs/stable/torchvision/models.html
478
        img = img.convert("RGB")
479
480
        # Ensure we have a square tile in our hands.
481
        # We can't handle non-squares currently, as this would requiring changes to
482
        # the aspect ratio when resizing.
483
        width, height = img.size
484
        assert width == height, "input image is not a square"
485
486
        img = self.roi_transforms(img).unsqueeze(0)
487
        coord = tile.bounds
488
        return img, coord, is_tile_kept
489
490
def collate_features(batch):
491
    # Item 2 is the boolean value from tile filtering.
492
    img = torch.cat([item[0] for item in batch if item[2]], dim=0)
493
    coords = np.vstack([item[1] for item in batch if item[2]])
494
    return [img, coords]
495
496
def mergedpatch_gen(features, coords, dist_threshold=4, corr_threshold = 0.6):
497
498
    # Get patch distance in pixels with rendered segmentation level. Note that each patch is squared and therefore same distance.
499
    patch_dist = abs(coords[0,2] - coords[0,0]) 
500
    print(patch_dist)
501
    
502
    # Compute feature similarity (cosine) and nearby pacthes (L2 norm - only need the top left x,y coordinates)
503
    cosine_matrix = cosine_similarity(features, features)
504
    coordinate_matrix = euclidean_distances(coords[:,:2], coords[:,:2])
505
506
    # NOTE: random selection for the first patch for patch merging might be less biased towards tissue orientation and size. 
507
    indices_avail = np.arange(features.shape[0])
508
    np.random.seed(0)  
509
    np.random.shuffle(indices_avail)
510
511
    # Merging together nearby patches and similar within pre-defined threshold. 
512
    mergedfeatures = []
513
    indices_used = []
514
    for ref in indices_avail:
515
516
        # This has been merged already
517
        if ref not in indices_used:
518
519
            # Making sure they won't be selected once more
520
            if indices_used:
521
                coordinate_matrix[ref,indices_used] = [np.Inf]*len(indices_used)
522
                cosine_matrix[ref,indices_used] = [0.0]*len(indices_used)
523
            
524
            indices_dist = np.where(coordinate_matrix[ref] < patch_dist*dist_threshold, 1 , 0)
525
            indices_corr = np.where(cosine_matrix[ref] > corr_threshold, 1 , 0)
526
            final_indices = indices_dist * indices_corr
527
528
            # which includes already the ref patch
529
            indices_used.extend(list(np.where(final_indices == 1)[0]))
530
            mergedfeatures.append(tuple((features[final_indices==1,:], coords[final_indices==1,:])))
531
        else:
532
            continue
533
        
534
    assert len(indices_used)==features.shape[0], f'Probably issue in contruscting merged features for graph {len(indices_used)}!={features.shape[0]}'
535
536
    return mergedfeatures
537
538
class HNSW:
539
    def __init__(self, space):
540
        self.space = space
541
542
    def fit(self, X):
543
        # See https://nmslib.github.io/nmslib/quickstart.html
544
        index = nmslib.init(space=self.space, method='hnsw')
545
        index.addDataPointBatch(X)
546
        index.createIndex()
547
        self.index_ = index
548
        return self
549
550
    def query(self, vector, topn):
551
        indices, dist = self.index_.knnQuery(vector, k=topn)
552
        return indices, dist
553
554
@torch.no_grad()
555
def extract_features(model, device, wsi, filtered_tiles, workers, out_size, batch_size, n_last_blocks, avgpool_patchtokens, depths):
556
    # Use multiple workers if running on the GPU, otherwise we'll need all workers for evaluating the model.
557
    kwargs = (
558
        {"num_workers": workers, "pin_memory": True} if device.type == "cuda" else {}
559
    )
560
    loader = DataLoader(
561
        dataset=BagOfTiles(wsi, filtered_tiles, resize_to=out_size),
562
        batch_size=batch_size,
563
        collate_fn=collate_features,
564
        **kwargs,
565
    )
566
    features_ = []
567
    coords_ = []
568
    for batch, coords in loader:
569
        batch = batch.to(device, non_blocking=True)
570
        # NOTE: Example using EsVIT. You may want to call your own feature extractor otherwise. 
571
        features = model.forward_return_n_last_blocks(batch, n_last_blocks, avgpool_patchtokens, depths).cpu().numpy()
572
        features_.extend(features)
573
        coords_.extend(coords)
574
    return np.asarray(features_), np.asarray(coords_)
575
576
def extract_save_features(args):
577
    # Derive the slide ID from its name.
578
    slide_id, _ = os.path.splitext(os.path.basename(args.input_slide))
579
    wip_file_path = os.path.join(args.output_dir, slide_id + "_wip.h5")
580
    output_file_path = os.path.join(args.output_dir, slide_id + "_features.h5")
581
582
    os.makedirs(args.output_dir, exist_ok=True)
583
584
    # Check if the _features output file already exist. If so, we terminate to avoid
585
    # overwriting it by accident. This also simplifies resuming bulk batch jobs.
586
    if os.path.exists(output_file_path):
587
        raise Exception(f"{output_file_path} already exists")
588
589
    # Open the slide for reading.
590
    wsi = openslide.open_slide(args.input_slide)
591
592
    # Decide on which slide level we want to base the segmentation.
593
    seg_level = wsi.get_best_level_for_downsample(64)
594
595
    # Run the segmentation and  tiling procedure.
596
    start_time = time.time()
597
    tissue_mask_scaled = create_tissue_mask(wsi, seg_level, method=args.method)
598
    filtered_tiles = create_tissue_tiles(wsi, tissue_mask_scaled, args.tile_size)
599
600
    # Build a figure for quality control purposes, to check if the tiles are where we expect them.
601
    qc_img = make_tile_QC_fig(filtered_tiles, wsi, seg_level, 2)
602
    qc_img_target_width = 1920
603
    qc_img = qc_img.resize((qc_img_target_width, int(qc_img.height / (qc_img.width / qc_img_target_width))))
604
    qc_img_file_path = os.path.join(args.output_dir, f"{slide_id}_features_QC.png")
605
    qc_img.save(qc_img_file_path)
606
    print(f"Finished creating {len(filtered_tiles)} tissue tiles in {time.time() - start_time}s")
607
608
    # Save QC figure.
609
    qc_img_file_path = os.path.join(
610
        args.output_dir, f"{slide_id}_N{len(mergedpatches)}mergedpatches_distThreshold{args.dist_threshold}_corrThreshold{args.corr_threshold}.png"
611
    )
612
613
    # Extract the rectangles, and compute the feature vectors. Example using EsVIT. 
614
    device = torch.device("cuda") 
615
    model, _, depths = load_encoder_esVIT(args, device)
616
    
617
    features, coords = extract_features(
618
        model,
619
        device,
620
        wsi,
621
        filtered_tiles,
622
        args.workers,
623
        args.out_size,
624
        args.batch_size,
625
        n_last_blocks = args.n_last_blocks, 
626
        avgpool_patchtokens = args.avgpool_patchtokens,
627
        depths = depths,
628
    )
629
    
630
    print(f'Number of features N={len(features)}')
631
    # Merging nearby patches with similar semantic. 
632
    mergedpatches = mergedpatch_gen(features, coords, dist_threshold=args.dist_threshold, corr_threshold=args.corr_threshold)
633
    print(f'Merging step => N={len(mergedpatches)}')
634
635
    # Saving features.
636
    torch.save(mergedpatches, wip_file_path)
637
638
    # Rename the file containing the patches to ensure we can easily
639
    # distinguish incomplete bags of patches (due to e.g. errors) from complete ones in case a job fails.
640
    os.rename(wip_file_path, output_file_path)
641
642
    print('Done.')
643
644
if __name__ == '__main__':
645
    parser = argparse.ArgumentParser('Preprocessing script esvit', parents=[get_args_parser()])
646
    args = parser.parse_args()
647
648
    assert os.path.isfile(args.checkpoint), f'{args.checkpoint} does not exist'
649
    assert torch.cuda.is_available(), 'Need cuda for this job'
650
    assert os.path.isfile(args.input_slide), f'{args.input_slide} does not exist'
651
652
    extract_save_features(args)