|
a |
|
b/datasets/wsi_dataset.py |
|
|
1 |
from torchvision import transforms |
|
|
2 |
import pandas as pd |
|
|
3 |
import numpy as np |
|
|
4 |
import time |
|
|
5 |
import pdb |
|
|
6 |
import PIL.Image as Image |
|
|
7 |
import h5py |
|
|
8 |
from torch.utils.data import Dataset |
|
|
9 |
import torch |
|
|
10 |
from wsi_core.util_classes import Contour_Checking_fn, isInContourV1, isInContourV2, isInContourV3_Easy, isInContourV3_Hard |
|
|
11 |
|
|
|
12 |
def default_transforms(mean = (0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): |
|
|
13 |
t = transforms.Compose( |
|
|
14 |
[transforms.ToTensor(), |
|
|
15 |
transforms.Normalize(mean = mean, std = std)]) |
|
|
16 |
return t |
|
|
17 |
|
|
|
18 |
def get_contour_check_fn(contour_fn='four_pt_hard', cont=None, ref_patch_size=None, center_shift=None): |
|
|
19 |
if contour_fn == 'four_pt_hard': |
|
|
20 |
cont_check_fn = isInContourV3_Hard(contour=cont, patch_size=ref_patch_size, center_shift=center_shift) |
|
|
21 |
elif contour_fn == 'four_pt_easy': |
|
|
22 |
cont_check_fn = isInContourV3_Easy(contour=cont, patch_size=ref_patch_size, center_shift=0.5) |
|
|
23 |
elif contour_fn == 'center': |
|
|
24 |
cont_check_fn = isInContourV2(contour=cont, patch_size=ref_patch_size) |
|
|
25 |
elif contour_fn == 'basic': |
|
|
26 |
cont_check_fn = isInContourV1(contour=cont) |
|
|
27 |
else: |
|
|
28 |
raise NotImplementedError |
|
|
29 |
return cont_check_fn |
|
|
30 |
|
|
|
31 |
|
|
|
32 |
|
|
|
33 |
class Wsi_Region(Dataset): |
|
|
34 |
''' |
|
|
35 |
args: |
|
|
36 |
wsi_object: instance of WholeSlideImage wrapper over a WSI |
|
|
37 |
top_left: tuple of coordinates representing the top left corner of WSI region (Default: None) |
|
|
38 |
bot_right tuple of coordinates representing the bot right corner of WSI region (Default: None) |
|
|
39 |
level: downsample level at which to prcess the WSI region |
|
|
40 |
patch_size: tuple of width, height representing the patch size |
|
|
41 |
step_size: tuple of w_step, h_step representing the step size |
|
|
42 |
contour_fn (str): |
|
|
43 |
contour checking fn to use |
|
|
44 |
choice of ['four_pt_hard', 'four_pt_easy', 'center', 'basic'] (Default: 'four_pt_hard') |
|
|
45 |
t: custom torchvision transformation to apply |
|
|
46 |
custom_downsample (int): additional downscale factor to apply |
|
|
47 |
use_center_shift: for 'four_pt_hard' contour check, how far out to shift the 4 points |
|
|
48 |
''' |
|
|
49 |
def __init__(self, wsi_object, top_left=None, bot_right=None, level=0, |
|
|
50 |
patch_size = (256, 256), step_size=(256, 256), |
|
|
51 |
contour_fn='four_pt_hard', |
|
|
52 |
t=None, custom_downsample=1, use_center_shift=False): |
|
|
53 |
|
|
|
54 |
self.custom_downsample = custom_downsample |
|
|
55 |
|
|
|
56 |
# downscale factor in reference to level 0 |
|
|
57 |
self.ref_downsample = wsi_object.level_downsamples[level] |
|
|
58 |
# patch size in reference to level 0 |
|
|
59 |
self.ref_size = tuple((np.array(patch_size) * np.array(self.ref_downsample)).astype(int)) |
|
|
60 |
|
|
|
61 |
if self.custom_downsample > 1: |
|
|
62 |
self.target_patch_size = patch_size |
|
|
63 |
patch_size = tuple((np.array(patch_size) * np.array(self.ref_downsample) * custom_downsample).astype(int)) |
|
|
64 |
step_size = tuple((np.array(step_size) * custom_downsample).astype(int)) |
|
|
65 |
self.ref_size = patch_size |
|
|
66 |
else: |
|
|
67 |
step_size = tuple((np.array(step_size)).astype(int)) |
|
|
68 |
self.ref_size = tuple((np.array(patch_size) * np.array(self.ref_downsample)).astype(int)) |
|
|
69 |
|
|
|
70 |
self.wsi = wsi_object.wsi |
|
|
71 |
self.level = level |
|
|
72 |
self.patch_size = patch_size |
|
|
73 |
|
|
|
74 |
if not use_center_shift: |
|
|
75 |
center_shift = 0. |
|
|
76 |
else: |
|
|
77 |
overlap = 1 - float(step_size[0] / patch_size[0]) |
|
|
78 |
if overlap < 0.25: |
|
|
79 |
center_shift = 0.375 |
|
|
80 |
elif overlap >= 0.25 and overlap < 0.75: |
|
|
81 |
center_shift = 0.5 |
|
|
82 |
elif overlap >=0.75 and overlap < 0.95: |
|
|
83 |
center_shift = 0.5 |
|
|
84 |
else: |
|
|
85 |
center_shift = 0.625 |
|
|
86 |
#center_shift = 0.375 # 25% overlap |
|
|
87 |
#center_shift = 0.625 #50%, 75% overlap |
|
|
88 |
#center_shift = 1.0 #95% overlap |
|
|
89 |
|
|
|
90 |
filtered_coords = [] |
|
|
91 |
#iterate through tissue contours for valid patch coordinates |
|
|
92 |
for cont_idx, contour in enumerate(wsi_object.contours_tissue): |
|
|
93 |
print('processing {}/{} contours'.format(cont_idx, len(wsi_object.contours_tissue))) |
|
|
94 |
cont_check_fn = get_contour_check_fn(contour_fn, contour, self.ref_size[0], center_shift) |
|
|
95 |
coord_results, _ = wsi_object.process_contour(contour, wsi_object.holes_tissue[cont_idx], level, '', |
|
|
96 |
patch_size = patch_size[0], step_size = step_size[0], contour_fn=cont_check_fn, |
|
|
97 |
use_padding=True, top_left = top_left, bot_right = bot_right) |
|
|
98 |
if len(coord_results) > 0: |
|
|
99 |
filtered_coords.append(coord_results['coords']) |
|
|
100 |
|
|
|
101 |
coords=np.vstack(filtered_coords) |
|
|
102 |
|
|
|
103 |
self.coords = coords |
|
|
104 |
print('filtered a total of {} coordinates'.format(len(self.coords))) |
|
|
105 |
|
|
|
106 |
# apply transformation |
|
|
107 |
if t is None: |
|
|
108 |
self.transforms = default_transforms() |
|
|
109 |
else: |
|
|
110 |
self.transforms = t |
|
|
111 |
|
|
|
112 |
def __len__(self): |
|
|
113 |
return len(self.coords) |
|
|
114 |
|
|
|
115 |
def __getitem__(self, idx): |
|
|
116 |
coord = self.coords[idx] |
|
|
117 |
patch = self.wsi.read_region(tuple(coord), self.level, self.patch_size).convert('RGB') |
|
|
118 |
if self.custom_downsample > 1: |
|
|
119 |
patch = patch.resize(self.target_patch_size) |
|
|
120 |
patch = self.transforms(patch).unsqueeze(0) |
|
|
121 |
return patch, coord |