a b/slideflow/slide/backends/cucim.py
1
"""cuCIM slide-reading backend.
2
3
Requires: cuCIM (...)
4
"""
5
6
import cv2
7
import numpy as np
8
9
from types import SimpleNamespace
10
from typing import Optional, Dict, Any, Tuple, List, TYPE_CHECKING
11
from slideflow.util import log
12
from skimage.transform import resize
13
from skimage.util import img_as_float32
14
from skimage.color import rgb2hsv
15
from slideflow.slide.utils import *
16
17
if TYPE_CHECKING:
18
    from cucim import CuImage
19
    import cupy as cp
20
21
# -----------------------------------------------------------------------------
22
23
SUPPORTED_BACKEND_FORMATS = ['svs', 'tif', 'tiff']
24
25
# -----------------------------------------------------------------------------
26
27
__cv2_resize__ = True
28
__cuimage__ = None
29
__cuimage_path__ = None
30
31
# -----------------------------------------------------------------------------
32
33
def get_cucim_reader(path: str, *args, **kwargs):
34
    return _cuCIMReader(path, *args, **kwargs)
35
36
37
def cucim2numpy(img: Union["CuImage", "cp.ndarray", "np.ndarray"]) -> np.ndarray:
38
    """Convert a cuCIM image to a numpy array."""
39
    from cucim import CuImage
40
    if isinstance(img, CuImage):
41
        np_img = np.asarray(img)
42
    elif isinstance(img, np.ndarray):
43
        np_img = img
44
    else:
45
        import cupy as cp
46
        if isinstance(img, cp.ndarray):
47
            np_img = img.get()
48
        else:
49
            raise ValueError(f"Unsupported image type: {type(img)}")
50
    return ((img_as_float32(np_img)) * 255).astype(np.uint8)
51
52
53
def cucim2jpg(img: "CuImage") -> str:
54
    img = cucim2numpy(img)
55
    return numpy2jpg(img)
56
57
58
def cucim2png(img: "CuImage") -> str:
59
    img = cucim2numpy(img)
60
    return numpy2png(img)
61
62
63
def cucim_padded_crop(
64
    img: "CuImage",
65
    location: Tuple[int, int],
66
    size: Tuple[int, int],
67
    level: int,
68
    **kwargs
69
) -> Union["CuImage", "np.ndarray"]:
70
    """Read a region from the image, padding missing data.
71
72
    Args:
73
        img (CuImage): Image to read from.
74
        location (Tuple[int, int]): Top-left location of the region to extract,
75
            using base layer coordinates (x, y).
76
        size (Tuple[int, int]): Size of the region to read (width, height).
77
        level (int): Pyramid level to read from.
78
        **kwargs: Additional arguments for reading the region.
79
80
    Returns:
81
        Original image (``CuImage``) if the region is within bounds, otherwise
82
        a padded region (``np.ndarray``).
83
84
    """
85
    x, y = location
86
    width, height = size
87
    slide_height, slide_width = img.shape[0], img.shape[1]
88
    bg = [255]
89
    # Note that for cucim images, the shape is (height, width, channels).
90
    # First, return the original image if the region is within bounds.
91
    if (x >= 0 and y >= 0 and x + width <= slide_width and y + height <= slide_height):
92
        return img.read_region(location=(x, y), size=(width, height), level=level, **kwargs)
93
    # Otherwise, pad the missing region with white.
94
    # First, find the region that is within bounds.
95
    x1, y1 = max(0, x), max(0, y)
96
    x2, y2 = min(slide_width, x + width), min(slide_height, y + height)
97
    # Read the region within bounds.
98
    region = img.read_region(location=(x1, y1), size=(x2 - x1, y2 - y1), level=level, **kwargs)
99
    # Convert to a numpy array.
100
    region_cp = np.asarray(region)
101
    # Use np.pad to pad the region.
102
    pad_width = ((max(0, -y), max(0, y + height - slide_height)),
103
                 (max(0, -x), max(0, x + width - slide_width)),
104
                 (0, 0))
105
    region_cp = np.pad(region_cp, pad_width, mode='constant', constant_values=bg)
106
    return region_cp
107
108
109
def tile_worker(
110
    c: List[int],
111
    args: SimpleNamespace
112
) -> Optional[Union[str, Dict]]:
113
    """Multiprocessing worker for WSI. Extracts tile at given coordinates."""
114
115
    if args.has_segmentation:
116
        c, tile_mask = c
117
        (x, y, grid_x), grid_y = c, 0
118
    else:
