Switch to unified view

a b/docs-source/source/dataloaders.rst
1
.. _dataloaders:
2
3
Dataloaders: Sampling and Augmentation
4
======================================
5
6
With support for both Tensorflow and PyTorch, Slideflow provides several options for dataset sampling, processing, and augmentation. Here, we'll review the options for creating dataloaders - objects that read and process TFRecord data and return images and labels - in each framework. In all cases, data are read from TFRecords generated through :ref:`filtering`. The TFRecord data format is discussed in more detail in the :ref:`tfrecords` note.
7
8
Tensorflow
9
**********
10
11
.. |TFRecordDataset| replace:: ``tf.data.TFRecordDataset``
12
.. _TFRecordDataset: https://www.tensorflow.org/api_docs/python/tf/data/TFRecordDataset
13
14
The :meth:`slideflow.Dataset.tensorflow()` method provides an easy interface for creating a ``tf.data.Dataset`` that reads and interleaves from tfrecords in a Slideflow dataset. Behind the scenes, this method uses the |TFRecordDataset|_ class for reading and parsing each TFRecord.
15
16
The returned ``tf.data.Dataset`` object is an iterable-only dataset whose returned values depend on the arguments provided to the ``.tensorflow()`` function.
17
18
If no arguments are provided, the returned dataset will yield a tuple of ``(image, None)``, where the image is a ``tf.Tensor`` of shape ``[tile_height, tile_width, num_channels]`` and type ``tf.uint8``.
19
20
If the ``labels`` argument is provided (dictionary mapping slide names to a numeric label), the returned dataset will yield a tuple of ``(image, label)``, where the label is a ``tf.Tensor`` with a shape and type that matches the provided labels.
21
22
.. code-block:: python
23
24
    import slideflow as sf
25
26
    # Create a dataset object
27
    project = sf.load_project(...)
28
    dataset = project.dataset(...)
29
30
    # Get the labels
31
    labels, unique_labels = dataset.labels('HPV_status')
32
33
    # Create a tensorflow dataset
34
    # that yields (image, label) tuples
35
    tf_dataset = dataset.tensorflow(labels=labels)
36
37
    for image, label in tf_dataset:
38
        # Do something with the image and label...
39
        ...
40
41
Slide names and tile locations
42
------------------------------
43
44
Dataloaders can be configured to return slide names and tile locations in addition to the image and label. This is done by providing the ``incl_slidenames`` and ``incl_loc`` arguments to the ``.tensorflow()`` method. Both arguments are boolean values and default to ``False``.
45
46
Setting ``incl_slidenames=True`` will return the slidename as a Tensor (dtype=string) after the label. Setting ``incl_loc=True`` will return the x and y locations, both as Tensors (dtype=int64), as the last two values of the tuple.
47
48
.. code-block:: python
49
50
    tf_dataset = dataset.tensorflow(incl_slidenames=True, incl_loc=True)
51
52
    for image, label, slide, loc_x, loc_y in tf_dataset:
53
        ...
54
55
Image preprocessing
56
-------------------
57
58
.. |per_image_standardization| replace:: ``tf.image.per_image_standardization()``
59
.. _per_image_standardization: https://www.tensorflow.org/api_docs/python/tf/image/per_image_standardization
60
61
Dataloaders created with ``.tensorflow()`` include several image preprocessing options. These options are provided as keyword arguments to the ``.tensorflow()`` method and are executed in the order listed below:
62
63
- **crop_left** (int): Crop images to this top-left x/y coordinate. Default is ``None``.
64
- **crop_width** (int): Crop images to this width. Default is ``None``.
65
- **resize_target** (int): Resize images to this width/height. Default is ``None``.
66
- **resize_method** (str): Resize method. Default is ``"lanczos3"``.
67
- **resize_aa** (bool): Enable antialiasing if resizing. Defaults to ``True``.
68
- **normalizer** (``StainNormalizer``): Perform stain normalization.
69
- **augment** (str): Perform augmentations based on the provided string. Combine characters to perform multiple augmentations (e.g. ``'xyrj'``). Options include:
70
    - ``'n'``: Perform :ref:`stain_augmentation` (done concurrently with stain normalization)
71
    - ``'j'``: Random JPEG compression (50% chance to compress with quality between 50-100)
72
    - ``'r'``: Random 90-degree rotation
73
    - ``'x'``: Random horizontal flip
74
    - ``'y'``: Random vertical flip
