Switch to unified view

a b/pathaia/util/management.py
1
"""
2
Helpful function to extract and organize data.
3
4
It takes advantage of the common structure of pathaia projects to enable
5
datasets creation and experiment monitoring/evaluation.
6
"""
7
8
import pandas as pd
9
import os
10
import warnings
11
from typing import Sequence, Tuple, Iterator, List
12
from .types import Patch, PathLike
13
from glob import glob
14
import numpy as np
15
from tensorflow.keras.applications import *
16
from tensorflow.keras.models import Model, Sequential
17
from tensorflow.keras.layers import GlobalAveragePooling2D
18
from ..datasets.data import get_tf_dataset
19
from tqdm import tqdm
20
21
22
class Error(Exception):
23
    """
24
    Base of custom errors.
25
26
    **********************
27
    """
28
29
    pass
30
31
32
class LevelNotFoundError(Error):
33
    """
34
    Raise when trying to access unknown level.
35
36
    *********************************************
37
    """
38
39
    pass
40
41
42
class EmptyProjectError(Error):
43
    """
44
    Raise when trying to access unknown level.
45
46
    *********************************************
47
    """
48
49
    pass
50
51
52
class SlideNotFoundError(Error):
53
    """
54
    Raise when trying to access unknown level.
55
56
    *********************************************
57
    """
58
59
    pass
60
61
62
class PatchesNotFoundError(Error):
63
    """
64
    Raise when trying to access unknown level.
65
66
    *********************************************
67
    """
68
69
    pass
70
71
72
class UnknownColumnError(Error):
73
    """
74
    Raise when trying to access unknown level.
75
76
    *********************************************
77
    """
78
79
    pass
80
81
82
def get_patch_csv_from_patch_folder(patch_folder: str) -> str:
83
    """
84
    Give csv of patches given the slide patch folder.
85
86
    Check existence of the path and return absolute path of the csv.
87
88
    Args:
89
        patch_folder: absolute path to a pathaia slide folder.
90
91
    Returns:
92
        Absolute path of csv patch file.
93
94
    """
95
    if os.path.isdir(patch_folder):
96
        patch_file = os.path.join(patch_folder, "patches.csv")
97
        if os.path.exists(patch_file):
98
            return patch_file
99
        raise PatchesNotFoundError(
100
            "Could not find extracted patches for the slide: {}".format(
101
                patch_folder
102
            )
103
        )
104
    raise SlideNotFoundError(
105
        "Could not find a patch folder at: {}!!!".format(patch_folder)
106
    )
107
108
109
def get_patch_folders_in_project(project_folder: str) -> Iterator[PathLike]:
110
    """
111
    Give pathaia slide folders from a pathaia project folder (direct subfolders).
112
113
    Check existence of the project and yield slide folders inside.
114
115
    Args:
116
        project_folder: absolute path to a pathaia project folder.
117
        exclude: a list of str to exclude from subfolders of the project.
118
    Yields:
119
        Absolute path to folder containing patches csv files.
120
121
    """
122
    for folder in glob(os.path.join(project_folder, '*')):
123
        patch_file = os.path.join(folder, "patches.csv")
124
        if os.path.exists(patch_file):
125
            yield folder
126
        else:
127
            for f in get_patch_folders_in_project(folder):
128
                yield f
129
130
131
def get_slide_file(
132
    slide_folder: str, project_folder: str, patch_folder: str,
133
    extensions: List[str] = ['.mrxs', '.svs']
134
) -> str:
135
    """
136
    Give the absolute path to a slide file.
137
138
    Get the slide absolute path if slide name and slide folder are provided.
139
140
    Args:
141
        slide_folder: absolute path to a folder of WSIs.
142
        project_folder: absolute path to a pathaia folder.
143
        patch_folder: absolute path to a folder containing a 'patches.csv'.
144
    Returns:
145
        Absolute path of the WSI.
146
147
    """
148
    if not os.path.isdir(slide_folder):
149
        raise SlideNotFoundError(
150
            "Could not find a slide folder at: {}!!!".format(slide_folder)
151
        )
152
    for ext in extensions:
153
        slide = patch_folder.replace(project_folder, slide_folder) + ext
154
        if os.path.exists(slide):
155
            return slide
