--- a +++ b/slideflow/segment/data.py @@ -0,0 +1,528 @@ +import torch +import slideflow as sf +import rasterio +import numpy as np +import shapely.affinity as sa + +from typing import Tuple, Union, Optional, List, Dict +from torchvision import transforms +from os.path import join, exists +from rich.progress import track +from shapely.ops import unary_union +from shapely.geometry import Polygon +from shapely.ops import unary_union +from slideflow.util import path_to_name + +from .utils import topleft_pad_torch + +# ----------------------------------------------------------------------------- + +class ThumbMaskDataset(torch.utils.data.Dataset): + + def __init__( + self, + dataset: "sf.Dataset", + mpp: float, + roi_labels: List[str], + *, + mode: str = 'binary', + ) -> None: + """Dataset that generates thumbnails and ROI masks. + + Args: + dataset (sf.Dataset): The dataset to use. + mpp (float): The target microns per pixel. The thumbnail will be + scaled to this resolution. + roi_labels (List[str]): The ROI labels to include in the mask. + + Keyword args: + mode (str, optional): The mode to use for the mask. One of: + 'binary', 'multiclass', 'multilabel'. Defaults to 'binary'. + + """ + super().__init__() + self.mpp = mpp + self.mode = mode + self.roi_labels = roi_labels + + # Subsample dataset to only include slides with ROIs. + self.rois = dataset.rois() + slides = set(map(path_to_name, dataset.slide_paths())) + slides = slides.intersection(set(map(path_to_name, self.rois))) + dataset = dataset.filter({'slide': list(slides)}) + + # Prepare WSI objects (for slides with ROIs). + self.paths = dataset.slide_paths() + + def __len__(self) -> int: + return len(self.paths) + + def process( + self, + img: np.ndarray, + mask: np.ndarray + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Process the image/mask and convert to a tensor.""" + img = torch.from_numpy(img) + mask = torch.from_numpy(mask) + return img, mask + + def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]: + + # Load the image and mask. + path = self.paths[index] + wsi = sf.WSI(path, 299, 512, rois=self.rois, roi_filter_method=0.1, verbose=False) + output = get_thumb_and_mask(wsi, self.mpp, self.roi_labels, skip_missing=False) + if output is None: + return None + img = output['image'] # CHW (np.ndarray) + mask = output['mask'].astype(int) # 1HW (np.ndarray) + + if self.mode == 'multiclass': + mask = mask * np.arange(1, mask.shape[0]+1)[:, None, None] + mask = mask.max(axis=0) + elif self.mode == 'binary' and mask.ndim == 3: + mask = np.any(mask, axis=0)[None, :, :].astype(int) + + # Process. + img, mask = self.process(img, mask) + + return { + 'image': img, + 'mask': mask + } + + +class RandomCropDataset(ThumbMaskDataset): + + def __init__(self, *args, size: int = 1024, **kwargs): + """Dataset that generates thumbnails & ROI masks, with random crops. + + Thumbnails and masks and randomly cropped and rotated together to + a square size of `size` pixels. + + Args: + dataset (sf.Dataset): The dataset to use. + mpp (float): The target microns per pixel. The thumbnail will be + scaled to this resolution. + roi_labels (List[str]): The ROI labels to include in the mask. + size (int, optional): The size of the random crop. Defaults to 1024. + + Keyword Args: + mode (str, optional): The mode to use for the mask. One of: + 'binary', 'multiclass', 'multilabel'. Defaults to 'binary'. + + """ + super().__init__(*args, **kwargs) + self.size = size + + def process(self, img, mask): + """Randomly crop/rotate the image and mask and convert to a tensor.""" + return random_crop_and_rotate(img, mask, size=self.size) + +# ----------------------------------------------------------------------------- +# Buffered datasets + +class BufferedMaskDataset(torch.utils.data.Dataset): + + def __init__(self, dataset: "sf.Dataset", source: str, *, mode: str = 'binary'): + """Dataset that loads buffered image and mask pairs. + + Args: + dataset (sf.Dataset): The dataset to use. + source (str): The directory containing the buffered image/mask pairs. + + Keyword Args: + mode (str, optional): The mode to use for the mask. One of: + 'binary', 'multiclass', 'multilabel'. Defaults to 'binary'. + + """ + super().__init__() + if mode not in ['binary', 'multiclass', 'multilabel']: + raise ValueError("Invalid mode: {}. Expected one of: binary, " + "multiclass, multilabel".format(mode)) + self.dataset = dataset + self.mode = mode + self.paths = [ + join(source, s + '.pt') for s in dataset.slides() + if exists(join(source, s + '.pt')) + ] + + + def __len__(self) -> int: + return len(self.paths) + + def process( + self, + img: np.ndarray, + mask: np.ndarray + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Process the image/mask and convert to a tensor.""" + img = torch.from_numpy(img) + mask = torch.from_numpy(mask) + return img, mask + + def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]: + # Load the image and mask. + output = torch.load(self.paths[index]) + img = output['image'] # CHW (np.ndarray) + mask = output['mask'].astype(int) # 1HW (np.ndarray) + + if self.mode == 'multiclass': + mask = mask * np.arange(1, mask.shape[0]+1)[:, None, None] + mask = mask.max(axis=0) + elif self.mode == 'binary' and mask.ndim == 3: + mask = np.any(mask, axis=0)[None, :, :].astype(int) + + # Process. + img, mask = self.process(img, mask) + + return { + 'image': img, + 'mask': mask + } + + +class BufferedRandomCropDataset(BufferedMaskDataset): + + def __init__(self, *args, size: int = 1024, **kwargs): + """Dataset that loads buffered image/mask pairs and randomly crops. + + Loaded thumbnails and masks and randomly cropped and rotated together to + a square size of `size` pixels. + + Args: + dataset (sf.Dataset): The dataset to use. + source (str): The directory containing the buffered image/mask pairs. + size (int, optional): The size of the random crop. Defaults to 1024. + + Keyword Args: + mode (str, optional): The mode to use for the mask. One of: + 'binary', 'multiclass', 'multilabel'. Defaults to 'binary'. + + """ + super().__init__(*args, **kwargs) + self.size = size + + def process( + self, + img: np.ndarray, + mask: np.ndarray + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Randomly crop/rotate the image and mask and convert to a tensor.""" + return random_crop_and_rotate(img, mask, size=self.size) + +# ----------------------------------------------------------------------------- + +class TileMaskDataset(torch.utils.data.Dataset): + + def __init__( + self, + dataset: "sf.Dataset", + tile_px: int, + tile_um: Union[int, str], + *, + roi_labels: Optional[List[str]] = None, + stride_div: int = 2, + crop_margin: int = 0, + filter_method: str = 'otsu', + mode: str = 'binary' + ): + """Dataset that generates tiles and ROI masks from slides. + + Args: + dataset (sf.Dataset): The dataset to use. + tile_px (int): The size of the tiles (in pixels). + tile_um (Union[int, str]): The size of the tiles (in microns). + + Keyword args: + stride_div (int, optional): The divisor for the stride. + Defaults to 2. + crop_margin (int, optional): The number of pixels to add to the + tile size before random cropping to the target tile_px size. + Defaults to 0 (no random cropping). + filter_method (str, optional): The method to use for identifying + tiles for training. If 'roi', selects only tiles that intersect + with an ROI. If 'otsu', selects tiles based on an Otsu threshold + of the slide. Defaults to 'roi'. + + """ + super().__init__() + + rois = dataset.rois() + slides_with_rois = [path_to_name(r) for r in rois] + slides = [s for s in dataset.slide_paths() + if path_to_name(s) in slides_with_rois] + kw = dict( + tile_px=tile_px + crop_margin, + tile_um=tile_um, + verbose=False, + stride_div=stride_div + ) + if roi_labels is None: + roi_labels = [] + self.mode = mode + self.roi_labels = roi_labels + self.tile_px = tile_px + self.coords = [] + self.all_wsi = dict() + self.all_wsi_with_roi = dict() + self.all_extract_px = dict() + for slide in track(slides, description="Loading slides"): + name = path_to_name(slide) + wsi = sf.WSI(slide, **kw) + try: + wsi_with_rois = sf.WSI(slide, roi_filter_method=0.1, rois=rois, **kw) + except Exception as e: + sf.log.error("Failed to load slide {}: {}".format(slide, e)) + raise e + + # Filter ROIs to only include the specified labels. + if self.roi_labels: + wsi_with_rois.rois = [roi for roi in wsi_with_rois.rois if roi.label in self.roi_labels] + wsi_with_rois.process_rois() + + if not len(wsi_with_rois.rois): + continue + if filter_method == 'roi': + wsi_inner = sf.WSI(slide, roi_filter_method=0.9, rois=rois, **kw) + if self.roi_labels: + wsi_inner.rois = [roi for roi in wsi_with_rois.rois if roi.label in self.roi_labels] + wsi_inner.process_rois() + coords = np.argwhere(wsi_with_rois.grid & (~wsi_inner.grid)).tolist() + elif filter_method == 'otsu': + wsi.qc('otsu') + coords = np.argwhere(wsi.grid).tolist() + wsi.remove_qc() + elif filter_method in ['all', 'none', None]: + coords = np.argwhere(wsi_with_rois.grid).tolist() + else: + raise ValueError("Invalid filter method: {}. Expected one of: " + "roi, otsu".format(filter_method)) + for c in coords: + self.coords.append([name] + c) + self.all_wsi[name] = wsi + self.all_wsi_with_roi[name] = wsi_with_rois + self.all_extract_px[name] = int(wsi.tile_um / wsi.mpp) + + def __len__(self): + return len(self.coords) + + def get_scaled_and_intersecting_polys( + self, + polys: "Polygon", + tile: "Polygon", + scale: float, + full_stride: int, + grid_idx: Tuple[int, int] + ): + """Get scaled and intersecting polygons for a given tile.""" + gx, gy = grid_idx + A = polys.intersection(tile) + + # Translate polygons so the intersection origin is at (0, 0) + B = sa.translate(A, -(full_stride*gx), -(full_stride*gy)) + + # Scale to the target tile size + C = sa.scale(B, xfact=scale, yfact=scale, origin=(0, 0)) + return C + + def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]: + """Get an image and mask for a given index.""" + slide, gx, gy = self.coords[index] + wsi = self.all_wsi[slide] + wsi_with_roi = self.all_wsi_with_roi[slide] + fe = self.all_extract_px[slide] + fs = wsi.full_stride + scale = wsi.tile_px / fe + + # Get the image. + img = wsi[gx, gy].transpose(2, 0, 1) + + # Get a polygon for the tile, used for determining overlapping ROIs. + tile = Polygon([ + [fs*gx, fs*gy], + [fs*gx, (fs*gy)+fe], + [(fs*gx)+fe, (fs*gy)+fe], + [(fs*gx)+fe, fs*gy] + ]) + + # Compute the mask from ROIs. + if len(wsi_with_roi.rois) == 0: + if self.roi_labels: + mask = np.zeros((len(self.roi_labels), wsi.tile_px, wsi.tile_px), dtype=int) + else: + mask = np.zeros((1, wsi.tile_px, wsi.tile_px), dtype=int) + + # Handle ROIs with labels (multilabel or multiclass) + elif self.roi_labels: + labeled_masks = [] + for i, label in enumerate(self.roi_labels): + wsi_polys = [p.poly for p in wsi_with_roi.rois if p.label == label] + if len(wsi_polys) == 0: + mask = np.zeros((wsi.tile_px, wsi.tile_px), dtype=int) + labeled_masks.append(mask) + else: + all_polys = unary_union(wsi_polys) + polys = self.get_scaled_and_intersecting_polys( + all_polys, tile, scale, fs, (gx, gy) + ) + if isinstance(polys, Polygon) and polys.is_empty: + mask = np.zeros((wsi.tile_px, wsi.tile_px), dtype=int) + else: + # Rasterize to an int mask. + mask = rasterio.features.rasterize([polys], out_shape=[wsi.tile_px, wsi.tile_px]).astype(int) + labeled_masks.append(mask) + mask = np.stack(labeled_masks, axis=0) + + # Handle ROIs without labels (binary) + else: + # Determine the intersection at the given tile location. + all_polys = unary_union([p.poly for p in wsi_with_roi.rois]) + polys = self.get_scaled_and_intersecting_polys( + all_polys, tile, scale, fs, (gx, gy) + ) + + if isinstance(polys, Polygon) and polys.is_empty: + mask = np.zeros((wsi.tile_px, wsi.tile_px), dtype=int) + else: + # Rasterize to an int mask. + try: + mask = rasterio.features.rasterize([polys], out_shape=[wsi.tile_px, wsi.tile_px]).astype(bool).astype(np.int32) + except ValueError: + mask = np.zeros((wsi.tile_px, wsi.tile_px), dtype=int) + + # Add a dummy channel dimension. + mask = mask[None, :, :] + + # Process according to the mode. + if self.mode == 'multiclass': + mask = mask * np.arange(1, mask.shape[0]+1)[:, None, None] + mask = mask.max(axis=0) + elif self.mode == 'binary' and mask.ndim == 3: + mask = np.any(mask, axis=0)[None, :, :].astype(int) + + # Process. + img, mask = self.process(img, mask) + + return { + 'image': img, + 'mask': mask + } + + def process( + self, + img: np.ndarray, + mask: np.ndarray + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Randomly crop/rotate the image and mask and convert to a tensor.""" + return random_crop_and_rotate(img, mask, size=self.tile_px) + +# ----------------------------------------------------------------------------- + +def random_crop_and_rotate(img, mask, size): + if mask.ndim == 2: + to_squeeze = True + mask = mask[None, :, :] + else: + to_squeeze = False + + # Convert to tensor. + img = torch.from_numpy(img).permute(1, 2, 0) + mask = torch.from_numpy(mask).permute(1, 2, 0) + + # Pad to target size. + img = topleft_pad_torch(img, size).permute(2, 0, 1) + mask = topleft_pad_torch(mask, size).permute(2, 0, 1) + + # Random crop. + i, j, h, w = transforms.RandomCrop.get_params( + img, output_size=(size, size)) + img = transforms.functional.crop(img, i, j, h, w) + mask = transforms.functional.crop(mask, i, j, h, w) + + # Random flip. + if np.random.rand() > 0.5: + img = transforms.functional.hflip(img) + mask = transforms.functional.hflip(mask) + if np.random.rand() > 0.5: + img = transforms.functional.vflip(img) + mask = transforms.functional.vflip(mask) + + # Random cardinal rotation. + r = np.random.randint(4) + img = transforms.functional.rotate(img, r * 90) + mask = transforms.functional.rotate(mask, r * 90) + + if to_squeeze: + mask = mask.squeeze(0) + + return img, mask + +# ----------------------------------------------------------------------------- + +def get_thumb_and_mask( + wsi: "sf.WSI", + mpp: float, + roi_labels: Optional[List[str]] = None, + skip_missing: bool = False +) -> Dict[str, np.ndarray]: + """Get a thumbnail and segmentation mask for a slide.""" + + if len(wsi.rois) == 0 and skip_missing: + return None + + # Sanity check. + width = int((wsi.mpp * wsi.dimensions[0]) / mpp) + ds = wsi.dimensions[0] / width + level = wsi.slide.best_level_for_downsample(ds) + level_dim = wsi.slide.level_dimensions[level] + if any([d > 10000 for d in level_dim]): + sf.log.warning("Large thumbnail found ({}) at level={} for {}".format( + level_dim, level, wsi.path + )) + + # Get the thumbnail. + thumb = wsi.thumb(mpp=mpp).convert('RGB') + img = np.array(thumb).transpose(2, 0, 1) + xfact = thumb.size[1] / wsi.dimensions[1] + yfact = thumb.size[0] / wsi.dimensions[0] + + if len(wsi.rois) == 0: + if roi_labels: + mask = np.zeros((len(roi_labels), thumb.size[1], thumb.size[0])).astype(bool) + else: + mask = np.zeros((1, thumb.size[1], thumb.size[0])).astype(bool) + elif roi_labels: + labeled_masks = [] + for i, label in enumerate(roi_labels): + wsi_polys = [p.poly for p in wsi.rois if p.label == label] + if len(wsi_polys) == 0: + mask = np.zeros((thumb.size[1], thumb.size[0])).astype(bool) + labeled_masks.append(mask) + else: + all_polys = unary_union(wsi_polys) + # Scale ROIs to the thumbnail size. + C = sa.scale(all_polys, xfact=xfact, yfact=yfact, origin=(0, 0)) + # Rasterize to an int mask. + mask = rasterio.features.rasterize([C], out_shape=(thumb.size[1], thumb.size[0])).astype(bool).astype(np.int32) + labeled_masks.append(mask) + mask = np.stack(labeled_masks, axis=0) + + else: + all_polys = unary_union([p.poly for p in wsi.rois]) + # Scale ROIs to the thumbnail size. + C = sa.scale(all_polys, xfact=xfact, yfact=yfact, origin=(0, 0)) + # Rasterize to an int mask. + mask = rasterio.features.rasterize([C], out_shape=(thumb.size[1], thumb.size[0])).astype(bool) + # Add a dummy channel dimension. + mask = mask[None, :, :] + + assert img.shape[1:] == mask.shape[1:], "Image and mask must have the same dimensions." + assert mask.ndim == 3, "Mask must have 3 dimensions (C, H, W)." + assert img.ndim == 3, "Image must have 3 dimensions (C, H, W)." + + return { + 'image': img, + 'mask': mask + }