75
    - ``'b'``: Random Gaussian blur (10% chance to blur with sigma between 0.5-2.0)
76
- **transform** (Any): Arbitrary function to apply to each image. The function must accept a single argument (the image) and return a single value (the transformed image).
77
- **standardize** (bool): Standardize images with |per_image_standardization|_, returning a ``tf.float32`` image. Default is ``False``, returning a ``tf.uint8`` image.
78
79
Dataset sharding
80
----------------
81
82
Tensorflow dataloaders can be sharded into multiple partitions, ensuring that data is not duplicated when performing distributed training across multiple processes or nodes. This is done by providing the ``shard_idx`` and ``num_shards`` arguments to the ``.tensorflow()`` method. The ``shard_idx`` argument is an integer specifying the shard number, and ``num_shards`` is an integer specifying the total number of shards.
83
84
.. code-block:: python
85
86
    # Shard the dataset for GPU 1 of 4
87
    tf_dataset = dataset.tensorflow(
88
        ...,
89
        shard_idx=0,
90
        num_shards=4
91
    )
92
93
PyTorch
94
*******
95
96
.. |DataLoader| replace:: ``torch.utils.data.DataLoader``
97
.. _DataLoader: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
98
99
As with Tensorflow, the :meth:`slideflow.Dataset.torch()` method creates a |DataLoader|_ that reads images from TFRecords. In the backend, TFRecords are read using :func:`slideflow.tfrecord.torch.MultiTFRecordDataset` and processed as described in :ref:`tfrecords`.
100
101
The returned |DataLoader|_ is an iterable-only dataloader whose returned values depend on the arguments provided to the ``.torch()`` function. An indexable, map-style dataset is also available when using PyTorch, as described in :ref:`indexable_dataloader`.
102
103
If no arguments are provided, the returned dataloader will yield a tuple of ``(image, None)``, where the image is a ``torch.Tensor`` of shape ``[num_channels, tile_height, tile_width]`` and type ``torch.uint8``. Labels are assigned as described above. Slide names and tile location can also be returned, using the same arguments as `described above <https://slideflow.dev/dataloaders/#slide-names-and-tile-locations>`_.
104
105
106
.. code-block:: python
107
108
    import slideflow as sf
109
110
    # Create a dataset object
111
    project = sf.load_project(...)
112
    dataset = project.dataset(...)
113
114
    # Create a tensorflow dataset
115
    torch_dl = dataset.torch()
116
117
    for image, label in torch_dl:
118
        # Do something with the image...
119
        ...
120
121
Image preprocessing
122
-------------------
123
124
Dataloaders created with ``.torch()`` include several image preprocessing options, provided as keyword arguments to the ``.torch()`` method. These preprocessing steps are executed in the order listed below:
125
126
- **normalizer** (``StainNormalizer``): Perform stain normalization.
127
- **augment** (str): Perform augmentations based on the provided string. Combine characters to perform multiple augmentations (e.g. ``'xyrj'``). Augmentations are executed in the order characters appear in the string. Options include:
128
    - ``'n'``: Perform :ref:`stain_augmentation` (done concurrently with stain normalization)
129
    - ``'j'``: Random JPEG compression (50% chance to compress with quality between 50-100)
130
    - ``'r'``: Random 90-degree rotation
131
    - ``'x'``: Random horizontal flip
132
    - ``'y'``: Random vertical flip
133
    - ``'b'``: Random Gaussian blur (10% chance to blur with sigma between 0.5-2.0)
134
- **transform** (Any): Arbitrary function to apply to each image, including `torchvision transforms <https://pytorch.org/vision/main/transforms.html>`_. The function must accept a single argument (the image, in ``(num_channels, height, width)`` format) and return a single value (the transformed image).
135
- **standardize** (bool): Standardize images with ``image / 127.5 - 1``, returning a ``torch.float32`` image. Default is ``False``, returning a ``torch.uint8`` image.
136
137
Below is an example of using the ``transform`` argument to apply a torchvision transform to each image:
138
139
.. code-block:: python
140
141
    import torchvision.transforms as T
142
143
    # Create a torch dataloader
144
    torch_dataloader = dataset.torch(
145
        transform=T.Compose([
146
            RandomResizedCrop(size=(224, 224), antialias=True),
147
            Normalize(mean=[0.485, 0.456, 0.406],
148
                      std=[0.229, 0.224, 0.225]),
149
        ])
150
    )
151
152
    for image, label in torch_dataloader:
153
        # Do something with the image and label...
154
        ...
