a b/docs-source/source/custom_loops.rst
1
Custom Training Loops
2
=====================
3
4
To use ``*.tfrecords`` from extracted tiles in a custom training loop or entirely separate architecture (such as `StyleGAN2 <https://github.com/jamesdolezal/stylegan2-slideflow>`_ or `YoloV5 <https://github.com/ultralytics/yolov5>`_), Tensorflow ``tf.data.Dataset`` or PyTorch ``torch.utils.data.DataLoader`` objects can be created for easily serving processed images to your custom trainer.
5
6
TFRecord DataLoader
7
*******************
8
9
The :class:`slideflow.Dataset` class includes functions to prepare a Tensorflow ``tf.data.Dataset`` or PyTorch ``torch.utils.data.DataLoader`` object to interleave and process images from stored TFRecords. First, create a ``Dataset`` object at a given tile size:
10
11
.. code-block:: python
12
13
    from slideflow import Project
14
15
    P = Project('/project/path', ...)
16
    dts = P.dataset(tile_px=299, tile_um=302)
17
18
If you want to perform any mini-batch balancing, use the ``.balance()`` method:
19
20
.. code-block:: python
21
22
    dts = dts.balance('HPV_status', strategy='category')
23
24
Other dataset options can also be applied at this step. For example, to clip the maximum number of tiles to take from a slide, use the ``.clip()`` method:
25
26
.. code-block:: python
27
28
    dts = dts.clip(500)
29
30
Finally, use the :meth:`slideflow.Dataset.torch` method to create a DataLoader object:
31
32
.. code-block:: python
33
34
    dataloader = dts.torch(
35
        labels       = ...       # Your outcome label
36
        batch_size   = 64,       # Batch size
37
        num_workers  = 6,        # Number of workers reading tfrecords
38
        infinite     = True,     # True for training, False for validation
39
        augment      = True,     # Flip/rotate/compression augmentation
40
        standardize  = True,     # Standardize images: mean 0, variance of 1
41
        pin_memory   = False,    # Pin memory to GPUs
42
    )
43
44
or the :meth:`slideflow.Dataset.tensorflow` method to create a ``tf.data.Dataset``:
45
46
.. code-block:: python
47
48
    dataloader = dts.tensorflow(
49
        labels       = ...       # Your outcome label
50
        batch_size   = 64,       # Batch size
51
        infinite     = True,     # True for training, False for validation
52
        augment      = True,     # Flip/rotate/compression augmentation
53
        standardize  = True,     # Standardize images
54
    )
55
56
The returned dataloaders can then be used directly with your external applications. Read more about :ref:`creating and using dataloaders <dataloaders>`.