Diff of /pathaia/datasets/data.py [000000] .. [7823dd]

Switch to unified view

a b/pathaia/datasets/data.py
1
"""A module to handle data generation for deep neural networks.
2
3
It uses the tf.data.Dataset object to enable parallel computing of batches.
4
"""
5
import numpy as np
6
import openslide
7
import tensorflow as tf
8
from typing import Sequence, Callable, Iterator, Any, Tuple, Optional, Dict, Union
9
from ..util.types import Patch, NDByteImage
10
11
12
def slide_query(patch: Patch, patch_size: int) -> NDByteImage:
13
    """
14
    Query patch image in slide.
15
16
    Get patch image given position, level and dimensions.
17
18
    Args:
19
        patch: the patch to query.
20
        patch_size: size of side of the patch in pixels.
21
22
    Returns:
23
        Numpy array rgb image of the patch.
24
25
    """
26
    slide = openslide.OpenSlide(patch["slide"])
27
    pil_img = slide.read_region(
28
        (patch["x"], patch["y"]), patch["level"], (patch_size, patch_size)
29
    )
30
    return np.array(pil_img)[:, :, 0:3]
31
32
33
def fast_slide_query(
34
    slides: Dict[str, openslide.OpenSlide],
35
    patch: Patch,
36
    patch_size: int
37
) -> NDByteImage:
38
    """
39
    Query patch image in slide.
40
41
    Get patch image given the slide obj, the position, level and dimensions.
42
43
    Args:
44
        slide: the slide to request the patch.
45
        patch: the patch to query.
46
        patch_size: size of side of the patch in pixels.
47
48
    Returns:
49
        Numpy array rgb image of the patch.
50
51
    """
52
    slide = slides[patch["slide"]]
53
    pil_img = slide.read_region(
54
        (patch["x"], patch["y"]), patch["level"], (patch_size, patch_size)
55
    )
56
    return np.array(pil_img)[:, :, 0:3]
57
58
59
def generator_fn(
60
    patch_list: Sequence[Patch],
61
    label_list: Sequence[Any],
62
    patch_size: int,
63
    preproc: Callable
64
) -> Iterator[Tuple[Patch, Any]]:
65
    """
66
    Provide a generator for tf.data.Dataset.
67
68
    Create a scope with required arguments, but produce a arg-less iterator.
69
70
    Args:
71
        patch_list: patch list to query.
72
        label_list: label of patches.
73
        patch_size: size of the side of the patches in pixels.
74
        preproc: a preprocessing function for images.
75
    Returns:
76
        A generator of tuples (patch, label).
77
78
    """
79
    def generator():
80
        for patch, y in zip(patch_list, label_list):
81
            x = slide_query(patch, patch_size)
82
            yield preproc(x), y
83
84
    return generator
85
86
87
def get_tf_dataset(
88
    patch_list: Sequence[Patch],
89
    label_list: Any,
90
    preproc: Callable,
91
    batch_size: int,
92
    patch_size: int,
93
    prefetch: Optional[int] = None,
94
    training: Optional[bool] = True,
95
) -> tf.data.Dataset:
96
    """
97
    Create tensorflow dataset.
98
99
    Create tf.data.Dataset object able to prefetch and batch samples from generator.
100
101
    Args:
102
        patch_list: patch list to query.
103
        label_list: label of patches.
104
        preproc: a preprocessing function for images.
105
        batch_size: number of samples per batch.
106
        patch_size: size (pixel) of the side of a square patch.
107
108
    Returns:
109
        tf.data.Dataset: a proper tensorflow dataset to fit on.
110
111
    """
112
    gen = generator_fn(patch_list, label_list, patch_size, preproc)
113
    try:
114
        shape_label = label_list[0].shape
115
    except AttributeError:
116
        shape_label = None
117
    dataset = tf.data.Dataset.from_generator(
118
        generator=gen,
119
        output_types=(np.float32, np.int32),
120
        output_shapes=((patch_size, patch_size, 3), shape_label),
121
    )
122
    if training:
123
        dataset = dataset.batch(batch_size, drop_remainder=True)
124
        dataset = dataset.repeat()
125
    else:
126
        dataset = dataset.batch(batch_size, drop_remainder=False)
127
    # prefetch
128
    # <=> while fitting batch b, prepare b+k in parallel
129
    if prefetch is None:
130
        dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
131
    else:
132
        dataset = dataset.prefetch(prefetch)
133
    return dataset