119
        tile_mask = None
120
        x, y, grid_x, grid_y = c
121
122
    x_coord = int(x + args.full_extract_px / 2)
123
    y_coord = int(y + args.full_extract_px / 2)
124
125
    # If downsampling is enabled, read image from highest level
126
    # to perform filtering; otherwise filter from our target level
127
    slide = get_cucim_reader(args.path, args.mpp_override, **args.reader_kwargs)
128
    if args.whitespace_fraction < 1 or args.grayspace_fraction < 1:
129
        if args.filter_downsample_ratio > 1:
130
            filter_extract_px = args.extract_px // args.filter_downsample_ratio
131
            filter_region = slide.read_region(
132
                (x, y),
133
                args.filter_downsample_level,
134
                (filter_extract_px, filter_extract_px)
135
            )
136
        else:
137
            # Read the region and resize to target size
138
            filter_region = slide.read_region(
139
                (x, y),
140
                args.downsample_level,
141
                (args.extract_px, args.extract_px)
142
            )
143
        try:
144
            # Perform whitespace filtering [cucim]
145
            if args.whitespace_fraction < 1:
146
                ws_fraction = np.mean((np.mean(cucim2numpy(filter_region), axis=-1) > args.whitespace_threshold))
147
                if (ws_fraction > args.whitespace_fraction
148
                and args.whitespace_fraction != FORCE_CALCULATE_WHITESPACE):
149
                    return None
150
151
            # Perform grayspace filtering [cucim]
152
            if args.grayspace_fraction < 1:
153
                hsv_region = rgb2hsv(np.asarray(filter_region))
154
                gs_fraction = np.mean(hsv_region[:, :, 1] < args.grayspace_threshold)
155
                if (gs_fraction > args.grayspace_fraction
156
                and args.whitespace_fraction != FORCE_CALCULATE_WHITESPACE):
157
                    return None
158
        except IndexError:
159
            return None
160
161
    # Prepare return dict with WS/GS fraction
162
    return_dict = {'loc': [x_coord, y_coord]}  # type: Dict[str, Any]
163
    return_dict.update({'grid': [grid_x, grid_y]})
164
    if args.grayspace_fraction < 1:
165
        return_dict.update({'gs_fraction': gs_fraction})
166
    if args.whitespace_fraction < 1:
167
        return_dict.update({'ws_fraction': ws_fraction})
168
169
    # If dry run, return without the image
170
    if args.dry_run:
171
        return_dict.update({'loc': [x_coord, y_coord]})
172
        return return_dict
173
174
    # If using a segmentation mask, resize mask to match the tile size.
175
    if tile_mask is not None:
176
        tile_mask = cv2.resize(
177
            tile_mask,
178
            (args.tile_px, args.tile_px),
179
            interpolation=cv2.INTER_NEAREST)
180
181
    # Read the target downsample region now, if we were
182
    # filtering at a different level
183
    region = slide.read_region(
184
        (x, y),
185
        args.downsample_level,
186
        (args.extract_px, args.extract_px)
187
    )
188
    # If the region is None (out of bounds), return None
189
    if region is None:
190
        return None
191
192
    # cuCIM resize
193
    if not __cv2_resize__:
194
        if int(args.tile_px) != int(args.extract_px):
195
            region = resize(np.asarray(region), (args.tile_px, args.tile_px))
196
197
    region = cucim2numpy(region)
198
199
    # cv2 resize
200
    if __cv2_resize__:
201
        if int(args.tile_px) != int(args.extract_px):
202
            region = cv2.resize(region, (args.tile_px, args.tile_px))
203
204
    assert(region.shape[0] == region.shape[1] == args.tile_px)
205
206
    # Remove the alpha channel and convert to RGB
207
    if region.shape[-1] == 4:
208
        region = region[:, :, 0:3]
209
210
    # Apply segmentation mask
211
    if tile_mask is not None:
212
        region[tile_mask == 0] = (0, 0, 0)
213
214
    # Apply normalization
215
    if args.normalizer:
216
        try:
217
            region = args.normalizer.rgb_to_rgb(region)
218
        except Exception:
219
            # The image could not be normalized,
220
            # which happens when a tile is primarily one solid color
221
            return None
222
223
    if args.img_format != 'numpy':
224
        image = cv2.cvtColor(region, cv2.COLOR_RGB2BGR)
225
        # Default image quality for JPEG is 95%
226
        image = cv2.imencode("."+args.img_format, image)[1].tobytes()
227
    else:
228
        image = region
