Diff of /slideflow/segment/data.py [000000] .. [78ef36]

Switch to unified view

a b/slideflow/segment/data.py
1
import torch
2
import slideflow as sf
3
import rasterio
4
import numpy as np
5
import shapely.affinity as sa
6
7
from typing import Tuple, Union, Optional, List, Dict
8
from torchvision import transforms
9
from os.path import join, exists
10
from rich.progress import track
11
from shapely.ops import unary_union
12
from shapely.geometry import Polygon
13
from shapely.ops import unary_union
14
from slideflow.util import path_to_name
15
16
from .utils import topleft_pad_torch
17
18
# -----------------------------------------------------------------------------
19
20
class ThumbMaskDataset(torch.utils.data.Dataset):
21
22
    def __init__(
23
        self,
24
        dataset: "sf.Dataset",
25
        mpp: float,
26
        roi_labels: List[str],
27
        *,
28
        mode: str = 'binary',
29
    ) -> None:
30
        """Dataset that generates thumbnails and ROI masks.
31
32
        Args:
33
            dataset (sf.Dataset): The dataset to use.
34
            mpp (float): The target microns per pixel. The thumbnail will be
35
                scaled to this resolution.
36
            roi_labels (List[str]): The ROI labels to include in the mask.
37
38
        Keyword args:
39
            mode (str, optional): The mode to use for the mask. One of:
40
                'binary', 'multiclass', 'multilabel'. Defaults to 'binary'.
41
42
        """
43
        super().__init__()
44
        self.mpp = mpp
45
        self.mode = mode
46
        self.roi_labels = roi_labels
47
48
        # Subsample dataset to only include slides with ROIs.
49
        self.rois = dataset.rois()
50
        slides = set(map(path_to_name, dataset.slide_paths()))
51
        slides = slides.intersection(set(map(path_to_name, self.rois)))
52
        dataset = dataset.filter({'slide': list(slides)})
53
54
        # Prepare WSI objects (for slides with ROIs).
55
        self.paths = dataset.slide_paths()
56
57
    def __len__(self) -> int:
58
        return len(self.paths)
59
60
    def process(
61
        self,
62
        img: np.ndarray,
63
        mask: np.ndarray
64
    ) -> Tuple[torch.Tensor, torch.Tensor]:
65
        """Process the image/mask and convert to a tensor."""
66
        img = torch.from_numpy(img)
67
        mask = torch.from_numpy(mask)
68
        return img, mask
69
70
    def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
71
72
        # Load the image and mask.
73
        path = self.paths[index]
74
        wsi = sf.WSI(path, 299, 512, rois=self.rois, roi_filter_method=0.1, verbose=False)
75
        output = get_thumb_and_mask(wsi, self.mpp, self.roi_labels, skip_missing=False)
76
        if output is None:
77
            return None
78
        img = output['image']               # CHW (np.ndarray)
79
        mask = output['mask'].astype(int)   # 1HW (np.ndarray)
80
81
        if self.mode == 'multiclass':
82
            mask = mask * np.arange(1, mask.shape[0]+1)[:, None, None]
83
            mask = mask.max(axis=0)
84
        elif self.mode == 'binary' and mask.ndim == 3:
85
            mask = np.any(mask, axis=0)[None, :, :].astype(int)
86
87
        # Process.
88
        img, mask = self.process(img, mask)
89
90
        return {
91
            'image': img,
92
            'mask': mask
93
        }
94
95
96
class RandomCropDataset(ThumbMaskDataset):
97
98
    def __init__(self, *args, size: int = 1024, **kwargs):
99
        """Dataset that generates thumbnails & ROI masks, with random crops.
100
101
        Thumbnails and masks and randomly cropped and rotated together to
102
        a square size of `size` pixels.
103
104
        Args:
105
            dataset (sf.Dataset): The dataset to use.
106
            mpp (float): The target microns per pixel. The thumbnail will be
107
                scaled to this resolution.
108
            roi_labels (List[str]): The ROI labels to include in the mask.
109
            size (int, optional): The size of the random crop. Defaults to 1024.
110
111
        Keyword Args:
112
            mode (str, optional): The mode to use for the mask. One of:
113
                'binary', 'multiclass', 'multilabel'. Defaults to 'binary'.
114
115
        """
116
        super().__init__(*args, **kwargs)
117
        self.size = size
