<!DOCTYPE html>
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
<head>
<meta charset="utf-8">
<meta name="generator" content="Docutils 0.18.1: http://docutils.sourceforge.net/" />
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Tutorial 3: Using a custom architecture — slideflow 3.0.0 documentation</title>
<link rel="stylesheet" href="../_static/css/theme.css" type="text/css" />
<!-- <link rel="stylesheet" href="../_static/pygments.css" type="text/css" /> -->
<link rel="stylesheet" href="../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../_static/css/theme.css" type="text/css" />
<link rel="index" title="Index" href="../genindex/" />
<link rel="search" title="Search" href="../search/" />
<link rel="next" title="Tutorial 4: Model evaluation & heatmaps" href="../tutorial4/" />
<link rel="prev" title="Tutorial 2: Model training (advanced)" href="../tutorial2/" />
<script src="../_static/js/modernizr.min.js"></script>
<!-- Preload the theme fonts -->
<link rel="preload" href="../_static/fonts/FreightSans/freight-sans-book.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="../_static/fonts/FreightSans/freight-sans-medium.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="../_static/fonts/IBMPlexMono/IBMPlexMono-Medium.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="../_static/fonts/FreightSans/freight-sans-bold.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="../_static/fonts/FreightSans/freight-sans-medium-italic.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="../_static/fonts/IBMPlexMono/IBMPlexMono-SemiBold.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<!-- Preload the katex fonts -->
<link rel="preload" href="https://cdn.jsdelivr.net/npm/katex@0.10.0/dist/fonts/KaTeX_Math-Italic.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="https://cdn.jsdelivr.net/npm/katex@0.10.0/dist/fonts/KaTeX_Main-Regular.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="https://cdn.jsdelivr.net/npm/katex@0.10.0/dist/fonts/KaTeX_Main-Bold.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="https://cdn.jsdelivr.net/npm/katex@0.10.0/dist/fonts/KaTeX_Size1-Regular.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="https://cdn.jsdelivr.net/npm/katex@0.10.0/dist/fonts/KaTeX_Size4-Regular.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="https://cdn.jsdelivr.net/npm/katex@0.10.0/dist/fonts/KaTeX_Size2-Regular.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="https://cdn.jsdelivr.net/npm/katex@0.10.0/dist/fonts/KaTeX_Size3-Regular.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="https://cdn.jsdelivr.net/npm/katex@0.10.0/dist/fonts/KaTeX_Caligraphic-Regular.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="stylesheet" href="https://use.fontawesome.com/releases/v5.15.2/css/all.css" integrity="sha384-vSIIfh2YWi9wW0r9iZe7RJPrKwp6bG+s9QZMoITbCckVJqGCCRhc+ccxNcdpHuYu" crossorigin="anonymous">
<script defer data-domain="slideflow.dev" src="https://plausible.io/js/script.js"></script>
</head>
<div class="container-fluid header-holder tutorials-header" id="header-holder">
<div class="container">
<div class="header-container">
<a class="header-logo" href="https://slideflow.dev" aria-label="Slideflow"></a>
<div class="main-menu">
<ul>
<li class="active">
<a href="https://slideflow.dev">Docs</a>
</li>
<li>
<a href="https://slideflow.dev/tutorial1/">Tutorials</a>
</li>
<li>
<a href="https://github.com/slideflow/slideflow">GitHub</a>
</li>
</ul>
</div>
<a class="main-menu-open-button" href="#" data-behavior="open-mobile-menu"></a>
</div>
</div>
</div>
<body class="pytorch-body">
<div class="table-of-contents-link-wrapper">
<span>Table of Contents</span>
<a href="#" class="toggle-table-of-contents" data-behavior="toggle-table-of-contents"></a>
</div>
<nav data-toggle="wy-nav-shift" class="pytorch-left-menu" id="pytorch-left-menu">
<div class="pytorch-side-scroll">
<div class="pytorch-menu pytorch-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
<div class="pytorch-left-menu-search">
<div class="version">
3.0
</div>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../search/" method="get">
<input type="text" name="q" placeholder="Search Docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div>
<p class="caption" role="heading"><span class="caption-text">Introduction</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../installation/">Installation</a></li>
<li class="toctree-l1"><a class="reference internal" href="../overview/">Overview</a></li>
<li class="toctree-l1"><a class="reference internal" href="../quickstart/">Quickstart</a></li>
<li class="toctree-l1"><a class="reference internal" href="../project_setup/">Setting up a Project</a></li>
<li class="toctree-l1"><a class="reference internal" href="../datasets_and_val/">Datasets</a></li>
<li class="toctree-l1"><a class="reference internal" href="../slide_processing/">Slide Processing</a></li>
<li class="toctree-l1"><a class="reference internal" href="../training/">Training</a></li>
<li class="toctree-l1"><a class="reference internal" href="../evaluation/">Evaluation</a></li>
<li class="toctree-l1"><a class="reference internal" href="../posthoc/">Layer Activations</a></li>
<li class="toctree-l1"><a class="reference internal" href="../uq/">Uncertainty Quantification</a></li>
<li class="toctree-l1"><a class="reference internal" href="../features/">Generating Features</a></li>
<li class="toctree-l1"><a class="reference internal" href="../mil/">Multiple-Instance Learning (MIL)</a></li>
<li class="toctree-l1"><a class="reference internal" href="../ssl/">Self-Supervised Learning (SSL)</a></li>
<li class="toctree-l1"><a class="reference internal" href="../stylegan/">Generative Networks (GANs)</a></li>
<li class="toctree-l1"><a class="reference internal" href="../saliency/">Saliency Maps</a></li>
<li class="toctree-l1"><a class="reference internal" href="../segmentation/">Tissue Segmentation</a></li>
<li class="toctree-l1"><a class="reference internal" href="../cellseg/">Cell Segmentation</a></li>
<li class="toctree-l1"><a class="reference internal" href="../custom_loops/">Custom Training Loops</a></li>
<li class="toctree-l1"><a class="reference internal" href="../studio/">Slideflow Studio: Live Visualization</a></li>
<li class="toctree-l1"><a class="reference internal" href="../troubleshooting/">Troubleshooting</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Developer Notes</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../tfrecords/">TFRecords: Reading and Writing</a></li>
<li class="toctree-l1"><a class="reference internal" href="../dataloaders/">Dataloaders: Sampling and Augmentation</a></li>
<li class="toctree-l1"><a class="reference internal" href="../custom_extractors/">Custom Feature Extractors</a></li>
<li class="toctree-l1"><a class="reference internal" href="../tile_labels/">Strong Supervision with Tile Labels</a></li>
<li class="toctree-l1"><a class="reference internal" href="../plugins/">Creating a Slideflow Plugin</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">API</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../slideflow/">slideflow</a></li>
<li class="toctree-l1"><a class="reference internal" href="../project/">slideflow.Project</a></li>
<li class="toctree-l1"><a class="reference internal" href="../dataset/">slideflow.Dataset</a></li>
<li class="toctree-l1"><a class="reference internal" href="../dataset_features/">slideflow.DatasetFeatures</a></li>
<li class="toctree-l1"><a class="reference internal" href="../heatmap/">slideflow.Heatmap</a></li>
<li class="toctree-l1"><a class="reference internal" href="../model_params/">slideflow.ModelParams</a></li>
<li class="toctree-l1"><a class="reference internal" href="../mosaic/">slideflow.Mosaic</a></li>
<li class="toctree-l1"><a class="reference internal" href="../slidemap/">slideflow.SlideMap</a></li>
<li class="toctree-l1"><a class="reference internal" href="../biscuit/">slideflow.biscuit</a></li>
<li class="toctree-l1"><a class="reference internal" href="../slideflow_cellseg/">slideflow.cellseg</a></li>
<li class="toctree-l1"><a class="reference internal" href="../io/">slideflow.io</a></li>
<li class="toctree-l1"><a class="reference internal" href="../io_tensorflow/">slideflow.io.tensorflow</a></li>
<li class="toctree-l1"><a class="reference internal" href="../io_torch/">slideflow.io.torch</a></li>
<li class="toctree-l1"><a class="reference internal" href="../gan/">slideflow.gan</a></li>
<li class="toctree-l1"><a class="reference internal" href="../grad/">slideflow.grad</a></li>
<li class="toctree-l1"><a class="reference internal" href="../mil_module/">slideflow.mil</a></li>
<li class="toctree-l1"><a class="reference internal" href="../model/">slideflow.model</a></li>
<li class="toctree-l1"><a class="reference internal" href="../model_tensorflow/">slideflow.model.tensorflow</a></li>
<li class="toctree-l1"><a class="reference internal" href="../model_torch/">slideflow.model.torch</a></li>
<li class="toctree-l1"><a class="reference internal" href="../norm/">slideflow.norm</a></li>
<li class="toctree-l1"><a class="reference internal" href="../simclr/">slideflow.simclr</a></li>
<li class="toctree-l1"><a class="reference internal" href="../slide/">slideflow.slide</a></li>
<li class="toctree-l1"><a class="reference internal" href="../slide_qc/">slideflow.slide.qc</a></li>
<li class="toctree-l1"><a class="reference internal" href="../stats/">slideflow.stats</a></li>
<li class="toctree-l1"><a class="reference internal" href="../util/">slideflow.util</a></li>
<li class="toctree-l1"><a class="reference internal" href="../studio_module/">slideflow.studio</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Tutorials</span></p>
<ul class="current">
<li class="toctree-l1"><a class="reference internal" href="../tutorial1/">Tutorial 1: Model training (simple)</a></li>
<li class="toctree-l1"><a class="reference internal" href="../tutorial2/">Tutorial 2: Model training (advanced)</a></li>
<li class="toctree-l1 current"><a class="current reference internal" href="#">Tutorial 3: Using a custom architecture</a></li>
<li class="toctree-l1"><a class="reference internal" href="../tutorial4/">Tutorial 4: Model evaluation & heatmaps</a></li>
<li class="toctree-l1"><a class="reference internal" href="../tutorial5/">Tutorial 5: Creating a mosaic map</a></li>
<li class="toctree-l1"><a class="reference internal" href="../tutorial6/">Tutorial 6: Custom slide filtering</a></li>
<li class="toctree-l1"><a class="reference internal" href="../tutorial7/">Tutorial 7: Training with custom augmentations</a></li>
<li class="toctree-l1"><a class="reference internal" href="../tutorial8/">Tutorial 8: Multiple-Instance Learning</a></li>
</ul>
</div>
</div>
</nav>
<div class="pytorch-container">
<div class="pytorch-page-level-bar" id="pytorch-page-level-bar">
<div class="pytorch-breadcrumbs-wrapper">
<div role="navigation" aria-label="breadcrumbs navigation">
<ul class="pytorch-breadcrumbs">
<li>
<a href="../">
Docs
</a> >
</li>
<li>Tutorial 3: Using a custom architecture</li>
<li class="pytorch-breadcrumbs-aside">
<a href="../_sources/tutorial3.rst.txt" rel="nofollow"><img src="../_static/images/view-page-source-icon.svg"></a>
</li>
</ul>
</div>
</div>
<div class="pytorch-shortcuts-wrapper" id="pytorch-shortcuts-wrapper">
Shortcuts
</div>
</div>
<section data-toggle="wy-nav-shift" id="pytorch-content-wrap" class="pytorch-content-wrap">
<div class="pytorch-content-left">
<div class="rst-content">
<div role="main" class="main-content" itemscope="itemscope" itemtype="http://schema.org/Article">
<article itemprop="articleBody" id="pytorch-article" class="pytorch-article">
<section id="tutorial-3-using-a-custom-architecture">
<span id="tutorial3"></span><h1>Tutorial 3: Using a custom architecture<a class="headerlink" href="#tutorial-3-using-a-custom-architecture" title="Permalink to this heading">¶</a></h1>
<p>Out of the box, Slideflow includes support for 21 model architectures in the Tensorflow backend and 17 with the PyTorch backend. In this tutorial, we will demonstrate how to train a custom model architecture (ViT) in either backend.</p>
<section id="custom-tensorflow-model">
<h2>Custom Tensorflow model<a class="headerlink" href="#custom-tensorflow-model" title="Permalink to this heading">¶</a></h2>
<p>Any Tensorflow/Keras model (<code class="xref py py-class docutils literal notranslate"><span class="pre">tf.keras.Model</span></code>) can be trained in Slideflow by setting the <code class="docutils literal notranslate"><span class="pre">model</span></code> parameter of a <a class="reference internal" href="../model_params/#slideflow.ModelParams" title="slideflow.ModelParams"><code class="xref py py-class docutils literal notranslate"><span class="pre">slideflow.ModelParams</span></code></a> object to a function which initalizes the model.</p>
<p>First, define the model in a file that can be imported. In this example, we will define a vision transformer (ViT) model in a file <code class="docutils literal notranslate"><span class="pre">vit_tensorflow.py</span></code>:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="c1"># From:</span>
<span class="c1"># https://github.com/ashishpatel26/Vision-Transformer-Keras-Tensorflow-Pytorch-Examples/blob/main/Vision_Transformer_with_tf2.ipynb</span>
<span class="kn">import</span> <span class="nn">math</span>
<span class="kn">import</span> <span class="nn">six</span>
<span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="nn">tf</span>
<span class="kn">from</span> <span class="nn">einops.layers.tensorflow</span> <span class="kn">import</span> <span class="n">Rearrange</span>
<span class="k">def</span> <span class="nf">gelu</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
<span class="w"> </span><span class="sd">"""Gaussian Error Linear Unit.</span>
<span class="sd"> This is a smoother version of the RELU.</span>
<span class="sd"> Original paper: https://arxiv.org/abs/1606.08415</span>
<span class="sd"> Args:</span>
<span class="sd"> x: float Tensor to perform activation.</span>
<span class="sd"> Returns:</span>
<span class="sd"> `x` with the GELU activation applied.</span>
<span class="sd"> """</span>
<span class="n">cdf</span> <span class="o">=</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="p">(</span><span class="mf">1.0</span> <span class="o">+</span> <span class="n">tf</span><span class="o">.</span><span class="n">tanh</span><span class="p">(</span>
<span class="p">(</span><span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="mi">2</span> <span class="o">/</span> <span class="n">math</span><span class="o">.</span><span class="n">pi</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="n">x</span> <span class="o">+</span> <span class="mf">0.044715</span> <span class="o">*</span> <span class="n">tf</span><span class="o">.</span><span class="n">pow</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">3</span><span class="p">)))))</span>
<span class="k">return</span> <span class="n">x</span> <span class="o">*</span> <span class="n">cdf</span>
<span class="k">def</span> <span class="nf">get_activation</span><span class="p">(</span><span class="n">identifier</span><span class="p">):</span>
<span class="w"> </span><span class="sd">"""Maps a identifier to a Python function, e.g., "relu" => `tf.nn.relu`.</span>
<span class="sd"> It checks string first and if it is one of customized activation not in TF,</span>
<span class="sd"> the corresponding activation will be returned. For non-customized activation</span>
<span class="sd"> names and callable identifiers, always fallback to tf.keras.activations.get.</span>
<span class="sd"> Args:</span>
<span class="sd"> identifier: String name of the activation function or callable.</span>
<span class="sd"> Returns:</span>
<span class="sd"> A Python function corresponding to the activation function.</span>
<span class="sd"> """</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">identifier</span><span class="p">,</span> <span class="n">six</span><span class="o">.</span><span class="n">string_types</span><span class="p">):</span>
<span class="n">name_to_fn</span> <span class="o">=</span> <span class="p">{</span><span class="s2">"gelu"</span><span class="p">:</span> <span class="n">gelu</span><span class="p">}</span>
<span class="n">identifier</span> <span class="o">=</span> <span class="nb">str</span><span class="p">(</span><span class="n">identifier</span><span class="p">)</span><span class="o">.</span><span class="n">lower</span><span class="p">()</span>
<span class="k">if</span> <span class="n">identifier</span> <span class="ow">in</span> <span class="n">name_to_fn</span><span class="p">:</span>
<span class="k">return</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">activations</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">name_to_fn</span><span class="p">[</span><span class="n">identifier</span><span class="p">])</span>
<span class="k">return</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">activations</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">identifier</span><span class="p">)</span>
<span class="k">class</span> <span class="nc">Residual</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">):</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">fn</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">fn</span> <span class="o">=</span> <span class="n">fn</span>
<span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">fn</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="o">+</span> <span class="n">x</span>
<span class="k">class</span> <span class="nc">PreNorm</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">):</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">fn</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">norm</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">LayerNormalization</span><span class="p">(</span><span class="n">epsilon</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">fn</span> <span class="o">=</span> <span class="n">fn</span>
<span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">fn</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">x</span><span class="p">))</span>
<span class="k">class</span> <span class="nc">FeedForward</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">):</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">hidden_dim</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">net</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">([</span>
<span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span>
<span class="n">hidden_dim</span><span class="p">,</span>
<span class="n">activation</span><span class="o">=</span><span class="n">get_activation</span><span class="p">(</span><span class="s1">'gelu'</span><span class="p">)</span>
<span class="p">),</span>
<span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">dim</span><span class="p">)]</span>
<span class="p">)</span>
<span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">net</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="k">class</span> <span class="nc">Attention</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">):</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">heads</span><span class="o">=</span><span class="mi">8</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">heads</span> <span class="o">=</span> <span class="n">heads</span>
<span class="bp">self</span><span class="o">.</span><span class="n">scale</span> <span class="o">=</span> <span class="n">dim</span> <span class="o">**</span> <span class="o">-</span><span class="mf">0.5</span>
<span class="bp">self</span><span class="o">.</span><span class="n">to_qkv</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">dim</span> <span class="o">*</span> <span class="mi">3</span><span class="p">,</span> <span class="n">use_bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">to_out</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">rearrange_qkv</span> <span class="o">=</span> <span class="n">Rearrange</span><span class="p">(</span>
<span class="s1">'b n (qkv h d) -> qkv b h n d'</span><span class="p">,</span>
<span class="n">qkv</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span>
<span class="n">h</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">heads</span>
<span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">rearrange_out</span> <span class="o">=</span> <span class="n">Rearrange</span><span class="p">(</span><span class="s1">'b h n d -> b n (h d)'</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
<span class="n">qkv</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">to_qkv</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">qkv</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">rearrange_qkv</span><span class="p">(</span><span class="n">qkv</span><span class="p">)</span>
<span class="n">q</span> <span class="o">=</span> <span class="n">qkv</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">k</span> <span class="o">=</span> <span class="n">qkv</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
<span class="n">v</span> <span class="o">=</span> <span class="n">qkv</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span>
<span class="n">dots</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">'bhid,bhjd->bhij'</span><span class="p">,</span> <span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span>
<span class="n">attn</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">dots</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">out</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">'bhij,bhjd->bhid'</span><span class="p">,</span> <span class="n">attn</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span>
<span class="n">out</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">rearrange_out</span><span class="p">(</span><span class="n">out</span><span class="p">)</span>
<span class="n">out</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">to_out</span><span class="p">(</span><span class="n">out</span><span class="p">)</span>
<span class="k">return</span> <span class="n">out</span>
<span class="k">class</span> <span class="nc">Transformer</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">):</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">depth</span><span class="p">,</span> <span class="n">heads</span><span class="p">,</span> <span class="n">mlp_dim</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="n">layers</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">depth</span><span class="p">):</span>
<span class="n">layers</span><span class="o">.</span><span class="n">extend</span><span class="p">([</span>
<span class="n">Residual</span><span class="p">(</span><span class="n">PreNorm</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="n">Attention</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="n">heads</span><span class="o">=</span><span class="n">heads</span><span class="p">))),</span>
<span class="n">Residual</span><span class="p">(</span><span class="n">PreNorm</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="n">FeedForward</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="n">mlp_dim</span><span class="p">)))</span>
<span class="p">])</span>
<span class="bp">self</span><span class="o">.</span><span class="n">net</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span><span class="n">layers</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">net</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="k">class</span> <span class="nc">ViT</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">):</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span> <span class="n">image_size</span><span class="p">,</span> <span class="n">patch_size</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">,</span>
<span class="n">dim</span><span class="p">,</span> <span class="n">depth</span><span class="p">,</span> <span class="n">heads</span><span class="p">,</span> <span class="n">mlp_dim</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">image_size</span> <span class="o">%</span> <span class="n">patch_size</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s1">'image dimensions must be divisible by the '</span>
<span class="s1">'patch size'</span><span class="p">)</span>
<span class="n">num_patches</span> <span class="o">=</span> <span class="p">(</span><span class="n">image_size</span> <span class="o">//</span> <span class="n">patch_size</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span>
<span class="bp">self</span><span class="o">.</span><span class="n">patch_size</span> <span class="o">=</span> <span class="n">patch_size</span>
<span class="bp">self</span><span class="o">.</span><span class="n">dim</span> <span class="o">=</span> <span class="n">dim</span>
<span class="bp">self</span><span class="o">.</span><span class="n">pos_embedding</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">add_weight</span><span class="p">(</span>
<span class="s2">"position_embeddings"</span><span class="p">,</span>
<span class="n">shape</span><span class="o">=</span><span class="p">[</span><span class="n">num_patches</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="n">dim</span><span class="p">],</span>
<span class="n">initializer</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">initializers</span><span class="o">.</span><span class="n">RandomNormal</span><span class="p">(),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">float32</span>
<span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">patch_to_embedding</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">cls_token</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">add_weight</span><span class="p">(</span>
<span class="s2">"cls_token"</span><span class="p">,</span>
<span class="n">shape</span><span class="o">=</span><span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">dim</span><span class="p">],</span>
<span class="n">initializer</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">initializers</span><span class="o">.</span><span class="n">RandomNormal</span><span class="p">(),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">float32</span>
<span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">rearrange</span> <span class="o">=</span> <span class="n">Rearrange</span><span class="p">(</span>
<span class="s1">'b (h p1) (w p2) c -> b (h w) (p1 p2 c)'</span><span class="p">,</span>
<span class="n">p1</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">patch_size</span><span class="p">,</span>
<span class="n">p2</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">patch_size</span>
<span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">transformer</span> <span class="o">=</span> <span class="n">Transformer</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="n">depth</span><span class="p">,</span> <span class="n">heads</span><span class="p">,</span> <span class="n">mlp_dim</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">to_cls_token</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">identity</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mlp_head</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">([</span>
<span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">mlp_dim</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="n">get_activation</span><span class="p">(</span><span class="s1">'gelu'</span><span class="p">)),</span>
<span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">num_classes</span><span class="p">)</span>
<span class="p">])</span>
<span class="nd">@tf</span><span class="o">.</span><span class="n">function</span>
<span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">img</span><span class="p">):</span>
<span class="n">shapes</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">img</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">rearrange</span><span class="p">(</span><span class="n">img</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">patch_to_embedding</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">cls_tokens</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">broadcast_to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">cls_token</span><span class="p">,</span> <span class="p">(</span><span class="n">shapes</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">dim</span><span class="p">))</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">concat</span><span class="p">((</span><span class="n">cls_tokens</span><span class="p">,</span> <span class="n">x</span><span class="p">),</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">x</span> <span class="o">+=</span> <span class="bp">self</span><span class="o">.</span><span class="n">pos_embedding</span>
<span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">transformer</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">to_cls_token</span><span class="p">(</span><span class="n">x</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">])</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">mlp_head</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
</pre></div>
</div>
<p>Next, define a function that accepts any combination of the keyword arguments <code class="docutils literal notranslate"><span class="pre">input_shape</span></code>, <code class="docutils literal notranslate"><span class="pre">include_top</span></code>, <code class="docutils literal notranslate"><span class="pre">pooling</span></code>, and/or <code class="docutils literal notranslate"><span class="pre">weights</span></code> and returns an instanced model.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">vit_tensorflow</span> <span class="kn">import</span> <span class="n">ViT</span>
<span class="k">def</span> <span class="nf">vit_model</span><span class="p">(</span><span class="n">image_shape</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="k">return</span> <span class="n">ViT</span><span class="p">(</span>
<span class="n">image_size</span><span class="o">=</span><span class="n">input_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span>
<span class="n">patch_size</span><span class="o">=</span><span class="mi">23</span><span class="p">,</span>
<span class="n">num_classes</span><span class="o">=</span><span class="mi">1000</span><span class="p">,</span>
<span class="n">dim</span><span class="o">=</span><span class="mi">1024</span><span class="p">,</span>
<span class="n">depth</span><span class="o">=</span><span class="mi">6</span><span class="p">,</span>
<span class="n">heads</span><span class="o">=</span><span class="mi">16</span><span class="p">,</span>
<span class="n">mlp_dim</span><span class="o">=</span><span class="mi">2048</span>
<span class="p">)</span>
</pre></div>
</div>
<p>Then, create a <a class="reference internal" href="../model_params/#slideflow.ModelParams" title="slideflow.ModelParams"><code class="xref py py-class docutils literal notranslate"><span class="pre">slideflow.ModelParams</span></code></a> object with your training parameters, setting the <code class="docutils literal notranslate"><span class="pre">model</span></code> argument equal to the function you just defined:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">slideflow</span> <span class="k">as</span> <span class="nn">sf</span>
<span class="kn">from</span> <span class="nn">vit_tensorflow</span> <span class="n">impport</span> <span class="n">ViT</span>
<span class="k">def</span> <span class="nf">vit_model</span><span class="p">(</span><span class="n">image_shape</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="o">...</span>
<span class="n">hp</span> <span class="o">=</span> <span class="n">ModelParams</span><span class="p">(</span>
<span class="n">tile_px</span><span class="o">=</span><span class="mi">299</span><span class="p">,</span>
<span class="n">tile_um</span><span class="o">=</span><span class="mi">302</span><span class="p">,</span>
<span class="n">batch_size</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span>
<span class="n">model</span><span class="o">=</span><span class="n">vit_model</span><span class="p">,</span>
<span class="o">...</span>
<span class="p">)</span>
</pre></div>
</div>
<p>You can now train the model as described in <a class="reference internal" href="../tutorial1/#tutorial1"><span class="std std-ref">Tutorial 1: Model training (simple)</span></a>.</p>
</section>
<section id="custom-pytorch-model">
<h2>Custom PyTorch model<a class="headerlink" href="#custom-pytorch-model" title="Permalink to this heading">¶</a></h2>
<p>The process is very similar when using PyTorch. In this example, instead of defining the architecture in a separate file, we will use an implementation of ViT available via PyPI:</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">pip3</span> <span class="n">install</span> <span class="n">vit</span><span class="o">-</span><span class="n">pytorch</span>
</pre></div>
</div>
<p>Next, define a function which accepts any combination of the keyword arguments <code class="docutils literal notranslate"><span class="pre">image_size</span></code> and/or <code class="docutils literal notranslate"><span class="pre">pretrained</span></code> and returns an instanced model.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">slideflow</span> <span class="k">as</span> <span class="nn">sf</span>
<span class="kn">from</span> <span class="nn">vit_pytorch</span> <span class="n">impport</span> <span class="n">ViT</span>
<span class="k">def</span> <span class="nf">vit_model</span><span class="p">(</span><span class="n">image_shape</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">ViT</span><span class="p">(</span>
<span class="n">image_size</span><span class="o">=</span><span class="n">image_size</span><span class="p">,</span>
<span class="n">patch_size</span><span class="o">=</span><span class="mi">23</span><span class="p">,</span>
<span class="n">num_classes</span><span class="o">=</span><span class="mi">1000</span><span class="p">,</span>
<span class="n">dim</span><span class="o">=</span><span class="mi">1024</span><span class="p">,</span>
<span class="n">depth</span><span class="o">=</span><span class="mi">6</span><span class="p">,</span>
<span class="n">heads</span><span class="o">=</span><span class="mi">16</span><span class="p">,</span>
<span class="n">mlp_dim</span><span class="o">=</span><span class="mi">2048</span><span class="p">,</span>
<span class="n">dropout</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span>
<span class="n">emb_dropout</span><span class="o">=</span><span class="mf">0.1</span>
<span class="p">)</span>
<span class="n">model</span><span class="o">.</span><span class="n">out_features</span> <span class="o">=</span> <span class="mi">1000</span>
<span class="k">return</span> <span class="n">model</span>
</pre></div>
</div>
<p>Finally, set the <code class="docutils literal notranslate"><span class="pre">model</span></code> argument of a <a class="reference internal" href="../model_params/#slideflow.ModelParams" title="slideflow.ModelParams"><code class="xref py py-class docutils literal notranslate"><span class="pre">slideflow.ModelParams</span></code></a> object equal to this function:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">slideflow</span> <span class="k">as</span> <span class="nn">sf</span>
<span class="kn">from</span> <span class="nn">vit_pytorch</span> <span class="n">impport</span> <span class="n">ViT</span>
<span class="k">def</span> <span class="nf">vit_model</span><span class="p">(</span><span class="n">image_shape</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="o">...</span>
<span class="n">hp</span> <span class="o">=</span> <span class="n">ModelParams</span><span class="p">(</span>
<span class="n">tile_px</span><span class="o">=</span><span class="mi">299</span><span class="p">,</span>
<span class="n">tile_um</span><span class="o">=</span><span class="mi">302</span><span class="p">,</span>
<span class="n">batch_size</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span>
<span class="n">model</span><span class="o">=</span><span class="n">vit_model</span><span class="p">,</span>
<span class="o">...</span>
<span class="p">)</span>
</pre></div>
</div>
<p>You can now train the model as described in <a class="reference internal" href="../tutorial1/#tutorial1"><span class="std std-ref">Tutorial 1: Model training (simple)</span></a>.</p>
</section>
</section>
</article>
</div>
<footer>
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
<a href="../tutorial4/" class="btn btn-neutral float-right" title="Tutorial 4: Model evaluation & heatmaps" accesskey="n" rel="next">Next <img src="../_static/images/chevron-right-orange.svg" class="next-page"></a>
<a href="../tutorial2/" class="btn btn-neutral" title="Tutorial 2: Model training (advanced)" accesskey="p" rel="prev"><img src="../_static/images/chevron-right-orange.svg" class="previous-page"> Previous</a>
</div>
<hr>
<div role="contentinfo">
<p>
© Copyright 2023, James M Dolezal.
</p>
</div>
<div>
Built with <a href="http://sphinx-doc.org/">Sphinx</a> using a <a href="https://github.com/rtfd/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>.
</div>
</footer>
</div>
</div>
<div class="pytorch-content-right" id="pytorch-content-right">
<div class="pytorch-right-menu" id="pytorch-right-menu">
<div class="pytorch-side-scroll" id="pytorch-side-scroll-right">
<ul>
<li><a class="reference internal" href="#">Tutorial 3: Using a custom architecture</a><ul>
<li><a class="reference internal" href="#custom-tensorflow-model">Custom Tensorflow model</a></li>
<li><a class="reference internal" href="#custom-pytorch-model">Custom PyTorch model</a></li>
</ul>
</li>
</ul>
</div>
</div>
</div>
</section>
</div>
<script type="text/javascript" id="documentation_options" data-url_root="../" src="../_static/documentation_options.js"></script>
<script data-url_root="../" id="documentation_options" src="../_static/documentation_options.js"></script>
<script src="../_static/doctools.js"></script>
<script src="../_static/sphinx_highlight.js"></script>
<script type="text/javascript" src="../_static/js/vendor/jquery-3.6.3.min.js"></script>
<script type="text/javascript" src="../_static/js/vendor/popper.min.js"></script>
<script type="text/javascript" src="../_static/js/vendor/bootstrap.min.js"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/list.js/1.5.0/list.min.js"></script>
<script type="text/javascript" src="../_static/js/theme.js"></script>
<script type="text/javascript">
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
<!-- Begin Footer -->
<!-- End Footer -->
<!-- Begin Mobile Menu -->
<div class="mobile-main-menu">
<div class="container-fluid">
<div class="container">
<div class="mobile-main-menu-header-container">
<a class="header-logo" href="https://pytorch.org/" aria-label="PyTorch"></a>
<a class="main-menu-close-button" href="#" data-behavior="close-mobile-menu"></a>
</div>
</div>
</div>
<div class="mobile-main-menu-links-container">
<div class="main-menu">
<ul>
<li>
<a href="https://slideflow.dev">Docs</a>
</li>
<li>
<a href="https://slideflow.dev/tutorial1/">Tutorials</a>
</li>
<li>
<a href="https://github.com/slideflow/slideflow">Github</a>
</li>
</ul>
</div>
</div>
</div>
<!-- End Mobile Menu -->
<script script type="text/javascript">
var collapsedSections = [];
</script>
<script type="text/javascript" src="../_static/js/vendor/anchor.min.js"></script>
<script type="text/javascript">
$(document).ready(function() {
mobileMenu.bind();
mobileTOC.bind();
pytorchAnchors.bind();
sideMenus.bind();
scrollToAnchor.bind();
highlightNavigation.bind();
mainMenuDropdown.bind();
filterTags.bind();
// Add class to links that have code blocks, since we cannot create links in code blocks
$("article.pytorch-article a span.pre").each(function(e) {
$(this).closest("a").addClass("has-code");
});
})
</script>
</body>
</html>