|
a |
|
b/slideflow/segment/data.py |
|
|
1 |
import torch |
|
|
2 |
import slideflow as sf |
|
|
3 |
import rasterio |
|
|
4 |
import numpy as np |
|
|
5 |
import shapely.affinity as sa |
|
|
6 |
|
|
|
7 |
from typing import Tuple, Union, Optional, List, Dict |
|
|
8 |
from torchvision import transforms |
|
|
9 |
from os.path import join, exists |
|
|
10 |
from rich.progress import track |
|
|
11 |
from shapely.ops import unary_union |
|
|
12 |
from shapely.geometry import Polygon |
|
|
13 |
from shapely.ops import unary_union |
|
|
14 |
from slideflow.util import path_to_name |
|
|
15 |
|
|
|
16 |
from .utils import topleft_pad_torch |
|
|
17 |
|
|
|
18 |
# ----------------------------------------------------------------------------- |
|
|
19 |
|
|
|
20 |
class ThumbMaskDataset(torch.utils.data.Dataset): |
|
|
21 |
|
|
|
22 |
def __init__( |
|
|
23 |
self, |
|
|
24 |
dataset: "sf.Dataset", |
|
|
25 |
mpp: float, |
|
|
26 |
roi_labels: List[str], |
|
|
27 |
*, |
|
|
28 |
mode: str = 'binary', |
|
|
29 |
) -> None: |
|
|
30 |
"""Dataset that generates thumbnails and ROI masks. |
|
|
31 |
|
|
|
32 |
Args: |
|
|
33 |
dataset (sf.Dataset): The dataset to use. |
|
|
34 |
mpp (float): The target microns per pixel. The thumbnail will be |
|
|
35 |
scaled to this resolution. |
|
|
36 |
roi_labels (List[str]): The ROI labels to include in the mask. |
|
|
37 |
|
|
|
38 |
Keyword args: |
|
|
39 |
mode (str, optional): The mode to use for the mask. One of: |
|
|
40 |
'binary', 'multiclass', 'multilabel'. Defaults to 'binary'. |
|
|
41 |
|
|
|
42 |
""" |
|
|
43 |
super().__init__() |
|
|
44 |
self.mpp = mpp |
|
|
45 |
self.mode = mode |
|
|
46 |
self.roi_labels = roi_labels |
|
|
47 |
|
|
|
48 |
# Subsample dataset to only include slides with ROIs. |
|
|
49 |
self.rois = dataset.rois() |
|
|
50 |
slides = set(map(path_to_name, dataset.slide_paths())) |
|
|
51 |
slides = slides.intersection(set(map(path_to_name, self.rois))) |
|
|
52 |
dataset = dataset.filter({'slide': list(slides)}) |
|
|
53 |
|
|
|
54 |
# Prepare WSI objects (for slides with ROIs). |
|
|
55 |
self.paths = dataset.slide_paths() |
|
|
56 |
|
|
|
57 |
def __len__(self) -> int: |
|
|
58 |
return len(self.paths) |
|
|
59 |
|
|
|
60 |
def process( |
|
|
61 |
self, |
|
|
62 |
img: np.ndarray, |
|
|
63 |
mask: np.ndarray |
|
|
64 |
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
65 |
"""Process the image/mask and convert to a tensor.""" |
|
|
66 |
img = torch.from_numpy(img) |
|
|
67 |
mask = torch.from_numpy(mask) |
|
|
68 |
return img, mask |
|
|
69 |
|
|
|
70 |
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
71 |
|
|
|
72 |
# Load the image and mask. |
|
|
73 |
path = self.paths[index] |
|
|
74 |
wsi = sf.WSI(path, 299, 512, rois=self.rois, roi_filter_method=0.1, verbose=False) |
|
|
75 |
output = get_thumb_and_mask(wsi, self.mpp, self.roi_labels, skip_missing=False) |
|
|
76 |
if output is None: |
|
|
77 |
return None |
|
|
78 |
img = output['image'] # CHW (np.ndarray) |
|
|
79 |
mask = output['mask'].astype(int) # 1HW (np.ndarray) |
|
|
80 |
|
|
|
81 |
if self.mode == 'multiclass': |
|
|
82 |
mask = mask * np.arange(1, mask.shape[0]+1)[:, None, None] |
|
|
83 |
mask = mask.max(axis=0) |
|
|
84 |
elif self.mode == 'binary' and mask.ndim == 3: |
|
|
85 |
mask = np.any(mask, axis=0)[None, :, :].astype(int) |
|
|
86 |
|
|
|
87 |
# Process. |
|
|
88 |
img, mask = self.process(img, mask) |
|
|
89 |
|
|
|
90 |
return { |
|
|
91 |
'image': img, |
|
|
92 |
'mask': mask |
|
|
93 |
} |
|
|
94 |
|
|
|
95 |
|
|
|
96 |
class RandomCropDataset(ThumbMaskDataset): |
|
|
97 |
|
|
|
98 |
def __init__(self, *args, size: int = 1024, **kwargs): |
|
|
99 |
"""Dataset that generates thumbnails & ROI masks, with random crops. |
|
|
100 |
|
|
|
101 |
Thumbnails and masks and randomly cropped and rotated together to |
|
|
102 |
a square size of `size` pixels. |
|
|
103 |
|
|
|
104 |
Args: |
|
|
105 |
dataset (sf.Dataset): The dataset to use. |
|
|
106 |
mpp (float): The target microns per pixel. The thumbnail will be |
|
|
107 |
scaled to this resolution. |
|
|
108 |
roi_labels (List[str]): The ROI labels to include in the mask. |
|
|
109 |
size (int, optional): The size of the random crop. Defaults to 1024. |
|
|
110 |
|
|
|
111 |
Keyword Args: |
|
|
112 |
mode (str, optional): The mode to use for the mask. One of: |
|
|
113 |
'binary', 'multiclass', 'multilabel'. Defaults to 'binary'. |
|
|
114 |
|
|
|
115 |
""" |
|
|
116 |
super().__init__(*args, **kwargs) |
|
|
117 |
self.size = size |
|
|
118 |
|
|
|
119 |
def process(self, img, mask): |
|
|
120 |
"""Randomly crop/rotate the image and mask and convert to a tensor.""" |
|
|
121 |
return random_crop_and_rotate(img, mask, size=self.size) |
|
|
122 |
|
|
|
123 |
# ----------------------------------------------------------------------------- |
|
|
124 |
# Buffered datasets |
|
|
125 |
|
|
|
126 |
class BufferedMaskDataset(torch.utils.data.Dataset): |
|
|
127 |
|
|
|
128 |
def __init__(self, dataset: "sf.Dataset", source: str, *, mode: str = 'binary'): |
|
|
129 |
"""Dataset that loads buffered image and mask pairs. |
|
|
130 |
|
|
|
131 |
Args: |
|
|
132 |
dataset (sf.Dataset): The dataset to use. |
|
|
133 |
source (str): The directory containing the buffered image/mask pairs. |
|
|
134 |
|
|
|
135 |
Keyword Args: |
|
|
136 |
mode (str, optional): The mode to use for the mask. One of: |
|
|
137 |
'binary', 'multiclass', 'multilabel'. Defaults to 'binary'. |
|
|
138 |
|
|
|
139 |
""" |
|
|
140 |
super().__init__() |
|
|
141 |
if mode not in ['binary', 'multiclass', 'multilabel']: |
|
|
142 |
raise ValueError("Invalid mode: {}. Expected one of: binary, " |
|
|
143 |
"multiclass, multilabel".format(mode)) |
|
|
144 |
self.dataset = dataset |
|
|
145 |
self.mode = mode |
|
|
146 |
self.paths = [ |
|
|
147 |
join(source, s + '.pt') for s in dataset.slides() |
|
|
148 |
if exists(join(source, s + '.pt')) |
|
|
149 |
] |
|
|
150 |
|
|
|
151 |
|
|
|
152 |
def __len__(self) -> int: |
|
|
153 |
return len(self.paths) |
|
|
154 |
|
|
|
155 |
def process( |
|
|
156 |
self, |
|
|
157 |
img: np.ndarray, |
|
|
158 |
mask: np.ndarray |
|
|
159 |
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
160 |
"""Process the image/mask and convert to a tensor.""" |
|
|
161 |
img = torch.from_numpy(img) |
|
|
162 |
mask = torch.from_numpy(mask) |
|
|
163 |
return img, mask |
|
|
164 |
|
|
|
165 |
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
166 |
# Load the image and mask. |
|
|
167 |
output = torch.load(self.paths[index]) |
|
|
168 |
img = output['image'] # CHW (np.ndarray) |
|
|
169 |
mask = output['mask'].astype(int) # 1HW (np.ndarray) |
|
|
170 |
|
|
|
171 |
if self.mode == 'multiclass': |
|
|
172 |
mask = mask * np.arange(1, mask.shape[0]+1)[:, None, None] |
|
|
173 |
mask = mask.max(axis=0) |
|
|
174 |
elif self.mode == 'binary' and mask.ndim == 3: |
|
|
175 |
mask = np.any(mask, axis=0)[None, :, :].astype(int) |
|
|
176 |
|
|
|
177 |
# Process. |
|
|
178 |
img, mask = self.process(img, mask) |
|
|
179 |
|
|
|
180 |
return { |
|
|
181 |
'image': img, |
|
|
182 |
'mask': mask |
|
|
183 |
} |
|
|
184 |
|
|
|
185 |
|
|
|
186 |
class BufferedRandomCropDataset(BufferedMaskDataset): |
|
|
187 |
|
|
|
188 |
def __init__(self, *args, size: int = 1024, **kwargs): |
|
|
189 |
"""Dataset that loads buffered image/mask pairs and randomly crops. |
|
|
190 |
|
|
|
191 |
Loaded thumbnails and masks and randomly cropped and rotated together to |
|
|
192 |
a square size of `size` pixels. |
|
|
193 |
|
|
|
194 |
Args: |
|
|
195 |
dataset (sf.Dataset): The dataset to use. |
|
|
196 |
source (str): The directory containing the buffered image/mask pairs. |
|
|
197 |
size (int, optional): The size of the random crop. Defaults to 1024. |
|
|
198 |
|
|
|
199 |
Keyword Args: |
|
|
200 |
mode (str, optional): The mode to use for the mask. One of: |
|
|
201 |
'binary', 'multiclass', 'multilabel'. Defaults to 'binary'. |
|
|
202 |
|
|
|
203 |
""" |
|
|
204 |
super().__init__(*args, **kwargs) |
|
|
205 |
self.size = size |
|
|
206 |
|
|
|
207 |
def process( |
|
|
208 |
self, |
|
|
209 |
img: np.ndarray, |
|
|
210 |
mask: np.ndarray |
|
|
211 |
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
212 |
"""Randomly crop/rotate the image and mask and convert to a tensor.""" |
|
|
213 |
return random_crop_and_rotate(img, mask, size=self.size) |
|
|
214 |
|
|
|
215 |
# ----------------------------------------------------------------------------- |
|
|
216 |
|
|
|
217 |
class TileMaskDataset(torch.utils.data.Dataset): |
|
|
218 |
|
|
|
219 |
def __init__( |
|
|
220 |
self, |
|
|
221 |
dataset: "sf.Dataset", |
|
|
222 |
tile_px: int, |
|
|
223 |
tile_um: Union[int, str], |
|
|
224 |
*, |
|
|
225 |
roi_labels: Optional[List[str]] = None, |
|
|
226 |
stride_div: int = 2, |
|
|
227 |
crop_margin: int = 0, |
|
|
228 |
filter_method: str = 'otsu', |
|
|
229 |
mode: str = 'binary' |
|
|
230 |
): |
|
|
231 |
"""Dataset that generates tiles and ROI masks from slides. |
|
|
232 |
|
|
|
233 |
Args: |
|
|
234 |
dataset (sf.Dataset): The dataset to use. |
|
|
235 |
tile_px (int): The size of the tiles (in pixels). |
|
|
236 |
tile_um (Union[int, str]): The size of the tiles (in microns). |
|
|
237 |
|
|
|
238 |
Keyword args: |
|
|
239 |
stride_div (int, optional): The divisor for the stride. |
|
|
240 |
Defaults to 2. |
|
|
241 |
crop_margin (int, optional): The number of pixels to add to the |
|
|
242 |
tile size before random cropping to the target tile_px size. |
|
|
243 |
Defaults to 0 (no random cropping). |
|
|
244 |
filter_method (str, optional): The method to use for identifying |
|
|
245 |
tiles for training. If 'roi', selects only tiles that intersect |
|
|
246 |
with an ROI. If 'otsu', selects tiles based on an Otsu threshold |
|
|
247 |
of the slide. Defaults to 'roi'. |
|
|
248 |
|
|
|
249 |
""" |
|
|
250 |
super().__init__() |
|
|
251 |
|
|
|
252 |
rois = dataset.rois() |
|
|
253 |
slides_with_rois = [path_to_name(r) for r in rois] |
|
|
254 |
slides = [s for s in dataset.slide_paths() |
|
|
255 |
if path_to_name(s) in slides_with_rois] |
|
|
256 |
kw = dict( |
|
|
257 |
tile_px=tile_px + crop_margin, |
|
|
258 |
tile_um=tile_um, |
|
|
259 |
verbose=False, |
|
|
260 |
stride_div=stride_div |
|
|
261 |
) |
|
|
262 |
if roi_labels is None: |
|
|
263 |
roi_labels = [] |
|
|
264 |
self.mode = mode |
|
|
265 |
self.roi_labels = roi_labels |
|
|
266 |
self.tile_px = tile_px |
|
|
267 |
self.coords = [] |
|
|
268 |
self.all_wsi = dict() |
|
|
269 |
self.all_wsi_with_roi = dict() |
|
|
270 |
self.all_extract_px = dict() |
|
|
271 |
for slide in track(slides, description="Loading slides"): |
|
|
272 |
name = path_to_name(slide) |
|
|
273 |
wsi = sf.WSI(slide, **kw) |
|
|
274 |
try: |
|
|
275 |
wsi_with_rois = sf.WSI(slide, roi_filter_method=0.1, rois=rois, **kw) |
|
|
276 |
except Exception as e: |
|
|
277 |
sf.log.error("Failed to load slide {}: {}".format(slide, e)) |
|
|
278 |
raise e |
|
|
279 |
|
|
|
280 |
# Filter ROIs to only include the specified labels. |
|
|
281 |
if self.roi_labels: |
|
|
282 |
wsi_with_rois.rois = [roi for roi in wsi_with_rois.rois if roi.label in self.roi_labels] |
|
|
283 |
wsi_with_rois.process_rois() |
|
|
284 |
|
|
|
285 |
if not len(wsi_with_rois.rois): |
|
|
286 |
continue |
|
|
287 |
if filter_method == 'roi': |
|
|
288 |
wsi_inner = sf.WSI(slide, roi_filter_method=0.9, rois=rois, **kw) |
|
|
289 |
if self.roi_labels: |
|
|
290 |
wsi_inner.rois = [roi for roi in wsi_with_rois.rois if roi.label in self.roi_labels] |
|
|
291 |
wsi_inner.process_rois() |
|
|
292 |
coords = np.argwhere(wsi_with_rois.grid & (~wsi_inner.grid)).tolist() |
|
|
293 |
elif filter_method == 'otsu': |
|
|
294 |
wsi.qc('otsu') |
|
|
295 |
coords = np.argwhere(wsi.grid).tolist() |
|
|
296 |
wsi.remove_qc() |
|
|
297 |
elif filter_method in ['all', 'none', None]: |
|
|
298 |
coords = np.argwhere(wsi_with_rois.grid).tolist() |
|
|
299 |
else: |
|
|
300 |
raise ValueError("Invalid filter method: {}. Expected one of: " |
|
|
301 |
"roi, otsu".format(filter_method)) |
|
|
302 |
for c in coords: |
|
|
303 |
self.coords.append([name] + c) |
|
|
304 |
self.all_wsi[name] = wsi |
|
|
305 |
self.all_wsi_with_roi[name] = wsi_with_rois |
|
|
306 |
self.all_extract_px[name] = int(wsi.tile_um / wsi.mpp) |
|
|
307 |
|
|
|
308 |
def __len__(self): |
|
|
309 |
return len(self.coords) |
|
|
310 |
|
|
|
311 |
def get_scaled_and_intersecting_polys( |
|
|
312 |
self, |
|
|
313 |
polys: "Polygon", |
|
|
314 |
tile: "Polygon", |
|
|
315 |
scale: float, |
|
|
316 |
full_stride: int, |
|
|
317 |
grid_idx: Tuple[int, int] |
|
|
318 |
): |
|
|
319 |
"""Get scaled and intersecting polygons for a given tile.""" |
|
|
320 |
gx, gy = grid_idx |
|
|
321 |
A = polys.intersection(tile) |
|
|
322 |
|
|
|
323 |
# Translate polygons so the intersection origin is at (0, 0) |
|
|
324 |
B = sa.translate(A, -(full_stride*gx), -(full_stride*gy)) |
|
|
325 |
|
|
|
326 |
# Scale to the target tile size |
|
|
327 |
C = sa.scale(B, xfact=scale, yfact=scale, origin=(0, 0)) |
|
|
328 |
return C |
|
|
329 |
|
|
|
330 |
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
331 |
"""Get an image and mask for a given index.""" |
|
|
332 |
slide, gx, gy = self.coords[index] |
|
|
333 |
wsi = self.all_wsi[slide] |
|
|
334 |
wsi_with_roi = self.all_wsi_with_roi[slide] |
|
|
335 |
fe = self.all_extract_px[slide] |
|
|
336 |
fs = wsi.full_stride |
|
|
337 |
scale = wsi.tile_px / fe |
|
|
338 |
|
|
|
339 |
# Get the image. |
|
|
340 |
img = wsi[gx, gy].transpose(2, 0, 1) |
|
|
341 |
|
|
|
342 |
# Get a polygon for the tile, used for determining overlapping ROIs. |
|
|
343 |
tile = Polygon([ |
|
|
344 |
[fs*gx, fs*gy], |
|
|
345 |
[fs*gx, (fs*gy)+fe], |
|
|
346 |
[(fs*gx)+fe, (fs*gy)+fe], |
|
|
347 |
[(fs*gx)+fe, fs*gy] |
|
|
348 |
]) |
|
|
349 |
|
|
|
350 |
# Compute the mask from ROIs. |
|
|
351 |
if len(wsi_with_roi.rois) == 0: |
|
|
352 |
if self.roi_labels: |
|
|
353 |
mask = np.zeros((len(self.roi_labels), wsi.tile_px, wsi.tile_px), dtype=int) |
|
|
354 |
else: |
|
|
355 |
mask = np.zeros((1, wsi.tile_px, wsi.tile_px), dtype=int) |
|
|
356 |
|
|
|
357 |
# Handle ROIs with labels (multilabel or multiclass) |
|
|
358 |
elif self.roi_labels: |
|
|
359 |
labeled_masks = [] |
|
|
360 |
for i, label in enumerate(self.roi_labels): |
|
|
361 |
wsi_polys = [p.poly for p in wsi_with_roi.rois if p.label == label] |
|
|
362 |
if len(wsi_polys) == 0: |
|
|
363 |
mask = np.zeros((wsi.tile_px, wsi.tile_px), dtype=int) |
|
|
364 |
labeled_masks.append(mask) |
|
|
365 |
else: |
|
|
366 |
all_polys = unary_union(wsi_polys) |
|
|
367 |
polys = self.get_scaled_and_intersecting_polys( |
|
|
368 |
all_polys, tile, scale, fs, (gx, gy) |
|
|
369 |
) |
|
|
370 |
if isinstance(polys, Polygon) and polys.is_empty: |
|
|
371 |
mask = np.zeros((wsi.tile_px, wsi.tile_px), dtype=int) |
|
|
372 |
else: |
|
|
373 |
# Rasterize to an int mask. |
|
|
374 |
mask = rasterio.features.rasterize([polys], out_shape=[wsi.tile_px, wsi.tile_px]).astype(int) |
|
|
375 |
labeled_masks.append(mask) |
|
|
376 |
mask = np.stack(labeled_masks, axis=0) |
|
|
377 |
|
|
|
378 |
# Handle ROIs without labels (binary) |
|
|
379 |
else: |
|
|
380 |
# Determine the intersection at the given tile location. |
|
|
381 |
all_polys = unary_union([p.poly for p in wsi_with_roi.rois]) |
|
|
382 |
polys = self.get_scaled_and_intersecting_polys( |
|
|
383 |
all_polys, tile, scale, fs, (gx, gy) |
|
|
384 |
) |
|
|
385 |
|
|
|
386 |
if isinstance(polys, Polygon) and polys.is_empty: |
|
|
387 |
mask = np.zeros((wsi.tile_px, wsi.tile_px), dtype=int) |
|
|
388 |
else: |
|
|
389 |
# Rasterize to an int mask. |
|
|
390 |
try: |
|
|
391 |
mask = rasterio.features.rasterize([polys], out_shape=[wsi.tile_px, wsi.tile_px]).astype(bool).astype(np.int32) |
|
|
392 |
except ValueError: |
|
|
393 |
mask = np.zeros((wsi.tile_px, wsi.tile_px), dtype=int) |
|
|
394 |
|
|
|
395 |
# Add a dummy channel dimension. |
|
|
396 |
mask = mask[None, :, :] |
|
|
397 |
|
|
|
398 |
# Process according to the mode. |
|
|
399 |
if self.mode == 'multiclass': |
|
|
400 |
mask = mask * np.arange(1, mask.shape[0]+1)[:, None, None] |
|
|
401 |
mask = mask.max(axis=0) |
|
|
402 |
elif self.mode == 'binary' and mask.ndim == 3: |
|
|
403 |
mask = np.any(mask, axis=0)[None, :, :].astype(int) |
|
|
404 |
|
|
|
405 |
# Process. |
|
|
406 |
img, mask = self.process(img, mask) |
|
|
407 |
|
|
|
408 |
return { |
|
|
409 |
'image': img, |
|
|
410 |
'mask': mask |
|
|
411 |
} |
|
|
412 |
|
|
|
413 |
def process( |
|
|
414 |
self, |
|
|
415 |
img: np.ndarray, |
|
|
416 |
mask: np.ndarray |
|
|
417 |
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
418 |
"""Randomly crop/rotate the image and mask and convert to a tensor.""" |
|
|
419 |
return random_crop_and_rotate(img, mask, size=self.tile_px) |
|
|
420 |
|
|
|
421 |
# ----------------------------------------------------------------------------- |
|
|
422 |
|
|
|
423 |
def random_crop_and_rotate(img, mask, size): |
|
|
424 |
if mask.ndim == 2: |
|
|
425 |
to_squeeze = True |
|
|
426 |
mask = mask[None, :, :] |
|
|
427 |
else: |
|
|
428 |
to_squeeze = False |
|
|
429 |
|
|
|
430 |
# Convert to tensor. |
|
|
431 |
img = torch.from_numpy(img).permute(1, 2, 0) |
|
|
432 |
mask = torch.from_numpy(mask).permute(1, 2, 0) |
|
|
433 |
|
|
|
434 |
# Pad to target size. |
|
|
435 |
img = topleft_pad_torch(img, size).permute(2, 0, 1) |
|
|
436 |
mask = topleft_pad_torch(mask, size).permute(2, 0, 1) |
|
|
437 |
|
|
|
438 |
# Random crop. |
|
|
439 |
i, j, h, w = transforms.RandomCrop.get_params( |
|
|
440 |
img, output_size=(size, size)) |
|
|
441 |
img = transforms.functional.crop(img, i, j, h, w) |
|
|
442 |
mask = transforms.functional.crop(mask, i, j, h, w) |
|
|
443 |
|
|
|
444 |
# Random flip. |
|
|
445 |
if np.random.rand() > 0.5: |
|
|
446 |
img = transforms.functional.hflip(img) |
|
|
447 |
mask = transforms.functional.hflip(mask) |
|
|
448 |
if np.random.rand() > 0.5: |
|
|
449 |
img = transforms.functional.vflip(img) |
|
|
450 |
mask = transforms.functional.vflip(mask) |
|
|
451 |
|
|
|
452 |
# Random cardinal rotation. |
|
|
453 |
r = np.random.randint(4) |
|
|
454 |
img = transforms.functional.rotate(img, r * 90) |
|
|
455 |
mask = transforms.functional.rotate(mask, r * 90) |
|
|
456 |
|
|
|
457 |
if to_squeeze: |
|
|
458 |
mask = mask.squeeze(0) |
|
|
459 |
|
|
|
460 |
return img, mask |
|
|
461 |
|
|
|
462 |
# ----------------------------------------------------------------------------- |
|
|
463 |
|
|
|
464 |
def get_thumb_and_mask( |
|
|
465 |
wsi: "sf.WSI", |
|
|
466 |
mpp: float, |
|
|
467 |
roi_labels: Optional[List[str]] = None, |
|
|
468 |
skip_missing: bool = False |
|
|
469 |
) -> Dict[str, np.ndarray]: |
|
|
470 |
"""Get a thumbnail and segmentation mask for a slide.""" |
|
|
471 |
|
|
|
472 |
if len(wsi.rois) == 0 and skip_missing: |
|
|
473 |
return None |
|
|
474 |
|
|
|
475 |
# Sanity check. |
|
|
476 |
width = int((wsi.mpp * wsi.dimensions[0]) / mpp) |
|
|
477 |
ds = wsi.dimensions[0] / width |
|
|
478 |
level = wsi.slide.best_level_for_downsample(ds) |
|
|
479 |
level_dim = wsi.slide.level_dimensions[level] |
|
|
480 |
if any([d > 10000 for d in level_dim]): |
|
|
481 |
sf.log.warning("Large thumbnail found ({}) at level={} for {}".format( |
|
|
482 |
level_dim, level, wsi.path |
|
|
483 |
)) |
|
|
484 |
|
|
|
485 |
# Get the thumbnail. |
|
|
486 |
thumb = wsi.thumb(mpp=mpp).convert('RGB') |
|
|
487 |
img = np.array(thumb).transpose(2, 0, 1) |
|
|
488 |
xfact = thumb.size[1] / wsi.dimensions[1] |
|
|
489 |
yfact = thumb.size[0] / wsi.dimensions[0] |
|
|
490 |
|
|
|
491 |
if len(wsi.rois) == 0: |
|
|
492 |
if roi_labels: |
|
|
493 |
mask = np.zeros((len(roi_labels), thumb.size[1], thumb.size[0])).astype(bool) |
|
|
494 |
else: |
|
|
495 |
mask = np.zeros((1, thumb.size[1], thumb.size[0])).astype(bool) |
|
|
496 |
elif roi_labels: |
|
|
497 |
labeled_masks = [] |
|
|
498 |
for i, label in enumerate(roi_labels): |
|
|
499 |
wsi_polys = [p.poly for p in wsi.rois if p.label == label] |
|
|
500 |
if len(wsi_polys) == 0: |
|
|
501 |
mask = np.zeros((thumb.size[1], thumb.size[0])).astype(bool) |
|
|
502 |
labeled_masks.append(mask) |
|
|
503 |
else: |
|
|
504 |
all_polys = unary_union(wsi_polys) |
|
|
505 |
# Scale ROIs to the thumbnail size. |
|
|
506 |
C = sa.scale(all_polys, xfact=xfact, yfact=yfact, origin=(0, 0)) |
|
|
507 |
# Rasterize to an int mask. |
|
|
508 |
mask = rasterio.features.rasterize([C], out_shape=(thumb.size[1], thumb.size[0])).astype(bool).astype(np.int32) |
|
|
509 |
labeled_masks.append(mask) |
|
|
510 |
mask = np.stack(labeled_masks, axis=0) |
|
|
511 |
|
|
|
512 |
else: |
|
|
513 |
all_polys = unary_union([p.poly for p in wsi.rois]) |
|
|
514 |
# Scale ROIs to the thumbnail size. |
|
|
515 |
C = sa.scale(all_polys, xfact=xfact, yfact=yfact, origin=(0, 0)) |
|
|
516 |
# Rasterize to an int mask. |
|
|
517 |
mask = rasterio.features.rasterize([C], out_shape=(thumb.size[1], thumb.size[0])).astype(bool) |
|
|
518 |
# Add a dummy channel dimension. |
|
|
519 |
mask = mask[None, :, :] |
|
|
520 |
|
|
|
521 |
assert img.shape[1:] == mask.shape[1:], "Image and mask must have the same dimensions." |
|
|
522 |
assert mask.ndim == 3, "Mask must have 3 dimensions (C, H, W)." |
|
|
523 |
assert img.ndim == 3, "Image must have 3 dimensions (C, H, W)." |
|
|
524 |
|
|
|
525 |
return { |
|
|
526 |
'image': img, |
|
|
527 |
'mask': mask |
|
|
528 |
} |