118
119
    def process(self, img, mask):
120
        """Randomly crop/rotate the image and mask and convert to a tensor."""
121
        return random_crop_and_rotate(img, mask, size=self.size)
122
123
# -----------------------------------------------------------------------------
124
# Buffered datasets
125
126
class BufferedMaskDataset(torch.utils.data.Dataset):
127
128
    def __init__(self, dataset: "sf.Dataset", source: str, *, mode: str = 'binary'):
129
        """Dataset that loads buffered image and mask pairs.
130
131
        Args:
132
            dataset (sf.Dataset): The dataset to use.
133
            source (str): The directory containing the buffered image/mask pairs.
134
135
        Keyword Args:
136
            mode (str, optional): The mode to use for the mask. One of:
137
                'binary', 'multiclass', 'multilabel'. Defaults to 'binary'.
138
139
        """
140
        super().__init__()
141
        if mode not in ['binary', 'multiclass', 'multilabel']:
142
            raise ValueError("Invalid mode: {}. Expected one of: binary, "
143
                             "multiclass, multilabel".format(mode))
144
        self.dataset = dataset
145
        self.mode = mode
146
        self.paths = [
147
            join(source, s + '.pt') for s in dataset.slides()
148
            if exists(join(source, s + '.pt'))
149
        ]
150
151
152
    def __len__(self) -> int:
153
        return len(self.paths)
154
155
    def process(
156
        self,
157
        img: np.ndarray,
158
        mask: np.ndarray
159
    ) -> Tuple[torch.Tensor, torch.Tensor]:
160
        """Process the image/mask and convert to a tensor."""
161
        img = torch.from_numpy(img)
162
        mask = torch.from_numpy(mask)
163
        return img, mask
164
165
    def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
166
        # Load the image and mask.
167
        output = torch.load(self.paths[index])
168
        img = output['image']               # CHW (np.ndarray)
169
        mask = output['mask'].astype(int)   # 1HW (np.ndarray)
170
171
        if self.mode == 'multiclass':
172
            mask = mask * np.arange(1, mask.shape[0]+1)[:, None, None]
173
            mask = mask.max(axis=0)
174
        elif self.mode == 'binary' and mask.ndim == 3:
175
            mask = np.any(mask, axis=0)[None, :, :].astype(int)
176
177
        # Process.
178
        img, mask = self.process(img, mask)
179
180
        return {
181
            'image': img,
182
            'mask': mask
183
        }
184
185
186
class BufferedRandomCropDataset(BufferedMaskDataset):
187
188
    def __init__(self, *args, size: int = 1024, **kwargs):
189
        """Dataset that loads buffered image/mask pairs and randomly crops.
190
191
        Loaded thumbnails and masks and randomly cropped and rotated together to
192
        a square size of `size` pixels.
193
194
        Args:
195
            dataset (sf.Dataset): The dataset to use.
196
            source (str): The directory containing the buffered image/mask pairs.
197
            size (int, optional): The size of the random crop. Defaults to 1024.
198
199
        Keyword Args:
200
            mode (str, optional): The mode to use for the mask. One of:
201
                'binary', 'multiclass', 'multilabel'. Defaults to 'binary'.
202
203
        """
204
        super().__init__(*args, **kwargs)
205
        self.size = size
206
207
    def process(
208
        self,
209
        img: np.ndarray,
210
        mask: np.ndarray
211
    ) -> Tuple[torch.Tensor, torch.Tensor]:
212
        """Randomly crop/rotate the image and mask and convert to a tensor."""
213
        return random_crop_and_rotate(img, mask, size=self.size)
214
215
# -----------------------------------------------------------------------------
216
217
class TileMaskDataset(torch.utils.data.Dataset):
218
219
    def __init__(
220
        self,
221
        dataset: "sf.Dataset",
222
        tile_px: int,
223
        tile_um: Union[int, str],
224
        *,
225
        roi_labels: Optional[List[str]] = None,
226
        stride_div: int = 2,
227
        crop_margin: int = 0,
228
        filter_method: str = 'otsu',
229
        mode: str = 'binary'
230
    ):