155
156
Dataset sharding
157
----------------
158
159
PyTorch Dataloaders can similarly be sharded into multiple partitions, ensuring that data is not duplicated when performing distributed training across multiple process or nodes.
160
161
Sharding is done in two stages. First, dataloaders can be split into partitions using the ``rank`` and ``num_replicas`` arguments to the ``.torch()`` method. The ``rank`` argument is an integer specifying the rank of the current process, and ``num_replicas`` is an integer specifying the total number of processes.
162
163
.. code-block:: python
164
165
    # Shard the dataset for GPU 1 of 4
166
    torch_dataloader = dataset.torch(
167
        ...,
168
        rank=0,
169
        num_replicas=4
170
    )
171
172
The second stage of sharding happens in the background: if a dataloader is built with multiple worker processes (``Dataset.torch(num_workers=...)``), partitions will be automatically further subdivided into smaller chunks, ensuring that each worker process reads a unique subset of the data.
173
174
Labeling
175
********
176
177
The ``label`` argument to the ``.tensorflow()`` and ``.torch()`` methods accept a dictionary mapping slide names to a numeric label. During TFRecord reading, the slide name is used to lookup the label from the provided dictionary.
178
179
.. warning::
180
181
    Labels are assigned to image tiles based on the slide names inside a :ref:`tfrecord <tfrecords>` file, not by the filename of the tfrecord. This means that renaming a TFRecord file will not change the label of the tiles inside the file. If you need to change the slide names associated with tiles inside a TFRecord, the TFRecord file must be regenerated.
182
183
The most common way to generate labels is to use the :meth:`slideflow.Dataset.labels()` method, which returns a dictionary mapping slide names to numeric labels. For categorical labels, the numeric labels correspond to the index of the label in the ``unique_labels`` list. For example, if the ``unique_labels`` list is ``['HPV-', 'HPV+']``, then the mapping of numeric labels would be ``{ 'HPV-': 0, 'HPV+': 1 }``.
184
185
.. code-block:: python
186
187
    >>> labels, unique_labels = dataset.labels('HPV_status')
188
    >>> unique_labels
189
    ['HPV-', 'HPV+']
190
    >>> labels
191
    {'slide1': 0,
192
     'slide2': 1,
193
     ...
194
    }
195
    >>> tf_dataset = dataset.tensorflow(labels=labels)
196
197
.. _sampling:
198
199
Sampling
200
********
201
202
Dataloaders created with ``.tensorflow()`` and ``.torch()`` are iterable-only dataloaders, meaning that they cannot be indexed directly. This is because the underlying TFRecords are sampled in a streaming fashion, and the dataloader does not know what the next record will be until it has been read. This is in contrast to the :ref:`indexable_dataloader` method described below, which creates an indexable, map-style dataset.
203
204
Dataloaders created with ``.tensorflow()`` and ``.torch()`` can be configured to sample from TFRecords in several ways, with options for infinite vs. finite sampling, oversampling, and undersampling. These sampling methods are described below.
205
206
Infinite dataloaders
207
--------------------
208
209
By default, dataloaders created with ``.tensorflow()`` and ``.torch()`` will sample from TFRecords in an infinite loop. This is useful for training, where the dataloader should continue to yield images until the training process is complete. By default, images are sampled from TFRecords with uniform sampling, meaning that each TFRecord has an equal chance of yielding an image. This sampling strategy can be configured, as described below.
210
211
.. note::
212
213
    When training :ref:`tile-based models <training>`, a dataloader is considered to have yielded one "epoch" of data when it has yielded the number of images equal to the number of tiles in the dataset. Due to the random sampling from TFRecords, this means that some images will be overrepresented (images from TFRecords with fewer tiles) and some will be underrepresented (images from TFRecords with many tiles).
214
215
Finite dataloaders
216
------------------
217
218
Dataloaders can also be configured with finite sampling, yielding tiles from TFRecords exactly once. This is accomplished by passing the argument ``infinite=False`` to the ``.tensorflow()`` or ``.torch()`` methods.
219
220
.. _balancing:
221
222
Oversampling with balancing
223
---------------------------
224
225
Oversampling methods control the probability that tiles are read from each TFRecord, affecting the balance of data across slides, patients, and outcome categories. Oversampling is configured at the Dataset level, using the :meth:`slideflow.Dataset.balance` method. This method returns a copy of the dataset with the specified oversampling strategy.
226
227
**Slide-level balancing**: By default, images are sampled from TFRecords with uniform probability, meaning that each TFRecord has an equal chance of yielding an image. This is equivalent to both ``.balance(strategy='slide')`` and ``.balance(strategy=None)``. This strategy will oversample images from slides with fewer tiles, and undersample images from slides with more tiles.
228
229
.. code-block:: python
230
231
    # Sample from TFRecords with equal probability
