a b/datasets/wsi_dataset.py
1
from torchvision import transforms
2
import pandas as pd
3
import numpy as np
4
import time
5
import pdb
6
import PIL.Image as Image
7
import h5py
8
from torch.utils.data import Dataset
9
import torch
10
from wsi_core.util_classes import Contour_Checking_fn, isInContourV1, isInContourV2, isInContourV3_Easy, isInContourV3_Hard
11
12
def default_transforms(mean = (0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
13
    t = transforms.Compose(
14
                        [transforms.ToTensor(),
15
                         transforms.Normalize(mean = mean, std = std)])
16
    return t
17
18
def get_contour_check_fn(contour_fn='four_pt_hard', cont=None, ref_patch_size=None, center_shift=None):
19
    if contour_fn == 'four_pt_hard':
20
        cont_check_fn = isInContourV3_Hard(contour=cont, patch_size=ref_patch_size, center_shift=center_shift)
21
    elif contour_fn == 'four_pt_easy':
22
        cont_check_fn = isInContourV3_Easy(contour=cont, patch_size=ref_patch_size, center_shift=0.5)
23
    elif contour_fn == 'center':
24
        cont_check_fn = isInContourV2(contour=cont, patch_size=ref_patch_size)
25
    elif contour_fn == 'basic':
26
        cont_check_fn = isInContourV1(contour=cont)
27
    else:
28
        raise NotImplementedError
29
    return cont_check_fn
30
31
32
33
class Wsi_Region(Dataset):
34
    '''
35
    args:
36
        wsi_object: instance of WholeSlideImage wrapper over a WSI
37
        top_left: tuple of coordinates representing the top left corner of WSI region (Default: None)
38
        bot_right tuple of coordinates representing the bot right corner of WSI region (Default: None)
39
        level: downsample level at which to prcess the WSI region
40
        patch_size: tuple of width, height representing the patch size
41
        step_size: tuple of w_step, h_step representing the step size
42
        contour_fn (str): 
43
            contour checking fn to use
44
            choice of ['four_pt_hard', 'four_pt_easy', 'center', 'basic'] (Default: 'four_pt_hard')
45
        t: custom torchvision transformation to apply 
46
        custom_downsample (int): additional downscale factor to apply 
47
        use_center_shift: for 'four_pt_hard' contour check, how far out to shift the 4 points
48
    '''
49
    def __init__(self, wsi_object, top_left=None, bot_right=None, level=0, 
50
                 patch_size = (256, 256), step_size=(256, 256), 
51
                 contour_fn='four_pt_hard',
52
                 t=None, custom_downsample=1, use_center_shift=False):
53
        
54
        self.custom_downsample = custom_downsample
55
56
        # downscale factor in reference to level 0
57
        self.ref_downsample = wsi_object.level_downsamples[level]
58
        # patch size in reference to level 0
59
        self.ref_size = tuple((np.array(patch_size) * np.array(self.ref_downsample)).astype(int)) 
60
        
61
        if self.custom_downsample > 1:
62
            self.target_patch_size = patch_size
63
            patch_size = tuple((np.array(patch_size) * np.array(self.ref_downsample) * custom_downsample).astype(int))
64
            step_size = tuple((np.array(step_size) * custom_downsample).astype(int))
65
            self.ref_size = patch_size
66
        else:
67
            step_size = tuple((np.array(step_size)).astype(int))
68
            self.ref_size = tuple((np.array(patch_size) * np.array(self.ref_downsample)).astype(int)) 
69
        
70
        self.wsi = wsi_object.wsi
71
        self.level = level
72
        self.patch_size = patch_size
73
            
74
        if not use_center_shift:
75
            center_shift = 0.
76
        else:
77
            overlap = 1 - float(step_size[0] / patch_size[0])
78
            if overlap < 0.25:
79
                center_shift = 0.375
80
            elif overlap >= 0.25 and overlap < 0.75:
81
                center_shift = 0.5
82
            elif overlap >=0.75 and overlap < 0.95:
83
                center_shift = 0.5
84
            else:
85
                center_shift = 0.625
86
            #center_shift = 0.375 # 25% overlap
87
            #center_shift = 0.625 #50%, 75% overlap
88
            #center_shift = 1.0 #95% overlap
89
        
90
        filtered_coords = []
91
        #iterate through tissue contours for valid patch coordinates
92
        for cont_idx, contour in enumerate(wsi_object.contours_tissue): 
93
            print('processing {}/{} contours'.format(cont_idx, len(wsi_object.contours_tissue)))
94
            cont_check_fn = get_contour_check_fn(contour_fn, contour, self.ref_size[0], center_shift)
95
            coord_results, _ = wsi_object.process_contour(contour, wsi_object.holes_tissue[cont_idx], level, '', 
96
                            patch_size = patch_size[0], step_size = step_size[0], contour_fn=cont_check_fn,
97
                            use_padding=True, top_left = top_left, bot_right = bot_right)
98
            if len(coord_results) > 0:
99
                filtered_coords.append(coord_results['coords'])
100
        
101
        coords=np.vstack(filtered_coords)
102
103
        self.coords = coords
104
        print('filtered a total of {} coordinates'.format(len(self.coords)))
105
        
106
        # apply transformation
107
        if t is None:
108
            self.transforms = default_transforms()
109
        else:
110
            self.transforms = t
111
112
    def __len__(self):
113
        return len(self.coords)
114
    
115
    def __getitem__(self, idx):
116
        coord = self.coords[idx]
117
        patch = self.wsi.read_region(tuple(coord), self.level, self.patch_size).convert('RGB')
118
        if self.custom_downsample > 1:
119
            patch = patch.resize(self.target_patch_size)
120
        patch = self.transforms(patch).unsqueeze(0)
121
        return patch, coord