231
        """Dataset that generates tiles and ROI masks from slides.
232
233
        Args:
234
            dataset (sf.Dataset): The dataset to use.
235
            tile_px (int): The size of the tiles (in pixels).
236
            tile_um (Union[int, str]): The size of the tiles (in microns).
237
238
        Keyword args:
239
            stride_div (int, optional): The divisor for the stride.
240
                Defaults to 2.
241
            crop_margin (int, optional): The number of pixels to add to the
242
                tile size before random cropping to the target tile_px size.
243
                Defaults to 0 (no random cropping).
244
            filter_method (str, optional): The method to use for identifying
245
                tiles for training. If 'roi', selects only tiles that intersect
246
                with an ROI. If 'otsu', selects tiles based on an Otsu threshold
247
                of the slide. Defaults to 'roi'.
248
249
        """
250
        super().__init__()
251
252
        rois = dataset.rois()
253
        slides_with_rois = [path_to_name(r) for r in rois]
254
        slides = [s for s in dataset.slide_paths()
255
                  if path_to_name(s) in slides_with_rois]
256
        kw = dict(
257
            tile_px=tile_px + crop_margin,
258
            tile_um=tile_um,
259
            verbose=False,
260
            stride_div=stride_div
261
        )
262
        if roi_labels is None:
263
            roi_labels = []
264
        self.mode = mode
265
        self.roi_labels = roi_labels
266
        self.tile_px = tile_px
267
        self.coords = []
268
        self.all_wsi = dict()
269
        self.all_wsi_with_roi = dict()
270
        self.all_extract_px = dict()
271
        for slide in track(slides, description="Loading slides"):
272
            name = path_to_name(slide)
273
            wsi = sf.WSI(slide, **kw)
274
            try:
275
                wsi_with_rois = sf.WSI(slide, roi_filter_method=0.1, rois=rois, **kw)
276
            except Exception as e:
277
                sf.log.error("Failed to load slide {}: {}".format(slide, e))
278
                raise e
279
280
            # Filter ROIs to only include the specified labels.
281
            if self.roi_labels:
282
                wsi_with_rois.rois = [roi for roi in wsi_with_rois.rois if roi.label in self.roi_labels]
283
                wsi_with_rois.process_rois()
284
285
            if not len(wsi_with_rois.rois):
286
                continue
287
            if filter_method == 'roi':
288
                wsi_inner = sf.WSI(slide, roi_filter_method=0.9, rois=rois, **kw)
289
                if self.roi_labels:
290
                    wsi_inner.rois = [roi for roi in wsi_with_rois.rois if roi.label in self.roi_labels]
291
                    wsi_inner.process_rois()
292
                coords = np.argwhere(wsi_with_rois.grid & (~wsi_inner.grid)).tolist()
293
            elif filter_method == 'otsu':
294
                wsi.qc('otsu')
295
                coords = np.argwhere(wsi.grid).tolist()
296
                wsi.remove_qc()
297
            elif filter_method in ['all', 'none', None]:
298
                coords = np.argwhere(wsi_with_rois.grid).tolist()
299
            else:
300
                raise ValueError("Invalid filter method: {}. Expected one of: "
301
                                 "roi, otsu".format(filter_method))
302
            for c in coords:
303
                self.coords.append([name] + c)
304
            self.all_wsi[name] = wsi
305
            self.all_wsi_with_roi[name] = wsi_with_rois
306
            self.all_extract_px[name] = int(wsi.tile_um / wsi.mpp)
307
308
    def __len__(self):
309
        return len(self.coords)
310
311
    def get_scaled_and_intersecting_polys(
312
        self,
313
        polys: "Polygon",
314
        tile: "Polygon",
315
        scale: float,
316
        full_stride: int,
317
        grid_idx: Tuple[int, int]
318
    ):
319
        """Get scaled and intersecting polygons for a given tile."""
320
        gx, gy = grid_idx
321
        A = polys.intersection(tile)
322
323
        # Translate polygons so the intersection origin is at (0, 0)
324
        B = sa.translate(A, -(full_stride*gx), -(full_stride*gy))
325
326
        # Scale to the target tile size
327
        C = sa.scale(B, xfact=scale, yfact=scale, origin=(0, 0))
328
        return C
329
330
    def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
331
        """Get an image and mask for a given index."""
332
        slide, gx, gy = self.coords[index]
333
        wsi = self.all_wsi[slide]
334
        wsi_with_roi = self.all_wsi_with_roi[slide]
335
        fe = self.all_extract_px[slide]
336
        fs = wsi.full_stride
