|
a |
|
b/docs-source/source/custom_loops.rst |
|
|
1 |
Custom Training Loops |
|
|
2 |
===================== |
|
|
3 |
|
|
|
4 |
To use ``*.tfrecords`` from extracted tiles in a custom training loop or entirely separate architecture (such as `StyleGAN2 <https://github.com/jamesdolezal/stylegan2-slideflow>`_ or `YoloV5 <https://github.com/ultralytics/yolov5>`_), Tensorflow ``tf.data.Dataset`` or PyTorch ``torch.utils.data.DataLoader`` objects can be created for easily serving processed images to your custom trainer. |
|
|
5 |
|
|
|
6 |
TFRecord DataLoader |
|
|
7 |
******************* |
|
|
8 |
|
|
|
9 |
The :class:`slideflow.Dataset` class includes functions to prepare a Tensorflow ``tf.data.Dataset`` or PyTorch ``torch.utils.data.DataLoader`` object to interleave and process images from stored TFRecords. First, create a ``Dataset`` object at a given tile size: |
|
|
10 |
|
|
|
11 |
.. code-block:: python |
|
|
12 |
|
|
|
13 |
from slideflow import Project |
|
|
14 |
|
|
|
15 |
P = Project('/project/path', ...) |
|
|
16 |
dts = P.dataset(tile_px=299, tile_um=302) |
|
|
17 |
|
|
|
18 |
If you want to perform any mini-batch balancing, use the ``.balance()`` method: |
|
|
19 |
|
|
|
20 |
.. code-block:: python |
|
|
21 |
|
|
|
22 |
dts = dts.balance('HPV_status', strategy='category') |
|
|
23 |
|
|
|
24 |
Other dataset options can also be applied at this step. For example, to clip the maximum number of tiles to take from a slide, use the ``.clip()`` method: |
|
|
25 |
|
|
|
26 |
.. code-block:: python |
|
|
27 |
|
|
|
28 |
dts = dts.clip(500) |
|
|
29 |
|
|
|
30 |
Finally, use the :meth:`slideflow.Dataset.torch` method to create a DataLoader object: |
|
|
31 |
|
|
|
32 |
.. code-block:: python |
|
|
33 |
|
|
|
34 |
dataloader = dts.torch( |
|
|
35 |
labels = ... # Your outcome label |
|
|
36 |
batch_size = 64, # Batch size |
|
|
37 |
num_workers = 6, # Number of workers reading tfrecords |
|
|
38 |
infinite = True, # True for training, False for validation |
|
|
39 |
augment = True, # Flip/rotate/compression augmentation |
|
|
40 |
standardize = True, # Standardize images: mean 0, variance of 1 |
|
|
41 |
pin_memory = False, # Pin memory to GPUs |
|
|
42 |
) |
|
|
43 |
|
|
|
44 |
or the :meth:`slideflow.Dataset.tensorflow` method to create a ``tf.data.Dataset``: |
|
|
45 |
|
|
|
46 |
.. code-block:: python |
|
|
47 |
|
|
|
48 |
dataloader = dts.tensorflow( |
|
|
49 |
labels = ... # Your outcome label |
|
|
50 |
batch_size = 64, # Batch size |
|
|
51 |
infinite = True, # True for training, False for validation |
|
|
52 |
augment = True, # Flip/rotate/compression augmentation |
|
|
53 |
standardize = True, # Standardize images |
|
|
54 |
) |
|
|
55 |
|
|
|
56 |
The returned dataloaders can then be used directly with your external applications. Read more about :ref:`creating and using dataloaders <dataloaders>`. |