--- a +++ b/docs-source/source/ssl.rst @@ -0,0 +1,152 @@ +.. currentmodule:: slideflow.simclr + +.. _simclr_ssl: + +Self-Supervised Learning (SSL) +============================== + +Slideflow provides easy access to training the self-supervised, contrastive learning framework `SimCLR <https://arxiv.org/abs/2002.05709>`_. Self-supervised learning provides an avenue for learning useful visual representations in your dataset without requiring ground-truth labels. These visual representations can be exported as feature vectors and used for downstream analyses such as :ref:`dimensionality reduction <slidemap>` or :ref:`multi-instance learning <mil>`. + +The ``slideflow.simclr`` module contains a `forked Tensorflow implementation <https://github.com/jamesdolezal/simclr/>`_ minimally modified to interface with Slideflow. SimCLR models can be trained with :meth:`slideflow.Project.train_simclr`, and SimCLR features can be calculated as with other models using :meth:`slideflow.Project.generate_features`. + +Training SimCLR +*************** + +First, determine the SimCLR training parameters with :func:`slideflow.simclr.get_args`. This function accepts parameters via keyword arguments, such as ``learning_rate`` and ``temperature``, and returns a configured :class:`slideflow.simclr.SimCLR_Args`. + +.. code-block:: python + + from slideflow import simclr + + args = simclr.get_args( + temperature=0.1, + learning_rate=0.3, + train_epochs=100, + image_size=299 + ) + +Next, assemble a training and (optionally) a validation dataset. The validation dataset is used to assess contrastive loss during training, but is not required. + +.. code-block:: python + + import slideflow as sf + + # Load a project and dataset + P = sf.load_project('path') + dataset = P.dataset(tile_px=299, tile_um=302) + + # Split dataset into training/validation + train_dts, val_dts = dataset.split( + val_fraction=0.3, + model_type='classification', + labels='subtype') + +Finally, SimCLR can be trained with :meth:`slideflow.Project.train_simclr`. You can train with a single dataset: + +.. code-block:: python + + P.train_simclr(args, dataset) + +You can train with an optional validation dataset: + +.. code-block:: python + + P.train_simclr( + args, + train_dataset=train_dts, + val_dataset=val_dts + ) + +And you can also optionally provide labels for training the supervised head. To train a supervised head, you'll also need to set the SimCLR argument ``lineareval_while_pretraining=True``. + +.. code-block:: python + + # SimCLR args + args = simclr.get_args( + ..., + lineareval_while_pretraining=True + ) + + # Train with validation & supervised head + P.train_simclr( + args, + train_dataset=train_dts, + val_dataset=val_dts, + outcomes='subtype' + ) + +The SimCLR model checkpoints and final saved model will be saved in the ``simclr/`` folder within the project root directory. + +.. _dinov2: + +Training DINOv2 +*************** + +A lightly modified version of `DINOv2 <https://arxiv.org/abs/2304.07193>`__ with Slideflow integration is available on `GitHub <https://github.com/jamesdolezal/dinov2>`_. This version facilitates training DINOv2 with Slideflow datasets and adds stain augmentation to the training pipeline. + +To train DINOv2, first install the package: + +.. code-block:: bash + + pip install git+https://github.com/jamesdolezal/dinov2.git + +Next, configure the training parameters and datsets by providing a configuration YAML file. This configuration file should contain a ``slideflow`` section, which specifies the Slideflow project and dataset to use for training. An example YAML file is shown below: + +.. code-block:: yaml + + train: + dataset_path: slideflow + batch_size_per_gpu: 32 + slideflow: + project: "/mnt/data/projects/TCGA_THCA_BRAF" + dataset: + tile_px: 299 + tile_um: 302 + filters: + brs_class: + - "Braf-like" + - "Ras-like" + seed: 42 + outcome_labels: "brs_class" + normalizer: "reinhard_mask" + interleave_kwargs: null + +See the `DINOv2 README <https://github.com/jamesdolezal/dinov2>`_ for more details on the configuration file format. + +Finally, train DINOv2 using the same command-line interface as the original DINOv2 implementation. For example, to train DINOv2 on 4 GPUs on a single node: + +.. code-block:: bash + + torchrun --nproc_per_node=4 -m "dinov2.train.train" \ + --config-file /path/to/config.yaml \ + --output-dir /path/to/output_dir + +The teacher weights will be saved in ``outdir/eval/.../teacher_checkpoint.pth``, and the final configuration YAML will be saved in ``outdir/config.yaml``. + +Generating features +******************* + +Generating features from a trained SSL is straightforward - use the same :meth:`slideflow.Project.generate_features` and :class:`slideflow.DatasetFeatures` interfaces as :ref:`previously described <dataset_features>`, providing a path to a saved SimCLR model or checkpoint. + +.. code-block:: python + + import slideflow as sf + + # Create the SimCLR feature extractor + simclr = sf.build_feature_extractor( + 'simclr', + ckpt='/path/to/simclr.ckpt' + ) + + # Calculate SimCLR features for a dataset + features = P.generate_features(simclr, ...) + +For DINOv2 models, use ``'dinov2'`` as the first argument, and pass the model configuration YAML file to ``cfg`` and the teacher checkpoint weights to ``weights``. + +.. code-block:: python + + dinov2 = build_feature_extractor( + 'dinov2', + weights='/path/to/teacher_checkpoint.pth', + cfg='/path/to/config.yaml' + ) \ No newline at end of file