156
    raise SlideNotFoundError(
157
        "Could not find an {} slide file for: {} in {}!!!".format(
158
            ext, slide, slide_folder
159
        )
160
    )
161
162
163
def read_patch_file(
164
    patch_file: str, slide_path: str, column: str = None, level: int = None
165
) -> Iterator[Tuple[dict, str]]:
166
    """
167
    Read a patch file.
168
169
    Read lines of the patch csv looking for 'column' label.
170
171
    Args:
172
        patch_file: absolute path to a csv patch file.
173
        level: pyramid level to query patches in the csv.
174
        slide_path: absolute path to a slide file.
175
        column: header of the column to use to label individual patches.
176
177
    Yields:
178
        Position and label of patches (x, y, label).
179
180
    """
181
    df = pd.read_csv(patch_file)
182
    if level is not None:
183
        df = df[df["level"] == level]
184
    if column not in df:
185
        for _, row in df.iterrows():
186
            yield {
187
                "x": row["x"],
188
                "y": row["y"],
189
                "dx": row["dx"],
190
                "dy": row["dy"],
191
                "id": row["id"],
192
                "level": row["level"],
193
                "slide_path": slide_path,
194
                "slide": slide_path,
195
                "slide_name": os.path.basename(slide_path)
196
            }, None
197
    else:
198
        for _, row in df.iterrows():
199
            yield {
200
                "x": row["x"],
201
                "y": row["y"],
202
                "dx": row["dx"],
203
                "dy": row["dy"],
204
                "id": row["id"],
205
                "level": row["level"],
206
                "slide_path": slide_path,
207
                "slide": slide_path,
208
                "slide_name": os.path.basename(slide_path)
209
            }, row[column]
210
211
212
def write_slide_predictions(
213
    slide_predictions: Iterator[Patch], slide_csv: str, column: str
214
):
215
    """
216
    Write slide predictions in a pathaia slide csv.
217
218
    Args:
219
        slide_predictions: iterator on patch dicts.
220
        slide_csv: absolute path to a pathaia slide csv.
221
        column: name of the prediction column to append in csv.
222
223
    """
224
    patch_df = pd.read_csv(slide_csv, sep=None, engine="python")
225
    patch_df = patch_df.set_index("id")
226
    for patch in slide_predictions:
227
        idx = patch["id"]
228
        pred = patch[column]
229
        patch_df.loc[idx, column] = pred
230
    patch_df.to_csv(slide_csv, index=False)
231
232
233
def descriptors_to_csv(
234
    descriptors: List[Tuple], filename: str, patch_list: List[Patch]
235
):
236
    """
237
    Write patch embeddings into a csv file.
238
239
    Args:
240
        descriptors: list of
241
        filename:
242
243
    """
244
    columns = ['id', 'level', 'x', 'y']
245
    descriptors = np.asarray(descriptors)
246
    for i in range(descriptors.shape[1]):
247
        columns.append(f'{i}')
248
    descriptor_df = pd.DataFrame([], columns=columns)
249
    for x in range(len(patch_list)):
250
        data = {'id': patch_list[x]['id'],
251
                'level': patch_list[x]['level'],
252
                'x': patch_list[x]['x'],
253
                'y': patch_list[x]['y']}
254
        for i in range(descriptors.shape[1]):
255
            data[f'{i}'] = descriptors[x, i]
256
        descriptor_df = descriptor_df.append(data, ignore_index=True)
257
    descriptor_df.to_csv(filename, index=False)
258
259
260
class PathaiaHandler(object):
261
    """
262
    A class to handle simple patch datasets.
263
264
    It usually computes the input of tf datasets proposed in pathaia.data.
265
266
    Args:
267
        project_folder: absolute path to a pathaia project.
268
        slide_folder: absolute path to a slide folder.
269
270
    """
271
272
    def __init__(self, project_folder: str, slide_folder: str):
273
        """Init PathaiaHandler."""
274
        self.slide_folder = slide_folder
275
        self.project_folder = project_folder
276
277
    def _iter_slides(self) -> Iterator[Tuple[str, str]]:
278
        """Yield slide folders with associated 'patches.csv'."""
279
        for folder in get_patch_folders_in_project(self.project_folder):
280
            try:
