--- a +++ b/pathaia/datasets/data.py @@ -0,0 +1,133 @@ +"""A module to handle data generation for deep neural networks. + +It uses the tf.data.Dataset object to enable parallel computing of batches. +""" +import numpy as np +import openslide +import tensorflow as tf +from typing import Sequence, Callable, Iterator, Any, Tuple, Optional, Dict, Union +from ..util.types import Patch, NDByteImage + + +def slide_query(patch: Patch, patch_size: int) -> NDByteImage: + """ + Query patch image in slide. + + Get patch image given position, level and dimensions. + + Args: + patch: the patch to query. + patch_size: size of side of the patch in pixels. + + Returns: + Numpy array rgb image of the patch. + + """ + slide = openslide.OpenSlide(patch["slide"]) + pil_img = slide.read_region( + (patch["x"], patch["y"]), patch["level"], (patch_size, patch_size) + ) + return np.array(pil_img)[:, :, 0:3] + + +def fast_slide_query( + slides: Dict[str, openslide.OpenSlide], + patch: Patch, + patch_size: int +) -> NDByteImage: + """ + Query patch image in slide. + + Get patch image given the slide obj, the position, level and dimensions. + + Args: + slide: the slide to request the patch. + patch: the patch to query. + patch_size: size of side of the patch in pixels. + + Returns: + Numpy array rgb image of the patch. + + """ + slide = slides[patch["slide"]] + pil_img = slide.read_region( + (patch["x"], patch["y"]), patch["level"], (patch_size, patch_size) + ) + return np.array(pil_img)[:, :, 0:3] + + +def generator_fn( + patch_list: Sequence[Patch], + label_list: Sequence[Any], + patch_size: int, + preproc: Callable +) -> Iterator[Tuple[Patch, Any]]: + """ + Provide a generator for tf.data.Dataset. + + Create a scope with required arguments, but produce a arg-less iterator. + + Args: + patch_list: patch list to query. + label_list: label of patches. + patch_size: size of the side of the patches in pixels. + preproc: a preprocessing function for images. + Returns: + A generator of tuples (patch, label). + + """ + def generator(): + for patch, y in zip(patch_list, label_list): + x = slide_query(patch, patch_size) + yield preproc(x), y + + return generator + + +def get_tf_dataset( + patch_list: Sequence[Patch], + label_list: Any, + preproc: Callable, + batch_size: int, + patch_size: int, + prefetch: Optional[int] = None, + training: Optional[bool] = True, +) -> tf.data.Dataset: + """ + Create tensorflow dataset. + + Create tf.data.Dataset object able to prefetch and batch samples from generator. + + Args: + patch_list: patch list to query. + label_list: label of patches. + preproc: a preprocessing function for images. + batch_size: number of samples per batch. + patch_size: size (pixel) of the side of a square patch. + + Returns: + tf.data.Dataset: a proper tensorflow dataset to fit on. + + """ + gen = generator_fn(patch_list, label_list, patch_size, preproc) + try: + shape_label = label_list[0].shape + except AttributeError: + shape_label = None + dataset = tf.data.Dataset.from_generator( + generator=gen, + output_types=(np.float32, np.int32), + output_shapes=((patch_size, patch_size, 3), shape_label), + ) + if training: + dataset = dataset.batch(batch_size, drop_remainder=True) + dataset = dataset.repeat() + else: + dataset = dataset.batch(batch_size, drop_remainder=False) + # prefetch + # <=> while fitting batch b, prepare b+k in parallel + if prefetch is None: + dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) + else: + dataset = dataset.prefetch(prefetch) + return dataset