337
        scale = wsi.tile_px / fe
338
339
        # Get the image.
340
        img = wsi[gx, gy].transpose(2, 0, 1)
341
342
        # Get a polygon for the tile, used for determining overlapping ROIs.
343
        tile = Polygon([
344
            [fs*gx, fs*gy],
345
            [fs*gx, (fs*gy)+fe],
346
            [(fs*gx)+fe, (fs*gy)+fe],
347
            [(fs*gx)+fe, fs*gy]
348
        ])
349
350
        # Compute the mask from ROIs.
351
        if len(wsi_with_roi.rois) == 0:
352
            if self.roi_labels:
353
                mask = np.zeros((len(self.roi_labels), wsi.tile_px, wsi.tile_px), dtype=int)
354
            else:
355
                mask = np.zeros((1, wsi.tile_px, wsi.tile_px), dtype=int)
356
357
        # Handle ROIs with labels (multilabel or multiclass)
358
        elif self.roi_labels:
359
            labeled_masks = []
360
            for i, label in enumerate(self.roi_labels):
361
                wsi_polys = [p.poly for p in wsi_with_roi.rois if p.label == label]
362
                if len(wsi_polys) == 0:
363
                    mask = np.zeros((wsi.tile_px, wsi.tile_px), dtype=int)
364
                    labeled_masks.append(mask)
365
                else:
366
                    all_polys = unary_union(wsi_polys)
367
                    polys = self.get_scaled_and_intersecting_polys(
368
                        all_polys, tile, scale, fs, (gx, gy)
369
                    )
370
                    if isinstance(polys, Polygon) and polys.is_empty:
371
                        mask = np.zeros((wsi.tile_px, wsi.tile_px), dtype=int)
372
                    else:
373
                        # Rasterize to an int mask.
374
                        mask = rasterio.features.rasterize([polys], out_shape=[wsi.tile_px, wsi.tile_px]).astype(int)
375
                    labeled_masks.append(mask)
376
            mask = np.stack(labeled_masks, axis=0)
377
378
        # Handle ROIs without labels (binary)
379
        else:
380
            # Determine the intersection at the given tile location.
381
            all_polys = unary_union([p.poly for p in wsi_with_roi.rois])
382
            polys = self.get_scaled_and_intersecting_polys(
383
                all_polys, tile, scale, fs, (gx, gy)
384
            )
385
386
            if isinstance(polys, Polygon) and polys.is_empty:
387
                mask = np.zeros((wsi.tile_px, wsi.tile_px), dtype=int)
388
            else:
389
                # Rasterize to an int mask.
390
                try:
391
                    mask = rasterio.features.rasterize([polys], out_shape=[wsi.tile_px, wsi.tile_px]).astype(bool).astype(np.int32)
392
                except ValueError:
393
                    mask = np.zeros((wsi.tile_px, wsi.tile_px), dtype=int)
394
395
            # Add a dummy channel dimension.
396
            mask = mask[None, :, :]
397
398
        # Process according to the mode.
399
        if self.mode == 'multiclass':
400
            mask = mask * np.arange(1, mask.shape[0]+1)[:, None, None]
401
            mask = mask.max(axis=0)
402
        elif self.mode == 'binary' and mask.ndim == 3:
403
            mask = np.any(mask, axis=0)[None, :, :].astype(int)
404
405
        # Process.
406
        img, mask = self.process(img, mask)
407
408
        return {
409
            'image': img,
410
            'mask': mask
411
        }
412
413
    def process(
414
        self,
415
        img: np.ndarray,
416
        mask: np.ndarray
417
    ) -> Tuple[torch.Tensor, torch.Tensor]:
418
        """Randomly crop/rotate the image and mask and convert to a tensor."""
419
        return random_crop_and_rotate(img, mask, size=self.tile_px)
420
421
# -----------------------------------------------------------------------------
422
423
def random_crop_and_rotate(img, mask, size):
424
    if mask.ndim == 2:
425
        to_squeeze = True
426
        mask = mask[None, :, :]
427
    else:
428
        to_squeeze = False
429
430
    # Convert to tensor.
431
    img = torch.from_numpy(img).permute(1, 2, 0)
432
    mask = torch.from_numpy(mask).permute(1, 2, 0)
433
434
    # Pad to target size.
435
    img = topleft_pad_torch(img, size).permute(2, 0, 1)
