a b/vis_utils/heatmap_utils.py
1
import numpy as np
2
import torch
3
import torch.nn as nn
4
import torch.nn.functional as F
5
import pdb
6
import os
7
import pandas as pd
8
from utils.utils import *
9
from PIL import Image
10
from math import floor
11
import matplotlib.pyplot as plt
12
from datasets.wsi_dataset import Wsi_Region
13
import h5py
14
from wsi_core.WholeSlideImage import WholeSlideImage
15
from scipy.stats import percentileofscore
16
import math
17
from utils.file_utils import save_hdf5
18
from scipy.stats import percentileofscore
19
20
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
22
def score2percentile(score, ref):
23
    percentile = percentileofscore(ref, score)
24
    return percentile
25
26
def drawHeatmap(scores, coords, slide_path=None, wsi_object=None, vis_level = -1, **kwargs):
27
    if wsi_object is None:
28
        wsi_object = WholeSlideImage(slide_path)
29
        print(wsi_object.name)
30
    
31
    wsi = wsi_object.getOpenSlide()
32
    if vis_level < 0:
33
        vis_level = wsi.get_best_level_for_downsample(32)
34
    
35
    heatmap = wsi_object.visHeatmap(scores=scores, coords=coords, vis_level=vis_level, **kwargs)
36
    return heatmap
37
38
def initialize_wsi(wsi_path, seg_mask_path=None, seg_params=None, filter_params=None):
39
    wsi_object = WholeSlideImage(wsi_path)
40
    if seg_params['seg_level'] < 0:
41
        best_level = wsi_object.wsi.get_best_level_for_downsample(32)
42
        seg_params['seg_level'] = best_level
43
44
    wsi_object.segmentTissue(**seg_params, filter_params=filter_params)
45
    wsi_object.saveSegmentation(seg_mask_path)
46
    return wsi_object
47
48
def compute_from_patches(wsi_object, clam_pred=None, model=None, feature_extractor=None, batch_size=512,  
49
    attn_save_path=None, ref_scores=None, feat_save_path=None, **wsi_kwargs):    
50
    top_left = wsi_kwargs['top_left']
51
    bot_right = wsi_kwargs['bot_right']
52
    patch_size = wsi_kwargs['patch_size']
53
    
54
    roi_dataset = Wsi_Region(wsi_object, **wsi_kwargs)
55
    roi_loader = get_simple_loader(roi_dataset, batch_size=batch_size, num_workers=8)
56
    print('total number of patches to process: ', len(roi_dataset))
57
    num_batches = len(roi_loader)
58
    print('number of batches: ', len(roi_loader))
59
    mode = "w"
60
    for idx, (roi, coords) in enumerate(roi_loader):
61
        roi = roi.to(device)
62
        coords = coords.numpy()
63
        
64
        with torch.no_grad():
65
            features = feature_extractor(roi)
66
67
            if attn_save_path is not None:
68
                A = model(features, attention_only=True)
69
           
70
                if A.size(0) > 1: #CLAM multi-branch attention
71
                    A = A[clam_pred]
72
73
                A = A.view(-1, 1).cpu().numpy()
74
75
                if ref_scores is not None:
76
                    for score_idx in range(len(A)):
77
                        A[score_idx] = score2percentile(A[score_idx], ref_scores)
78
79
                asset_dict = {'attention_scores': A, 'coords': coords}
80
                save_path = save_hdf5(attn_save_path, asset_dict, mode=mode)
81
    
82
        if idx % math.ceil(num_batches * 0.05) == 0:
83
            print('procssed {} / {}'.format(idx, num_batches))
84
85
        if feat_save_path is not None:
86
            asset_dict = {'features': features.cpu().numpy(), 'coords': coords}
87
            save_hdf5(feat_save_path, asset_dict, mode=mode)
88
89
        mode = "a"
90
    return attn_save_path, feat_save_path, wsi_object