|
a |
|
b/docs-source/source/ssl.rst |
|
|
1 |
.. currentmodule:: slideflow.simclr |
|
|
2 |
|
|
|
3 |
.. _simclr_ssl: |
|
|
4 |
|
|
|
5 |
Self-Supervised Learning (SSL) |
|
|
6 |
============================== |
|
|
7 |
|
|
|
8 |
Slideflow provides easy access to training the self-supervised, contrastive learning framework `SimCLR <https://arxiv.org/abs/2002.05709>`_. Self-supervised learning provides an avenue for learning useful visual representations in your dataset without requiring ground-truth labels. These visual representations can be exported as feature vectors and used for downstream analyses such as :ref:`dimensionality reduction <slidemap>` or :ref:`multi-instance learning <mil>`. |
|
|
9 |
|
|
|
10 |
The ``slideflow.simclr`` module contains a `forked Tensorflow implementation <https://github.com/jamesdolezal/simclr/>`_ minimally modified to interface with Slideflow. SimCLR models can be trained with :meth:`slideflow.Project.train_simclr`, and SimCLR features can be calculated as with other models using :meth:`slideflow.Project.generate_features`. |
|
|
11 |
|
|
|
12 |
Training SimCLR |
|
|
13 |
*************** |
|
|
14 |
|
|
|
15 |
First, determine the SimCLR training parameters with :func:`slideflow.simclr.get_args`. This function accepts parameters via keyword arguments, such as ``learning_rate`` and ``temperature``, and returns a configured :class:`slideflow.simclr.SimCLR_Args`. |
|
|
16 |
|
|
|
17 |
.. code-block:: python |
|
|
18 |
|
|
|
19 |
from slideflow import simclr |
|
|
20 |
|
|
|
21 |
args = simclr.get_args( |
|
|
22 |
temperature=0.1, |
|
|
23 |
learning_rate=0.3, |
|
|
24 |
train_epochs=100, |
|
|
25 |
image_size=299 |
|
|
26 |
) |
|
|
27 |
|
|
|
28 |
Next, assemble a training and (optionally) a validation dataset. The validation dataset is used to assess contrastive loss during training, but is not required. |
|
|
29 |
|
|
|
30 |
.. code-block:: python |
|
|
31 |
|
|
|
32 |
import slideflow as sf |
|
|
33 |
|
|
|
34 |
# Load a project and dataset |
|
|
35 |
P = sf.load_project('path') |
|
|
36 |
dataset = P.dataset(tile_px=299, tile_um=302) |
|
|
37 |
|
|
|
38 |
# Split dataset into training/validation |
|
|
39 |
train_dts, val_dts = dataset.split( |
|
|
40 |
val_fraction=0.3, |
|
|
41 |
model_type='classification', |
|
|
42 |
labels='subtype') |
|
|
43 |
|
|
|
44 |
Finally, SimCLR can be trained with :meth:`slideflow.Project.train_simclr`. You can train with a single dataset: |
|
|
45 |
|
|
|
46 |
.. code-block:: python |
|
|
47 |
|
|
|
48 |
P.train_simclr(args, dataset) |
|
|
49 |
|
|
|
50 |
You can train with an optional validation dataset: |
|
|
51 |
|
|
|
52 |
.. code-block:: python |
|
|
53 |
|
|
|
54 |
P.train_simclr( |
|
|
55 |
args, |
|
|
56 |
train_dataset=train_dts, |
|
|
57 |
val_dataset=val_dts |
|
|
58 |
) |
|
|
59 |
|
|
|
60 |
And you can also optionally provide labels for training the supervised head. To train a supervised head, you'll also need to set the SimCLR argument ``lineareval_while_pretraining=True``. |
|
|
61 |
|
|
|
62 |
.. code-block:: python |
|
|
63 |
|
|
|
64 |
# SimCLR args |
|
|
65 |
args = simclr.get_args( |
|
|
66 |
..., |
|
|
67 |
lineareval_while_pretraining=True |
|
|
68 |
) |
|
|
69 |
|
|
|
70 |
# Train with validation & supervised head |
|
|
71 |
P.train_simclr( |
|
|
72 |
args, |
|
|
73 |
train_dataset=train_dts, |
|
|
74 |
val_dataset=val_dts, |
|
|
75 |
outcomes='subtype' |
|
|
76 |
) |
|
|
77 |
|
|
|
78 |
The SimCLR model checkpoints and final saved model will be saved in the ``simclr/`` folder within the project root directory. |
|
|
79 |
|
|
|
80 |
.. _dinov2: |
|
|
81 |
|
|
|
82 |
Training DINOv2 |
|
|
83 |
*************** |
|
|
84 |
|
|
|
85 |
A lightly modified version of `DINOv2 <https://arxiv.org/abs/2304.07193>`__ with Slideflow integration is available on `GitHub <https://github.com/jamesdolezal/dinov2>`_. This version facilitates training DINOv2 with Slideflow datasets and adds stain augmentation to the training pipeline. |
|
|
86 |
|
|
|
87 |
To train DINOv2, first install the package: |
|
|
88 |
|
|
|
89 |
.. code-block:: bash |
|
|
90 |
|
|
|
91 |
pip install git+https://github.com/jamesdolezal/dinov2.git |
|
|
92 |
|
|
|
93 |
Next, configure the training parameters and datsets by providing a configuration YAML file. This configuration file should contain a ``slideflow`` section, which specifies the Slideflow project and dataset to use for training. An example YAML file is shown below: |
|
|
94 |
|
|
|
95 |
.. code-block:: yaml |
|
|
96 |
|
|
|
97 |
train: |
|
|
98 |
dataset_path: slideflow |
|
|
99 |
batch_size_per_gpu: 32 |
|
|
100 |
slideflow: |
|
|
101 |
project: "/mnt/data/projects/TCGA_THCA_BRAF" |
|
|
102 |
dataset: |
|
|
103 |
tile_px: 299 |
|
|
104 |
tile_um: 302 |
|
|
105 |
filters: |
|
|
106 |
brs_class: |
|
|
107 |
- "Braf-like" |
|
|
108 |
- "Ras-like" |
|
|
109 |
seed: 42 |
|
|
110 |
outcome_labels: "brs_class" |
|
|
111 |
normalizer: "reinhard_mask" |
|
|
112 |
interleave_kwargs: null |
|
|
113 |
|
|
|
114 |
See the `DINOv2 README <https://github.com/jamesdolezal/dinov2>`_ for more details on the configuration file format. |
|
|
115 |
|
|
|
116 |
Finally, train DINOv2 using the same command-line interface as the original DINOv2 implementation. For example, to train DINOv2 on 4 GPUs on a single node: |
|
|
117 |
|
|
|
118 |
.. code-block:: bash |
|
|
119 |
|
|
|
120 |
torchrun --nproc_per_node=4 -m "dinov2.train.train" \ |
|
|
121 |
--config-file /path/to/config.yaml \ |
|
|
122 |
--output-dir /path/to/output_dir |
|
|
123 |
|
|
|
124 |
The teacher weights will be saved in ``outdir/eval/.../teacher_checkpoint.pth``, and the final configuration YAML will be saved in ``outdir/config.yaml``. |
|
|
125 |
|
|
|
126 |
Generating features |
|
|
127 |
******************* |
|
|
128 |
|
|
|
129 |
Generating features from a trained SSL is straightforward - use the same :meth:`slideflow.Project.generate_features` and :class:`slideflow.DatasetFeatures` interfaces as :ref:`previously described <dataset_features>`, providing a path to a saved SimCLR model or checkpoint. |
|
|
130 |
|
|
|
131 |
.. code-block:: python |
|
|
132 |
|
|
|
133 |
import slideflow as sf |
|
|
134 |
|
|
|
135 |
# Create the SimCLR feature extractor |
|
|
136 |
simclr = sf.build_feature_extractor( |
|
|
137 |
'simclr', |
|
|
138 |
ckpt='/path/to/simclr.ckpt' |
|
|
139 |
) |
|
|
140 |
|
|
|
141 |
# Calculate SimCLR features for a dataset |
|
|
142 |
features = P.generate_features(simclr, ...) |
|
|
143 |
|
|
|
144 |
For DINOv2 models, use ``'dinov2'`` as the first argument, and pass the model configuration YAML file to ``cfg`` and the teacher checkpoint weights to ``weights``. |
|
|
145 |
|
|
|
146 |
.. code-block:: python |
|
|
147 |
|
|
|
148 |
dinov2 = build_feature_extractor( |
|
|
149 |
'dinov2', |
|
|
150 |
weights='/path/to/teacher_checkpoint.pth', |
|
|
151 |
cfg='/path/to/config.yaml' |
|
|
152 |
) |