229
230
    # Include ROI / bounding box processing.
231
    # Used to visualize ROIs on extracted tiles, or to generate YoloV5 labels.
232
    if args.yolo or args.draw_roi:
233
        coords, boxes, yolo_anns = roi_coords_from_image(c, args)
234
    if args.draw_roi:
235
        image = draw_roi(image, coords)
236
237
    return_dict.update({'image': image})
238
    if args.yolo:
239
        return_dict.update({'yolo': yolo_anns})
240
    return return_dict
241
242
243
class _cuCIMReader:
244
245
    has_levels = True
246
247
    def __init__(
248
        self,
249
        path: str,
250
        mpp: Optional[float] = None,
251
        *,
252
        cache_kw: Optional[Dict[str, Any]] = None,
253
        num_workers: int = 0,
254
        ignore_missing_mpp: bool = True,
255
        pad_missing: bool = True,
256
        use_bounds: bool = False,  #TODO: Not yet implemented
257
    ):
258
        '''Wrapper for cuCIM reader to preserve cross-compatible functionality.'''
259
        global __cuimage__, __cuimage_path__
260
261
        from cucim import CuImage
262
263
        self.path = path
264
        self.pad_missing = pad_missing
265
        self.cache_kw = cache_kw if cache_kw else {}
266
        self.loaded_downsample_levels = {}  # type: Dict[int, "CuImage"]
267
        if path == __cuimage_path__:
268
            self.reader = __cuimage__
269
        else:
270
            __cuimage__ = self.reader = CuImage(path)
271
            __cuimage_path__ = path
272
        self.num_workers = num_workers
273
        self._mpp = None
274
275
        # Check for Microns-per-pixel (MPP)
276
        if mpp is not None:
277
            log.debug(f"Manually setting MPP to {mpp}")
278
            self._mpp = mpp
279
        for prop_key in self.metadata:
280
            if self._mpp is not None:
281
                break
282
            if 'MPP' in self.metadata[prop_key]:
283
                self._mpp = self.metadata[prop_key]['MPP']
284
                #log.debug(f'Setting MPP by metadata ({prop_key}) "MPP" to {self._mpp}')
285
            elif 'DICOM_PIXEL_SPACING' in self.metadata[prop_key]:
286
                ps = self.metadata[prop_key]['DICOM_PIXEL_SPACING'][0]
287
                self._mpp = ps * 1000  # Convert from millimeters -> microns
288
                #log.debug(f'Setting MPP by metadata ({prop_key}) "DICOM_PIXEL_SPACING" to {self._mpp}')
289
            elif 'spacing' in self.metadata[prop_key]:
290
                ps = self.metadata[prop_key]['spacing']
291
                if isinstance(ps, (list, tuple)):
292
                    ps = ps[0]
293
                if 'spacing_units' in self.metadata[prop_key]:
294
                    spacing_unit = self.metadata[prop_key]['spacing_units']
295
                    if isinstance(spacing_unit, (list, tuple)):
296
                        spacing_unit = spacing_unit[0]
297
                    if spacing_unit in ('mm', 'millimeters', 'millimeter'):
298
                        self._mpp = ps * 1000
299
                    elif spacing_unit in ('cm', 'centimeters', 'centimeter'):
300
                        self._mpp = ps * 10000
301
                    elif spacing_unit in ('um', 'microns', 'micrometers', 'micrometer'):
302
                        self._mpp = ps
303
                    else:
304
                        continue
305
                    #log.debug(f'Setting MPP by metadata ({prop_key}) "spacing" ({spacing_unit}) to {self._mpp}')
306
        if not self.mpp:
307
            log.warn("Unable to auto-detect microns-per-pixel (MPP).")
308
309
        # Pyramid layers
310
        self.dimensions = tuple(self.properties['shape'][0:2][::-1])
311
        self.levels = []
312
        for lev in range(self.level_count):
313
            self.levels.append({
314
                'dimensions': self.level_dimensions[lev],
315
                'width': self.level_dimensions[lev][0],
316
                'height': self.level_dimensions[lev][1],
317
                'downsample': self.level_downsamples[lev],
318
                'level': lev
319
            })
320
321
    @property
322
    def mpp(self):
323
        return self._mpp
324
325
    def has_mpp(self):
326
        return self._mpp is not None
327
328
    @property
329
    def metadata(self):
330
        return self.reader.metadata
331
332
    @property
333
    def properties(self):
334
        return self.reader.metadata['cucim']
335
336
    @property
337
    def resolutions(self):
