a b/docs-source/source/ssl.rst
1
.. currentmodule:: slideflow.simclr
2
3
.. _simclr_ssl:
4
5
Self-Supervised Learning (SSL)
6
==============================
7
8
Slideflow provides easy access to training the self-supervised, contrastive learning framework `SimCLR <https://arxiv.org/abs/2002.05709>`_. Self-supervised learning provides an avenue for learning useful visual representations in your dataset without requiring ground-truth labels. These visual representations can be exported as feature vectors and used for downstream analyses such as :ref:`dimensionality reduction <slidemap>` or :ref:`multi-instance learning <mil>`.
9
10
The ``slideflow.simclr`` module contains a `forked Tensorflow implementation <https://github.com/jamesdolezal/simclr/>`_ minimally modified to interface with Slideflow. SimCLR models can be trained with :meth:`slideflow.Project.train_simclr`, and SimCLR features can be calculated as with other models using :meth:`slideflow.Project.generate_features`.
11
12
Training SimCLR
13
***************
14
15
First, determine the SimCLR training parameters with :func:`slideflow.simclr.get_args`. This function accepts parameters via keyword arguments, such as ``learning_rate`` and ``temperature``, and returns a configured :class:`slideflow.simclr.SimCLR_Args`.
16
17
.. code-block:: python
18
19
    from slideflow import simclr
20
21
    args = simclr.get_args(
22
        temperature=0.1,
23
        learning_rate=0.3,
24
        train_epochs=100,
25
        image_size=299
26
    )
27
28
Next, assemble a training and (optionally) a validation dataset. The validation dataset is used to assess contrastive loss during training, but is not required.
29
30
.. code-block:: python
31
32
    import slideflow as sf
33
34
    # Load a project and dataset
35
    P = sf.load_project('path')
36
    dataset = P.dataset(tile_px=299, tile_um=302)
37
38
    # Split dataset into training/validation
39
    train_dts, val_dts = dataset.split(
40
        val_fraction=0.3,
41
        model_type='classification',
42
        labels='subtype')
43
44
Finally, SimCLR can be trained with :meth:`slideflow.Project.train_simclr`. You can train with a single dataset:
45
46
.. code-block:: python
47
48
    P.train_simclr(args, dataset)
49
50
You can train with an optional validation dataset:
51
52
.. code-block:: python
53
54
    P.train_simclr(
55
        args,
56
        train_dataset=train_dts,
57
        val_dataset=val_dts
58
    )
59
60
And you can also optionally provide labels for training the supervised head. To train a supervised head, you'll also need to set the SimCLR argument ``lineareval_while_pretraining=True``.
61
62
.. code-block:: python
63
64
    # SimCLR args
65
    args = simclr.get_args(
66
        ...,
67
        lineareval_while_pretraining=True
68
    )
69
70
    # Train with validation & supervised head
71
    P.train_simclr(
72
        args,
73
        train_dataset=train_dts,
74
        val_dataset=val_dts,
75
        outcomes='subtype'
76
    )
77
78
The SimCLR model checkpoints and final saved model will be saved in the ``simclr/`` folder within the project root directory.
79
80
.. _dinov2:
81
82
Training DINOv2
83
***************
84
85
A lightly modified version of `DINOv2 <https://arxiv.org/abs/2304.07193>`__ with Slideflow integration is available on `GitHub <https://github.com/jamesdolezal/dinov2>`_. This version facilitates training DINOv2 with Slideflow datasets and adds stain augmentation to the training pipeline.
86
87
To train DINOv2, first install the package:
88
89
.. code-block:: bash
90
91
    pip install git+https://github.com/jamesdolezal/dinov2.git
92
93
Next, configure the training parameters and datsets by providing a configuration YAML file. This configuration file should contain a ``slideflow`` section, which specifies the Slideflow project and dataset to use for training. An example YAML file is shown below:
94
95
.. code-block:: yaml
96
97
    train:
98
      dataset_path: slideflow
99
      batch_size_per_gpu: 32
100
      slideflow:
101
        project: "/mnt/data/projects/TCGA_THCA_BRAF"
102
        dataset:
103
          tile_px: 299
104
          tile_um: 302
105
          filters:
106
            brs_class:
107
            - "Braf-like"
108
            - "Ras-like"
109
        seed: 42
110
        outcome_labels: "brs_class"
111
        normalizer: "reinhard_mask"
112
        interleave_kwargs: null
113
114
See the `DINOv2 README <https://github.com/jamesdolezal/dinov2>`_ for more details on the configuration file format.
115
116
Finally, train DINOv2 using the same command-line interface as the original DINOv2 implementation. For example, to train DINOv2 on 4 GPUs on a single node:
117
118
.. code-block:: bash
119
120
    torchrun --nproc_per_node=4 -m "dinov2.train.train" \
121
        --config-file /path/to/config.yaml \
122
        --output-dir /path/to/output_dir
123
124
The teacher weights will be saved in ``outdir/eval/.../teacher_checkpoint.pth``, and the final configuration YAML will be saved in ``outdir/config.yaml``.
125
126
Generating features
127
*******************
128
129
Generating features from a trained SSL is straightforward - use the same :meth:`slideflow.Project.generate_features` and :class:`slideflow.DatasetFeatures` interfaces as :ref:`previously described <dataset_features>`, providing a path to a saved SimCLR model or checkpoint.
130
131
.. code-block:: python
132
133
    import slideflow as sf
134
135
    # Create the SimCLR feature extractor
136
    simclr = sf.build_feature_extractor(
137
        'simclr',
138
        ckpt='/path/to/simclr.ckpt'
139
    )
140
141
    # Calculate SimCLR features for a dataset
142
    features = P.generate_features(simclr, ...)
143
144
For DINOv2 models, use ``'dinov2'`` as the first argument, and pass the model configuration YAML file to ``cfg`` and the teacher checkpoint weights to ``weights``.
145
146
.. code-block:: python
147
148
    dinov2 = build_feature_extractor(
149
        'dinov2',
150
        weights='/path/to/teacher_checkpoint.pth',
151
        cfg='/path/to/config.yaml'
152
    )