|
a |
|
b/pathaia/datasets/data.py |
|
|
1 |
"""A module to handle data generation for deep neural networks. |
|
|
2 |
|
|
|
3 |
It uses the tf.data.Dataset object to enable parallel computing of batches. |
|
|
4 |
""" |
|
|
5 |
import numpy as np |
|
|
6 |
import openslide |
|
|
7 |
import tensorflow as tf |
|
|
8 |
from typing import Sequence, Callable, Iterator, Any, Tuple, Optional, Dict, Union |
|
|
9 |
from ..util.types import Patch, NDByteImage |
|
|
10 |
|
|
|
11 |
|
|
|
12 |
def slide_query(patch: Patch, patch_size: int) -> NDByteImage: |
|
|
13 |
""" |
|
|
14 |
Query patch image in slide. |
|
|
15 |
|
|
|
16 |
Get patch image given position, level and dimensions. |
|
|
17 |
|
|
|
18 |
Args: |
|
|
19 |
patch: the patch to query. |
|
|
20 |
patch_size: size of side of the patch in pixels. |
|
|
21 |
|
|
|
22 |
Returns: |
|
|
23 |
Numpy array rgb image of the patch. |
|
|
24 |
|
|
|
25 |
""" |
|
|
26 |
slide = openslide.OpenSlide(patch["slide"]) |
|
|
27 |
pil_img = slide.read_region( |
|
|
28 |
(patch["x"], patch["y"]), patch["level"], (patch_size, patch_size) |
|
|
29 |
) |
|
|
30 |
return np.array(pil_img)[:, :, 0:3] |
|
|
31 |
|
|
|
32 |
|
|
|
33 |
def fast_slide_query( |
|
|
34 |
slides: Dict[str, openslide.OpenSlide], |
|
|
35 |
patch: Patch, |
|
|
36 |
patch_size: int |
|
|
37 |
) -> NDByteImage: |
|
|
38 |
""" |
|
|
39 |
Query patch image in slide. |
|
|
40 |
|
|
|
41 |
Get patch image given the slide obj, the position, level and dimensions. |
|
|
42 |
|
|
|
43 |
Args: |
|
|
44 |
slide: the slide to request the patch. |
|
|
45 |
patch: the patch to query. |
|
|
46 |
patch_size: size of side of the patch in pixels. |
|
|
47 |
|
|
|
48 |
Returns: |
|
|
49 |
Numpy array rgb image of the patch. |
|
|
50 |
|
|
|
51 |
""" |
|
|
52 |
slide = slides[patch["slide"]] |
|
|
53 |
pil_img = slide.read_region( |
|
|
54 |
(patch["x"], patch["y"]), patch["level"], (patch_size, patch_size) |
|
|
55 |
) |
|
|
56 |
return np.array(pil_img)[:, :, 0:3] |
|
|
57 |
|
|
|
58 |
|
|
|
59 |
def generator_fn( |
|
|
60 |
patch_list: Sequence[Patch], |
|
|
61 |
label_list: Sequence[Any], |
|
|
62 |
patch_size: int, |
|
|
63 |
preproc: Callable |
|
|
64 |
) -> Iterator[Tuple[Patch, Any]]: |
|
|
65 |
""" |
|
|
66 |
Provide a generator for tf.data.Dataset. |
|
|
67 |
|
|
|
68 |
Create a scope with required arguments, but produce a arg-less iterator. |
|
|
69 |
|
|
|
70 |
Args: |
|
|
71 |
patch_list: patch list to query. |
|
|
72 |
label_list: label of patches. |
|
|
73 |
patch_size: size of the side of the patches in pixels. |
|
|
74 |
preproc: a preprocessing function for images. |
|
|
75 |
Returns: |
|
|
76 |
A generator of tuples (patch, label). |
|
|
77 |
|
|
|
78 |
""" |
|
|
79 |
def generator(): |
|
|
80 |
for patch, y in zip(patch_list, label_list): |
|
|
81 |
x = slide_query(patch, patch_size) |
|
|
82 |
yield preproc(x), y |
|
|
83 |
|
|
|
84 |
return generator |
|
|
85 |
|
|
|
86 |
|
|
|
87 |
def get_tf_dataset( |
|
|
88 |
patch_list: Sequence[Patch], |
|
|
89 |
label_list: Any, |
|
|
90 |
preproc: Callable, |
|
|
91 |
batch_size: int, |
|
|
92 |
patch_size: int, |
|
|
93 |
prefetch: Optional[int] = None, |
|
|
94 |
training: Optional[bool] = True, |
|
|
95 |
) -> tf.data.Dataset: |
|
|
96 |
""" |
|
|
97 |
Create tensorflow dataset. |
|
|
98 |
|
|
|
99 |
Create tf.data.Dataset object able to prefetch and batch samples from generator. |
|
|
100 |
|
|
|
101 |
Args: |
|
|
102 |
patch_list: patch list to query. |
|
|
103 |
label_list: label of patches. |
|
|
104 |
preproc: a preprocessing function for images. |
|
|
105 |
batch_size: number of samples per batch. |
|
|
106 |
patch_size: size (pixel) of the side of a square patch. |
|
|
107 |
|
|
|
108 |
Returns: |
|
|
109 |
tf.data.Dataset: a proper tensorflow dataset to fit on. |
|
|
110 |
|
|
|
111 |
""" |
|
|
112 |
gen = generator_fn(patch_list, label_list, patch_size, preproc) |
|
|
113 |
try: |
|
|
114 |
shape_label = label_list[0].shape |
|
|
115 |
except AttributeError: |
|
|
116 |
shape_label = None |
|
|
117 |
dataset = tf.data.Dataset.from_generator( |
|
|
118 |
generator=gen, |
|
|
119 |
output_types=(np.float32, np.int32), |
|
|
120 |
output_shapes=((patch_size, patch_size, 3), shape_label), |
|
|
121 |
) |
|
|
122 |
if training: |
|
|
123 |
dataset = dataset.batch(batch_size, drop_remainder=True) |
|
|
124 |
dataset = dataset.repeat() |
|
|
125 |
else: |
|
|
126 |
dataset = dataset.batch(batch_size, drop_remainder=False) |
|
|
127 |
# prefetch |
|
|
128 |
# <=> while fitting batch b, prepare b+k in parallel |
|
|
129 |
if prefetch is None: |
|
|
130 |
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) |
|
|
131 |
else: |
|
|
132 |
dataset = dataset.prefetch(prefetch) |
|
|
133 |
return dataset |