232
    dataset = dataset.balance(strategy='slide')
233
234
**Patient-level balancing**: To sample from TFRecords with probability proportional to the number of tiles in each patient, use ``.balance(strategy='patient')``. This strategy will oversample images from patients with fewer tiles, and undersample images from patients with more tiles.
235
236
.. code-block:: python
237
238
    # Sample from TFRecords with probability proportional
239
    # to the number of tiles in each patient.
240
    dataset = dataset.balance(strategy='patient')
241
242
**Tile-level balancing**: To sample from TFRecords with uniform probability across image tiles, use ``.balance(strategy='tile')``. This strategy will sample from TFRecords with probability proportional to the number of tiles in the TFRecord, resulting in higher representation of slides with more tiles.
243
244
.. code-block:: python
245
246
    # Sample from TFRecords with probability proportional
247
    # to the number of tiles in each TFRecord.
248
    dataset = dataset.balance(strategy='tile')
249
250
**Category-level balancing**: To sample from TFRecords with probability proportional to the number of tiles in each outcome category, use ``.balance(strategy='category')``. This strategy will oversample images from outcome categories with fewer tiles, and undersample images from outcome categories with more tiles. This strategy will also perform slide-level balancing within each category. Category-level balancing is only available when using categorical labels.
251
252
.. code-block:: python
253
254
    # Sample from TFRecords with probability proportional
255
    # to the number of tiles in each category
256
    # "HPV-" and "HPV+".
257
    dataset = dataset.balance("HPV_status", strategy='category')
258
259
**Custom balancing**: The ``.balance()`` method saves sampling probability weights to ``Dataset.prob_weights``, a dictionary mapping TFRecord paths to sampling weights. Custom balancing can be performed by overriding this dictionary with custom weights.
260
261
.. code-block:: python
262
263
    >>> dataset = dataset.balance(strategy='slide')
264
    >>> dataset.prob_weights
265
    {'/path/to/tfrecord1': 0.002,
266
     '/path/to/tfrecord2': 0.003,
267
     ...
268
    }
269
    >>> dataset.prob_weights = {...}
270
271
Balancing is automatically applied to dataloaders created with the ``.tensorflow()`` and ``.torch()`` methods.
272
273
Undersampling with clipping
274
---------------------------
275
276
Datasets can also be configured to undersample TFRecords using :meth:`slideflow.Dataset.clip`. Several undersampling strategies are available.
277
278
**Slide-level clipping**: TFRecords can be clipped to a maximum number of tiles per slide using ``.clip(max_tiles)``. This strategy will clip TFRecords with more tiles than the specified ``max_tiles`` value, resulting in a maximum of ``max_tiles`` tiles per slide.
279
280
**Patient-level clipping**: TFRecords can be clipped to a maximum number of tiles per patient using ``.clip(max_tiles, strategy='patient')``. For patients with more than one slide/TFRecord, TFRecords will be clipped proportionally.
281
282
**Outcome-level clipping**: TFRecords can also be clipped to a maximum number of tiles per outcome category using ``.clip(max_tiles, strategy='category', headers=...)``. The outcome category is specified by the ``headers`` argument, which can be a single header name or a list of header names. Within each category, TFRecords will be clipped proportionally.
283
284
**Custom clipping**: The ``.clip()`` method saves clipping values to ``Dataset._clip``, a dictionary mapping TFRecord paths to counts of how many tiles should be sampled from the TFRecord. Custom clipping can be performed by overriding this dictionary with custom weights.
285
286
.. code-block:: python
287
288
    >>> dataset = dataset.clip(100)
289
    >>> dataset._clip
290
    {'/path/to/tfrecord1': 76,
291
     '/path/to/tfrecord2': 100,
292
     ...
293
    }
294
    >>> dataset._clip = {...}
295
296
Undersampling via dataset clipping is automatically applied to dataloaders created with ``.tensorflow()`` and ``.torch()``.
297
298
During training
299
---------------
300
301
If you are training a Slideflow model by directly providing a training and validation dataset to the :meth:`slideflow.Project.train` method, you can configure the datasets to perform oversampling and undersampling as described above. For example:
302
303
.. code-block:: python
304
305
    import slideflow as sf