281
                slide_path = get_slide_file(
282
                    self.slide_folder, self.project_folder, folder
283
                )
284
                patch_file = get_patch_csv_from_patch_folder(folder)
285
            except (
286
                PatchesNotFoundError, UnknownColumnError, SlideNotFoundError
287
            ) as e:
288
                warnings.warn(str(e))
289
            yield slide_path, patch_file
290
291
    def random_split(
292
        self, ratio: float = 0.3
293
    ) -> Tuple[List[Tuple[str, str]], List[Tuple[str, str]]]:
294
        """
295
        Split whole slide dataset into training/validation.
296
297
        Args:
298
            ratio: ratio of slides to keep for validation.
299
        Returns:
300
            Training and validation datasets.
301
302
        """
303
        slides = []
304
        for slide in self._iter_slides():
305
            slides.append(slide)
306
        np.random.shuffle(slides)
307
        validation = slides[0:int(ratio * len(slides))]
308
        training = slides[int(ratio * len(slides))::]
309
        return training, validation
310
311
    def list_patches(
312
        self, level: int, dim: Tuple[int, int],
313
        column: str = None, slides: Iterator = None
314
    ) -> Tuple[List[Patch], List[str]]:
315
        """
316
        Create labeled patch dataset.
317
318
        Args:
319
            level: pyramid level to extract patches in csv.
320
            dim: dimensions of the patches in pixels.
321
            label: column header in csv to use as a category.
322
        Returns:
323
            List of patch dicts and list of labels.
324
325
        """
326
        patch_list = []
327
        labels = []
328
        slide_list = self._iter_slides()
329
        if slides is not None:
330
            slide_list = slides
331
        for slide_path, patch_file in slide_list:
332
            try:
333
                # read patch file and get the right level
334
                for patch, label in read_patch_file(
335
                    patch_file, slide_path, column, level
336
                ):
337
                    patch_list.append(patch)
338
                    labels.append(label)
339
            except (
340
                PatchesNotFoundError, UnknownColumnError, SlideNotFoundError
341
            ) as e:
342
                warnings.warn(str(e))
343
        return patch_list, labels
344
345
    def extract_features(
346
        self,
347
        model_name: str = 'ResNet50',
348
        slides: Iterator = None,
349
        patch_size: int = 224,
350
        level: int = None,
351
        layer: str = '',
352
        batch_size: int = 128
353
    ):
354
        """Extract features from patches with a model from keras applications."""
355
        models = {
356
            'ResNet50': {
357
                'model': resnet50.ResNet50,
358
                'module': resnet50
359
            }
360
        }
361
        preproc = models[model_name]['module'].preprocess_input
362
        ModelClass = models[model_name]['model']
363
        model = ModelClass(weights='imagenet', include_top=False,
364
                           pooling='avg',
365
                           input_shape=(patch_size, patch_size, 3))
366
        if not layer == '':
367
            layer_model = Model(inputs=model.input,
368
                                outputs=model.get_layer(layer).output)
369
            model = Sequential()
370
            model.add(layer_model)
371
            model.add(GlobalAveragePooling2D())
372
        slide_list = self._iter_slides()
373
        if slides is not None:
374
            slide_list = slides
375
        for slide_path, patch_file in tqdm(slide_list):
376
            try:
377
                patch_list = []
378
                label_list = []
379
                # read patch file and get the right level
380
                for patch, _ in read_patch_file(patch_file, slide_path,
381
                                                level=level):
382
                    patch_list.append(patch)
383
                    label_list.append(0)
384
            except (
385
                PatchesNotFoundError, UnknownColumnError, SlideNotFoundError
386
            ) as e:
387
                warnings.warn(str(e))
388
            if len(patch_list) == 0:
389
                # Raise error here
390
                print(f'No patches for slide {slide_path}')
391
                continue
392
            patch_set = get_tf_dataset(patch_list, label_list, preproc,
393
                                       batch_size=batch_size,
394
                                       patch_size=patch_size,
395
                                       training=False)
396
            descriptors = model.predict(patch_set)
397
            descriptor_csv = os.path.join(
398
                os.path.dirname(patch_file), f'features_{model_name}.csv'
399
            )
400
            descriptors_to_csv(descriptors, descriptor_csv, patch_list)