338
        return self.properties['resolutions']
339
340
    @property
341
    def level_count(self):
342
        return self.resolutions['level_count']
343
344
    @property
345
    def level_dimensions(self):
346
        return self.resolutions['level_dimensions']
347
348
    @property
349
    def level_downsamples(self):
350
        return self.resolutions['level_downsamples']
351
352
    @property
353
    def level_tile_sizes(self):
354
        return self.resolutions['level_tile_sizes']
355
356
    def best_level_for_downsample(
357
        self,
358
        downsample: float,
359
    ) -> int:
360
        '''Return lowest magnification level with a downsample level lower than
361
        the given target.
362
363
        Args:
364
            downsample (float): Ratio of target resolution to resolution
365
                at the highest magnification level. The downsample level of the
366
                highest magnification layer is equal to 1.
367
            levels (list(int), optional): Valid levels to search. Defaults to
368
                None (search all levels).
369
370
        Returns:
371
            int:    Optimal downsample level.
372
        '''
373
        max_downsample = 0
374
        for d in self.level_downsamples:
375
            if d < downsample:
376
                max_downsample = d
377
        try:
378
            max_level = self.level_downsamples.index(max_downsample)
379
        except Exception:
380
            log.debug(f"Error attempting to read level {max_downsample}")
381
            return 0
382
        return max_level
383
384
    def coord_to_raw(self, x, y):
385
        return x, y
386
387
    def raw_to_coord(self, x, y):
388
        return x, y
389
390
    def read_level(self, level: int, to_numpy: bool = False):
391
        """Read a pyramid level."""
392
        image = self.reader.read_region(level=level)
393
        if to_numpy:
394
            return cucim2numpy(image)
395
        else:
396
            return image
397
398
    def read_region(
399
        self,
400
        base_level_dim: Tuple[int, int],
401
        downsample_level: int,
402
        extract_size: Tuple[int, int],
403
        *,
404
        convert: Optional[str] = None,
405
        flatten: bool = False,
406
        resize_factor: Optional[float] = None,
407
        pad_missing: Optional[bool] = None
408
    ) -> Optional[Union["CuImage", np.ndarray, str]]:
409
        """Extracts a region from the image at the given downsample level.
410
411
        Args:
412
            base_level_dim (Tuple[int, int]): Top-left location of the region
413
                to extract, using base layer coordinates (x, y)
414
            downsample_level (int): Downsample level to read.
415
            extract_size (Tuple[int, int]): Size of the region to read
416
                (width, height) using downsample layer coordinates.
417
418
        Keyword args:
419
            pad_missing (bool, optional): Pad missing regions with black.
420
                If None, uses the value of the `pad_missing` attribute.
421
                Defaults to None.
422
            convert (str, optional): Convert the image to a different format.
423
                Supported formats are 'jpg', 'jpeg', 'png', and 'numpy'.
424
                Defaults to None.
425
            flatten (bool, optional): Flatten the image to 3 channels.
426
                Defaults to False.
427
            resize_factor (float, optional): Resize the image by this factor.
428
                Defaults to None.
429
430
431
        Returns:
432
            Image in the specified format.
433
434
        """
435
        # Define region kwargs
436
        region_kwargs = dict(
437
            location=base_level_dim,
438
            size=(int(extract_size[0]), int(extract_size[1])),
439
            level=downsample_level,
440
            num_workers=self.num_workers,
441
        )
442
        # Pad missing data, if enabled
443
        if ((pad_missing is not None and pad_missing)
444
        or (pad_missing is None and self.pad_missing)):
445
            try:
446
                region = cucim_padded_crop(self.reader, **region_kwargs)
447
            except ValueError as e:
448
                log.warning(f"Error reading region via padded crop with kwargs=({region_kwargs}): {e}")
449
                return None
450
        else:
451
            # If padding is disabled, this will raise a ValueError.
452
            try:
453
                region = self.reader.read_region(**region_kwargs)
454
            except ValueError as e:
455
                log.warning(f"Error reading region with kwargs=({region_kwargs}): {e}")
456
                return None
457
458
        # Resize using the same interpolation strategy as the Libvips backend (cv2).
459
        if resize_factor:
460
            target_size = (int(np.round(extract_size[0] * resize_factor)),
461
                           int(np.round(extract_size[1] * resize_factor)))
462
            if not __cv2_resize__:
463
                region = resize(cucim2numpy(region), target_size)
464
465
        # Final conversions.
466
        if flatten and region.shape[-1] == 4:
467
            region = region[:, :, 0:3]