306
307
    # Load a project
308
    project = sf.load_project(...)
309
310
    # Configure a training dataset with tile-level balancing
311
    # and clipping to max 100 tiles per TFRecord
312
    train = project.dataset(...).balance(strategy='tile').clip(100)
313
314
    # Get a validation dataset
315
    val = project.dataset(...)
316
317
    # Train a model
318
    project.train(
319
        ...,
320
        dataset=train,
321
        val_dataset=val,
322
    )
323
324
Alternatively, you can configure oversampling during training through the ``training_balance`` and ``validation_balance`` hyperparameters, as described in the :ref:`ModelParams <model_params>` documentation. Undersampling with dataset clipping can be performed with the ``max_tiles`` argument. Configuring oversampling/undersampling with this method propagates the configuration to all datasets generated during cross-validation.
325
326
.. code-block:: python
327
328
    import slideflow as sf
329
330
    # Load a project
331
    project = sf.load_project(...)
332
333
    # Configure hyperparameters with tile-level
334
    # balancing/oversampling for the training data
335
    hp = sf.ModelParams(
336
        ...,
337
        training_balance='tile',
338
        validation_balance=None,
339
    )
340
341
    # Train a model.
342
    # Undersample/clip data to max 100 tiles per TFRecord.
343
    project.train(
344
        ...,
345
        params=hp,
346
        max_tiles=100
347
    )
348
349
350
.. _indexable_dataloader:
351
352
Direct indexing
353
***************
354
355
An indexable, map-style dataloader can be created for PyTorch using :class:`slideflow.io.torch.IndexedInterleaver`, which returns a ``torch.utils.data.Dataset``. Indexable datasets are only available for the PyTorch backend.
356
357
This indexable dataset is created from a list of TFRecords and accepts many arguments for controlling labels, augmentation and image transformations.
358
359
.. code-block:: python
360
361
    from slideflow.io.torch import IndexedInterleaver
362
363
    # Create a dataset object
364
    project = sf.load_project(...)
365
    dataset = project.dataset(...)
366
367
    # Get the TFRecords
368
    tfrecords = dataset.tfrecords()
369
370
    # Assemble labels
371
    labels, _ = dataset.labels("HPV_status")
372
373
    # Create an indexable dataset
374
    dts = IndexedInterleaver(
375
        tfrecords,
376
        labels=labels,
377
        augment="xyrj",
378
        transform=T.Compose([
379
            T.RandomResizedCrop(size=(224, 224),
380
                                antialias=True),
381
        ]),
382
        normalizer=None,
383
        standardize=True,
384
        shuffle=True,
385
        seed=42,
386
    )
387
388
The returned dataset is indexable, meaning that it can be indexed directly to retrieve a single image and label.
389
390
.. code-block:: python
391
392
    >>> len(dts)
393
    284114
394
    >>> image, label = dts[0]
395
    >>> image.shape
396
    torch.Size([3, 224, 224])
397
    >>> image.dtype
398
    torch.float32
399
400
The dataset can be configured to return slide names and tile locations by setting the ``incl_slidenames`` and ``incl_loc`` arguments to ``True``, as described above.
401
402
Dataset sharding is supported with the same ``rank`` and ``num_replicas`` arguments as described above.
403
404
.. code-block:: python
405
406
    # Shard for GPU 1 of 4
407
    dts = IndexedInterleaver(
408
        ...,
409
        rank=0,
410
        num_replicas=4
411
    )
412
413
:class:`slideflow.io.IndexedInterleaver` supports undersampling via the `clip` argument (array of clipping values for each TFRecord), but does not support oversampling or balancing.
414
415
.. code-block:: python
416
417
    # Specify TFRecord clipping values
418
    dts = IndexedInterleaver(
419
        tfrecords=...,
420
        clip=[100, 75, ...], # Same length as tfrecords
421
        ...
422
    )
423
424
A |DataLoader|_ can then be created from the indexable dataset using the ``torch.utils.data.DataLoader`` class, as described in the PyTorch documentation.
425
426
.. code-block:: python
427
428
    from torch.utils.data import DataLoader
429
430
    # Create a dataloader
431
    dl = DataLoader(
432
        dts,
433
        batch_size=32,
434
        num_workers=4,
435
        pin_memory=True,
436
        drop_last=True,
437
    )
438
439
    for image, label in dl:
440
        # Do something with the image and label...
441
        ...