Switch to side-by-side view

--- a
+++ b/slideflow/slide/backends/cucim.py
@@ -0,0 +1,577 @@
+"""cuCIM slide-reading backend.
+
+Requires: cuCIM (...)
+"""
+
+import cv2
+import numpy as np
+
+from types import SimpleNamespace
+from typing import Optional, Dict, Any, Tuple, List, TYPE_CHECKING
+from slideflow.util import log
+from skimage.transform import resize
+from skimage.util import img_as_float32
+from skimage.color import rgb2hsv
+from slideflow.slide.utils import *
+
+if TYPE_CHECKING:
+    from cucim import CuImage
+    import cupy as cp
+
+# -----------------------------------------------------------------------------
+
+SUPPORTED_BACKEND_FORMATS = ['svs', 'tif', 'tiff']
+
+# -----------------------------------------------------------------------------
+
+__cv2_resize__ = True
+__cuimage__ = None
+__cuimage_path__ = None
+
+# -----------------------------------------------------------------------------
+
+def get_cucim_reader(path: str, *args, **kwargs):
+    return _cuCIMReader(path, *args, **kwargs)
+
+
+def cucim2numpy(img: Union["CuImage", "cp.ndarray", "np.ndarray"]) -> np.ndarray:
+    """Convert a cuCIM image to a numpy array."""
+    from cucim import CuImage
+    if isinstance(img, CuImage):
+        np_img = np.asarray(img)
+    elif isinstance(img, np.ndarray):
+        np_img = img
+    else:
+        import cupy as cp
+        if isinstance(img, cp.ndarray):
+            np_img = img.get()
+        else:
+            raise ValueError(f"Unsupported image type: {type(img)}")
+    return ((img_as_float32(np_img)) * 255).astype(np.uint8)
+
+
+def cucim2jpg(img: "CuImage") -> str:
+    img = cucim2numpy(img)
+    return numpy2jpg(img)
+
+
+def cucim2png(img: "CuImage") -> str:
+    img = cucim2numpy(img)
+    return numpy2png(img)
+
+
+def cucim_padded_crop(
+    img: "CuImage",
+    location: Tuple[int, int],
+    size: Tuple[int, int],
+    level: int,
+    **kwargs
+) -> Union["CuImage", "np.ndarray"]:
+    """Read a region from the image, padding missing data.
+
+    Args:
+        img (CuImage): Image to read from.
+        location (Tuple[int, int]): Top-left location of the region to extract,
+            using base layer coordinates (x, y).
+        size (Tuple[int, int]): Size of the region to read (width, height).
+        level (int): Pyramid level to read from.
+        **kwargs: Additional arguments for reading the region.
+
+    Returns:
+        Original image (``CuImage``) if the region is within bounds, otherwise
+        a padded region (``np.ndarray``).
+
+    """
+    x, y = location
+    width, height = size
+    slide_height, slide_width = img.shape[0], img.shape[1]
+    bg = [255]
+    # Note that for cucim images, the shape is (height, width, channels).
+    # First, return the original image if the region is within bounds.
+    if (x >= 0 and y >= 0 and x + width <= slide_width and y + height <= slide_height):
+        return img.read_region(location=(x, y), size=(width, height), level=level, **kwargs)
+    # Otherwise, pad the missing region with white.
+    # First, find the region that is within bounds.
+    x1, y1 = max(0, x), max(0, y)
+    x2, y2 = min(slide_width, x + width), min(slide_height, y + height)
+    # Read the region within bounds.
+    region = img.read_region(location=(x1, y1), size=(x2 - x1, y2 - y1), level=level, **kwargs)
+    # Convert to a numpy array.
+    region_cp = np.asarray(region)
+    # Use np.pad to pad the region.
+    pad_width = ((max(0, -y), max(0, y + height - slide_height)),
+                 (max(0, -x), max(0, x + width - slide_width)),
+                 (0, 0))
+    region_cp = np.pad(region_cp, pad_width, mode='constant', constant_values=bg)
+    return region_cp
+
+
+def tile_worker(
+    c: List[int],
+    args: SimpleNamespace
+) -> Optional[Union[str, Dict]]:
+    """Multiprocessing worker for WSI. Extracts tile at given coordinates."""
+
+    if args.has_segmentation:
+        c, tile_mask = c
+        (x, y, grid_x), grid_y = c, 0
+    else:
+        tile_mask = None
+        x, y, grid_x, grid_y = c
+
+    x_coord = int(x + args.full_extract_px / 2)
+    y_coord = int(y + args.full_extract_px / 2)
+
+    # If downsampling is enabled, read image from highest level
+    # to perform filtering; otherwise filter from our target level
+    slide = get_cucim_reader(args.path, args.mpp_override, **args.reader_kwargs)
+    if args.whitespace_fraction < 1 or args.grayspace_fraction < 1:
+        if args.filter_downsample_ratio > 1:
+            filter_extract_px = args.extract_px // args.filter_downsample_ratio
+            filter_region = slide.read_region(
+                (x, y),
+                args.filter_downsample_level,
+                (filter_extract_px, filter_extract_px)
+            )
+        else:
+            # Read the region and resize to target size
+            filter_region = slide.read_region(
+                (x, y),
+                args.downsample_level,
+                (args.extract_px, args.extract_px)
+            )
+        try:
+            # Perform whitespace filtering [cucim]
+            if args.whitespace_fraction < 1:
+                ws_fraction = np.mean((np.mean(cucim2numpy(filter_region), axis=-1) > args.whitespace_threshold))
+                if (ws_fraction > args.whitespace_fraction
+                and args.whitespace_fraction != FORCE_CALCULATE_WHITESPACE):
+                    return None
+
+            # Perform grayspace filtering [cucim]
+            if args.grayspace_fraction < 1:
+                hsv_region = rgb2hsv(np.asarray(filter_region))
+                gs_fraction = np.mean(hsv_region[:, :, 1] < args.grayspace_threshold)
+                if (gs_fraction > args.grayspace_fraction
+                and args.whitespace_fraction != FORCE_CALCULATE_WHITESPACE):
+                    return None
+        except IndexError:
+            return None
+
+    # Prepare return dict with WS/GS fraction
+    return_dict = {'loc': [x_coord, y_coord]}  # type: Dict[str, Any]
+    return_dict.update({'grid': [grid_x, grid_y]})
+    if args.grayspace_fraction < 1:
+        return_dict.update({'gs_fraction': gs_fraction})
+    if args.whitespace_fraction < 1:
+        return_dict.update({'ws_fraction': ws_fraction})
+
+    # If dry run, return without the image
+    if args.dry_run:
+        return_dict.update({'loc': [x_coord, y_coord]})
+        return return_dict
+
+    # If using a segmentation mask, resize mask to match the tile size.
+    if tile_mask is not None:
+        tile_mask = cv2.resize(
+            tile_mask,
+            (args.tile_px, args.tile_px),
+            interpolation=cv2.INTER_NEAREST)
+
+    # Read the target downsample region now, if we were
+    # filtering at a different level
+    region = slide.read_region(
+        (x, y),
+        args.downsample_level,
+        (args.extract_px, args.extract_px)
+    )
+    # If the region is None (out of bounds), return None
+    if region is None:
+        return None
+
+    # cuCIM resize
+    if not __cv2_resize__:
+        if int(args.tile_px) != int(args.extract_px):
+            region = resize(np.asarray(region), (args.tile_px, args.tile_px))
+
+    region = cucim2numpy(region)
+
+    # cv2 resize
+    if __cv2_resize__:
+        if int(args.tile_px) != int(args.extract_px):
+            region = cv2.resize(region, (args.tile_px, args.tile_px))
+
+    assert(region.shape[0] == region.shape[1] == args.tile_px)
+
+    # Remove the alpha channel and convert to RGB
+    if region.shape[-1] == 4:
+        region = region[:, :, 0:3]
+
+    # Apply segmentation mask
+    if tile_mask is not None:
+        region[tile_mask == 0] = (0, 0, 0)
+
+    # Apply normalization
+    if args.normalizer:
+        try:
+            region = args.normalizer.rgb_to_rgb(region)
+        except Exception:
+            # The image could not be normalized,
+            # which happens when a tile is primarily one solid color
+            return None
+
+    if args.img_format != 'numpy':
+        image = cv2.cvtColor(region, cv2.COLOR_RGB2BGR)
+        # Default image quality for JPEG is 95%
+        image = cv2.imencode("."+args.img_format, image)[1].tobytes()
+    else:
+        image = region
+
+    # Include ROI / bounding box processing.
+    # Used to visualize ROIs on extracted tiles, or to generate YoloV5 labels.
+    if args.yolo or args.draw_roi:
+        coords, boxes, yolo_anns = roi_coords_from_image(c, args)
+    if args.draw_roi:
+        image = draw_roi(image, coords)
+
+    return_dict.update({'image': image})
+    if args.yolo:
+        return_dict.update({'yolo': yolo_anns})
+    return return_dict
+
+
+class _cuCIMReader:
+
+    has_levels = True
+
+    def __init__(
+        self,
+        path: str,
+        mpp: Optional[float] = None,
+        *,
+        cache_kw: Optional[Dict[str, Any]] = None,
+        num_workers: int = 0,
+        ignore_missing_mpp: bool = True,
+        pad_missing: bool = True,
+        use_bounds: bool = False,  #TODO: Not yet implemented
+    ):
+        '''Wrapper for cuCIM reader to preserve cross-compatible functionality.'''
+        global __cuimage__, __cuimage_path__
+
+        from cucim import CuImage
+
+        self.path = path
+        self.pad_missing = pad_missing
+        self.cache_kw = cache_kw if cache_kw else {}
+        self.loaded_downsample_levels = {}  # type: Dict[int, "CuImage"]
+        if path == __cuimage_path__:
+            self.reader = __cuimage__
+        else:
+            __cuimage__ = self.reader = CuImage(path)
+            __cuimage_path__ = path
+        self.num_workers = num_workers
+        self._mpp = None
+
+        # Check for Microns-per-pixel (MPP)
+        if mpp is not None:
+            log.debug(f"Manually setting MPP to {mpp}")
+            self._mpp = mpp
+        for prop_key in self.metadata:
+            if self._mpp is not None:
+                break
+            if 'MPP' in self.metadata[prop_key]:
+                self._mpp = self.metadata[prop_key]['MPP']
+                #log.debug(f'Setting MPP by metadata ({prop_key}) "MPP" to {self._mpp}')
+            elif 'DICOM_PIXEL_SPACING' in self.metadata[prop_key]:
+                ps = self.metadata[prop_key]['DICOM_PIXEL_SPACING'][0]
+                self._mpp = ps * 1000  # Convert from millimeters -> microns
+                #log.debug(f'Setting MPP by metadata ({prop_key}) "DICOM_PIXEL_SPACING" to {self._mpp}')
+            elif 'spacing' in self.metadata[prop_key]:
+                ps = self.metadata[prop_key]['spacing']
+                if isinstance(ps, (list, tuple)):
+                    ps = ps[0]
+                if 'spacing_units' in self.metadata[prop_key]:
+                    spacing_unit = self.metadata[prop_key]['spacing_units']
+                    if isinstance(spacing_unit, (list, tuple)):
+                        spacing_unit = spacing_unit[0]
+                    if spacing_unit in ('mm', 'millimeters', 'millimeter'):
+                        self._mpp = ps * 1000
+                    elif spacing_unit in ('cm', 'centimeters', 'centimeter'):
+                        self._mpp = ps * 10000
+                    elif spacing_unit in ('um', 'microns', 'micrometers', 'micrometer'):
+                        self._mpp = ps
+                    else:
+                        continue
+                    #log.debug(f'Setting MPP by metadata ({prop_key}) "spacing" ({spacing_unit}) to {self._mpp}')
+        if not self.mpp:
+            log.warn("Unable to auto-detect microns-per-pixel (MPP).")
+
+        # Pyramid layers
+        self.dimensions = tuple(self.properties['shape'][0:2][::-1])
+        self.levels = []
+        for lev in range(self.level_count):
+            self.levels.append({
+                'dimensions': self.level_dimensions[lev],
+                'width': self.level_dimensions[lev][0],
+                'height': self.level_dimensions[lev][1],
+                'downsample': self.level_downsamples[lev],
+                'level': lev
+            })
+
+    @property
+    def mpp(self):
+        return self._mpp
+
+    def has_mpp(self):
+        return self._mpp is not None
+
+    @property
+    def metadata(self):
+        return self.reader.metadata
+
+    @property
+    def properties(self):
+        return self.reader.metadata['cucim']
+
+    @property
+    def resolutions(self):
+        return self.properties['resolutions']
+
+    @property
+    def level_count(self):
+        return self.resolutions['level_count']
+
+    @property
+    def level_dimensions(self):
+        return self.resolutions['level_dimensions']
+
+    @property
+    def level_downsamples(self):
+        return self.resolutions['level_downsamples']
+
+    @property
+    def level_tile_sizes(self):
+        return self.resolutions['level_tile_sizes']
+
+    def best_level_for_downsample(
+        self,
+        downsample: float,
+    ) -> int:
+        '''Return lowest magnification level with a downsample level lower than
+        the given target.
+
+        Args:
+            downsample (float): Ratio of target resolution to resolution
+                at the highest magnification level. The downsample level of the
+                highest magnification layer is equal to 1.
+            levels (list(int), optional): Valid levels to search. Defaults to
+                None (search all levels).
+
+        Returns:
+            int:    Optimal downsample level.
+        '''
+        max_downsample = 0
+        for d in self.level_downsamples:
+            if d < downsample:
+                max_downsample = d
+        try:
+            max_level = self.level_downsamples.index(max_downsample)
+        except Exception:
+            log.debug(f"Error attempting to read level {max_downsample}")
+            return 0
+        return max_level
+
+    def coord_to_raw(self, x, y):
+        return x, y
+
+    def raw_to_coord(self, x, y):
+        return x, y
+
+    def read_level(self, level: int, to_numpy: bool = False):
+        """Read a pyramid level."""
+        image = self.reader.read_region(level=level)
+        if to_numpy:
+            return cucim2numpy(image)
+        else:
+            return image
+
+    def read_region(
+        self,
+        base_level_dim: Tuple[int, int],
+        downsample_level: int,
+        extract_size: Tuple[int, int],
+        *,
+        convert: Optional[str] = None,
+        flatten: bool = False,
+        resize_factor: Optional[float] = None,
+        pad_missing: Optional[bool] = None
+    ) -> Optional[Union["CuImage", np.ndarray, str]]:
+        """Extracts a region from the image at the given downsample level.
+
+        Args:
+            base_level_dim (Tuple[int, int]): Top-left location of the region
+                to extract, using base layer coordinates (x, y)
+            downsample_level (int): Downsample level to read.
+            extract_size (Tuple[int, int]): Size of the region to read
+                (width, height) using downsample layer coordinates.
+
+        Keyword args:
+            pad_missing (bool, optional): Pad missing regions with black.
+                If None, uses the value of the `pad_missing` attribute.
+                Defaults to None.
+            convert (str, optional): Convert the image to a different format.
+                Supported formats are 'jpg', 'jpeg', 'png', and 'numpy'.
+                Defaults to None.
+            flatten (bool, optional): Flatten the image to 3 channels.
+                Defaults to False.
+            resize_factor (float, optional): Resize the image by this factor.
+                Defaults to None.
+
+
+        Returns:
+            Image in the specified format.
+
+        """
+        # Define region kwargs
+        region_kwargs = dict(
+            location=base_level_dim,
+            size=(int(extract_size[0]), int(extract_size[1])),
+            level=downsample_level,
+            num_workers=self.num_workers,
+        )
+        # Pad missing data, if enabled
+        if ((pad_missing is not None and pad_missing)
+        or (pad_missing is None and self.pad_missing)):
+            try:
+                region = cucim_padded_crop(self.reader, **region_kwargs)
+            except ValueError as e:
+                log.warning(f"Error reading region via padded crop with kwargs=({region_kwargs}): {e}")
+                return None
+        else:
+            # If padding is disabled, this will raise a ValueError.
+            try:
+                region = self.reader.read_region(**region_kwargs)
+            except ValueError as e:
+                log.warning(f"Error reading region with kwargs=({region_kwargs}): {e}")
+                return None
+
+        # Resize using the same interpolation strategy as the Libvips backend (cv2).
+        if resize_factor:
+            target_size = (int(np.round(extract_size[0] * resize_factor)),
+                           int(np.round(extract_size[1] * resize_factor)))
+            if not __cv2_resize__:
+                region = resize(cucim2numpy(region), target_size)
+
+        # Final conversions.
+        if flatten and region.shape[-1] == 4:
+            region = region[:, :, 0:3]
+        if (convert
+            and convert.lower() in ('jpg', 'jpeg', 'png', 'numpy')
+            and not isinstance(region, np.ndarray)):
+            region = cucim2numpy(region)
+        if resize_factor and __cv2_resize__:
+            region = cv2.resize(region, target_size)
+        if convert and convert.lower() in ('jpg', 'jpeg'):
+            return numpy2jpg(region)
+        elif convert and convert.lower() == 'png':
+            return numpy2png(region)
+        return region
+
+    def read_from_pyramid(
+        self,
+        top_left: Tuple[int, int],
+        window_size: Tuple[int, int],
+        target_size: Tuple[int, int],
+        *,
+        convert: Optional[str] = None,
+        flatten: bool = False,
+        pad_missing: Optional[bool] = None
+    ) -> "CuImage":
+        """Reads a region from the image using base layer coordinates.
+        Performance is accelerated by pyramid downsample layers, if available.
+
+        Args:
+            top_left (Tuple[int, int]): Top-left location of the region to
+                extract, using base layer coordinates (x, y).
+            window_size (Tuple[int, int]): Size of the region to read (width,
+                height) using base layer coordinates.
+            target_size (Tuple[int, int]): Resize the region to this target
+                size (width, height).
+
+        Keyword args:
+            convert (str, optional): Convert the image to a different format.
+                Supported formats are 'jpg', 'jpeg', 'png', and 'numpy'.
+                Defaults to None.
+            flatten (bool, optional): Flatten the image to 3 channels.
+                Defaults to False.
+            pad_missing (bool, optional): Pad missing regions with black.
+                If None, uses the value of the `pad_missing` attribute.
+                Defaults to None.
+
+        Returns:
+            CuImage: Image. Dimensions will equal target_size unless
+            the window includes an area of the image which is out of bounds.
+            In this case, the returned image will be cropped.
+        """
+        target_downsample = window_size[0] / target_size[0]
+        ds_level = self.best_level_for_downsample(target_downsample)
+
+        # Use a lower downsample level if the window size is too small
+        ds = self.level_downsamples[ds_level]
+        if not int(window_size[0] / ds) or not int(window_size[1] / ds):
+            ds_level = max(0, ds_level-1)
+            ds = self.level_downsamples[ds_level]
+
+        # Define region kwargs
+        region_kwargs = dict(
+            location=top_left,
+            size=(int(window_size[0] / ds), int(window_size[1] / ds)),
+            level=ds_level,
+            num_workers=self.num_workers,
+        )
+        if ((pad_missing is not None and pad_missing)
+              or (pad_missing is None and self.pad_missing)):
+            region = cucim_padded_crop(self.reader, **region_kwargs)
+        else:
+            region = self.read_region(**region_kwargs)
+
+        # Resize using the same interpolation strategy as the Libvips backend (cv2).
+        if not __cv2_resize__:
+            region = resize(cucim2numpy(region), (target_size[1], target_size[0]))
+
+        # Final conversions
+        if flatten and region.shape[-1] == 4:
+            region = region[:, :, 0:3]
+        if (convert
+            and convert.lower() in ('jpg', 'jpeg', 'png', 'numpy')
+            and not isinstance(region, np.ndarray)):
+            region = cucim2numpy(region)
+        if __cv2_resize__:
+            region = cv2.resize(region, target_size)
+        if convert and convert.lower() in ('jpg', 'jpeg'):
+            return numpy2jpg(region)
+        elif convert and convert.lower() == 'png':
+            return numpy2png(region)
+        return region
+
+    def thumbnail(
+        self,
+        width: int = 512,
+        level: Optional[int] = None,
+        associated: bool = False
+    ) -> np.ndarray:
+        """Return thumbnail of slide as numpy array."""
+        if associated:
+            log.debug("associated=True not implemented for cucim() thumbnail,"
+                      "reading from lowest-magnification layer.")
+        if level is None:
+            level = self.level_count - 1
+        w, h = self.dimensions
+        height = int((width * h) / w)
+        img = self.read_level(level=level)
+        if __cv2_resize__:
+            img = cucim2numpy(img)
+            return cv2.resize(img, (width, height))
+        else:
+            img = resize(np.asarray(img), (width, height))
+            return cucim2numpy(img)