468
        if (convert
469
            and convert.lower() in ('jpg', 'jpeg', 'png', 'numpy')
470
            and not isinstance(region, np.ndarray)):
471
            region = cucim2numpy(region)
472
        if resize_factor and __cv2_resize__:
473
            region = cv2.resize(region, target_size)
474
        if convert and convert.lower() in ('jpg', 'jpeg'):
475
            return numpy2jpg(region)
476
        elif convert and convert.lower() == 'png':
477
            return numpy2png(region)
478
        return region
479
480
    def read_from_pyramid(
481
        self,
482
        top_left: Tuple[int, int],
483
        window_size: Tuple[int, int],
484
        target_size: Tuple[int, int],
485
        *,
486
        convert: Optional[str] = None,
487
        flatten: bool = False,
488
        pad_missing: Optional[bool] = None
489
    ) -> "CuImage":
490
        """Reads a region from the image using base layer coordinates.
491
        Performance is accelerated by pyramid downsample layers, if available.
492
493
        Args:
494
            top_left (Tuple[int, int]): Top-left location of the region to
495
                extract, using base layer coordinates (x, y).
496
            window_size (Tuple[int, int]): Size of the region to read (width,
497
                height) using base layer coordinates.
498
            target_size (Tuple[int, int]): Resize the region to this target
499
                size (width, height).
500
501
        Keyword args:
502
            convert (str, optional): Convert the image to a different format.
503
                Supported formats are 'jpg', 'jpeg', 'png', and 'numpy'.
504
                Defaults to None.
505
            flatten (bool, optional): Flatten the image to 3 channels.
506
                Defaults to False.
507
            pad_missing (bool, optional): Pad missing regions with black.
508
                If None, uses the value of the `pad_missing` attribute.
509
                Defaults to None.
510
511
        Returns:
512
            CuImage: Image. Dimensions will equal target_size unless
513
            the window includes an area of the image which is out of bounds.
514
            In this case, the returned image will be cropped.
515
        """
516
        target_downsample = window_size[0] / target_size[0]
517
        ds_level = self.best_level_for_downsample(target_downsample)
518
519
        # Use a lower downsample level if the window size is too small
520
        ds = self.level_downsamples[ds_level]
521
        if not int(window_size[0] / ds) or not int(window_size[1] / ds):
522
            ds_level = max(0, ds_level-1)
523
            ds = self.level_downsamples[ds_level]
524
525
        # Define region kwargs
526
        region_kwargs = dict(
527
            location=top_left,
528
            size=(int(window_size[0] / ds), int(window_size[1] / ds)),
529
            level=ds_level,
530
            num_workers=self.num_workers,
531
        )
532
        if ((pad_missing is not None and pad_missing)
533
              or (pad_missing is None and self.pad_missing)):
534
            region = cucim_padded_crop(self.reader, **region_kwargs)
535
        else:
536
            region = self.read_region(**region_kwargs)
537
538
        # Resize using the same interpolation strategy as the Libvips backend (cv2).
539
        if not __cv2_resize__:
540
            region = resize(cucim2numpy(region), (target_size[1], target_size[0]))
541
542
        # Final conversions
543
        if flatten and region.shape[-1] == 4:
544
            region = region[:, :, 0:3]
545
        if (convert
546
            and convert.lower() in ('jpg', 'jpeg', 'png', 'numpy')
547
            and not isinstance(region, np.ndarray)):
548
            region = cucim2numpy(region)
549
        if __cv2_resize__:
550
            region = cv2.resize(region, target_size)
551
        if convert and convert.lower() in ('jpg', 'jpeg'):
552
            return numpy2jpg(region)
553
        elif convert and convert.lower() == 'png':
554
            return numpy2png(region)
555
        return region
556
557
    def thumbnail(
558
        self,
559
        width: int = 512,
560
        level: Optional[int] = None,
561
        associated: bool = False
562
    ) -> np.ndarray:
563
        """Return thumbnail of slide as numpy array."""
564
        if associated:
565
            log.debug("associated=True not implemented for cucim() thumbnail,"
566
                      "reading from lowest-magnification layer.")
567
        if level is None:
568
            level = self.level_count - 1
569
        w, h = self.dimensions
570
        height = int((width * h) / w)
571
        img = self.read_level(level=level)
572
        if __cv2_resize__:
573
            img = cucim2numpy(img)
574
            return cv2.resize(img, (width, height))
575
        else:
576
            img = resize(np.asarray(img), (width, height))
577
            return cucim2numpy(img)