436
    mask = topleft_pad_torch(mask, size).permute(2, 0, 1)
437
438
    # Random crop.
439
    i, j, h, w = transforms.RandomCrop.get_params(
440
        img, output_size=(size, size))
441
    img = transforms.functional.crop(img, i, j, h, w)
442
    mask = transforms.functional.crop(mask, i, j, h, w)
443
444
    # Random flip.
445
    if np.random.rand() > 0.5:
446
        img = transforms.functional.hflip(img)
447
        mask = transforms.functional.hflip(mask)
448
    if np.random.rand() > 0.5:
449
        img = transforms.functional.vflip(img)
450
        mask = transforms.functional.vflip(mask)
451
452
    # Random cardinal rotation.
453
    r = np.random.randint(4)
454
    img = transforms.functional.rotate(img, r * 90)
455
    mask = transforms.functional.rotate(mask, r * 90)
456
457
    if to_squeeze:
458
        mask = mask.squeeze(0)
459
460
    return img, mask
461
462
# -----------------------------------------------------------------------------
463
464
def get_thumb_and_mask(
465
    wsi: "sf.WSI",
466
    mpp: float,
467
    roi_labels: Optional[List[str]] = None,
468
    skip_missing: bool = False
469
) -> Dict[str, np.ndarray]:
470
    """Get a thumbnail and segmentation mask for a slide."""
471
472
    if len(wsi.rois) == 0 and skip_missing:
473
        return None
474
475
    # Sanity check.
476
    width = int((wsi.mpp * wsi.dimensions[0]) / mpp)
477
    ds = wsi.dimensions[0] / width
478
    level = wsi.slide.best_level_for_downsample(ds)
479
    level_dim = wsi.slide.level_dimensions[level]
480
    if any([d > 10000 for d in level_dim]):
481
        sf.log.warning("Large thumbnail found ({}) at level={} for {}".format(
482
            level_dim, level, wsi.path
483
        ))
484
485
    # Get the thumbnail.
486
    thumb = wsi.thumb(mpp=mpp).convert('RGB')
487
    img = np.array(thumb).transpose(2, 0, 1)
488
    xfact = thumb.size[1] / wsi.dimensions[1]
489
    yfact = thumb.size[0] / wsi.dimensions[0]
490
491
    if len(wsi.rois) == 0:
492
        if roi_labels:
493
            mask = np.zeros((len(roi_labels), thumb.size[1], thumb.size[0])).astype(bool)
494
        else:
495
            mask = np.zeros((1, thumb.size[1], thumb.size[0])).astype(bool)
496
    elif roi_labels:
497
        labeled_masks = []
498
        for i, label in enumerate(roi_labels):
499
            wsi_polys = [p.poly for p in wsi.rois if p.label == label]
500
            if len(wsi_polys) == 0:
501
                mask = np.zeros((thumb.size[1], thumb.size[0])).astype(bool)
502
                labeled_masks.append(mask)
503
            else:
504
                all_polys = unary_union(wsi_polys)
505
                # Scale ROIs to the thumbnail size.
506
                C = sa.scale(all_polys, xfact=xfact, yfact=yfact, origin=(0, 0))
507
                # Rasterize to an int mask.
508
                mask = rasterio.features.rasterize([C], out_shape=(thumb.size[1], thumb.size[0])).astype(bool).astype(np.int32)
509
                labeled_masks.append(mask)
510
        mask = np.stack(labeled_masks, axis=0)
511
512
    else:
513
        all_polys = unary_union([p.poly for p in wsi.rois])
514
        # Scale ROIs to the thumbnail size.
515
        C = sa.scale(all_polys, xfact=xfact, yfact=yfact, origin=(0, 0))
516
        # Rasterize to an int mask.
517
        mask = rasterio.features.rasterize([C], out_shape=(thumb.size[1], thumb.size[0])).astype(bool)
518
        # Add a dummy channel dimension.
519
        mask = mask[None, :, :]
520
521
    assert img.shape[1:] == mask.shape[1:], "Image and mask must have the same dimensions."
522
    assert mask.ndim == 3, "Mask must have 3 dimensions (C, H, W)."
523
    assert img.ndim == 3, "Image must have 3 dimensions (C, H, W)."
524
525
    return {
526
        'image': img,
527
        'mask': mask
528
    }