[404218]: / Code / PennyLane / Algorithm Prototypings I / 01 NYYY 55.3% kkawchak.ipynb

Download this file

901 lines (900 with data), 238.4 kB

{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "WNXEyZ23rdz-"
      },
      "outputs": [],
      "source": [
        "# This cell is added by sphinx-gallery\n",
        "# It can be customized to whatever you like\n",
        "#from google.colab import drive\n",
        "#drive.mount('/content/drive')"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [],
      "metadata": {
        "id": "WYVI3RMhAdap"
      }
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "id": "sANL984ird0B"
      },
      "outputs": [],
      "source": [
        "#!pip install pennylane\n",
        "# Some parts of this code are based on the Python script:\n",
        "# https://github.com/pytorch/tutorials/blob/master/beginner_source/transfer_learning_tutorial.py\n",
        "# License: BSD\n",
        "\n",
        "import time\n",
        "import os\n",
        "import copy\n",
        "\n",
        "# PyTorch\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "from torch.optim import lr_scheduler\n",
        "import torchvision\n",
        "from torchvision import datasets, transforms\n",
        "\n",
        "# Pennylane\n",
        "import pennylane as qml\n",
        "from pennylane import numpy as np\n",
        "\n",
        "torch.manual_seed(42)\n",
        "np.random.seed(42)\n",
        "\n",
        "# Plotting\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "# OpenMP: number of parallel threads.\n",
        "os.environ[\"OMP_NUM_THREADS\"] = \"1\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "o7eGTk_Prd0B"
      },
      "source": [
        "Setting of the main hyper-parameters of the model\n",
        "=================================================\n",
        "\n",
        "::: {.note}\n",
        "::: {.title}\n",
        "Note\n",
        ":::\n",
        "\n",
        "To reproduce the results of Ref. \\[1\\], `num_epochs` should be set to\n",
        "`30` which may take a long time. We suggest to first try with\n",
        "`num_epochs=1` and, if everything runs smoothly, increase it to a larger\n",
        "value.\n",
        ":::\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "id": "Wncy61Mdrd0B"
      },
      "outputs": [],
      "source": [
        "n_qubits = 8                # Number of qubits\n",
        "step = 0.0006               # Learning rate\n",
        "batch_size = 6              # Number of samples for each training step\n",
        "num_epochs = 1              # Number of training epochs\n",
        "q_depth = 6                 # Depth of the quantum circuit (number of variational layers)\n",
        "gamma_lr_scheduler = 0.1    # Learning rate reduction applied every 10 epochs.\n",
        "q_delta = 0.01              # Initial spread of random quantum weights\n",
        "start_time = time.time()    # Start of the computation timer"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PVqjLo8Rrd0B"
      },
      "source": [
        "We initialize a PennyLane device with a `default.qubit` backend.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "id": "qLOa5trRrd0B"
      },
      "outputs": [],
      "source": [
        "dev = qml.device(\"default.qubit\", wires=n_qubits)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dxlK7Jjtrd0C"
      },
      "source": [
        "We configure PyTorch to use CUDA only if available. Otherwise the CPU is\n",
        "used.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "id": "Br_YGwRDrd0C"
      },
      "outputs": [],
      "source": [
        "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WFHuYd5xrd0C"
      },
      "source": [
        "Dataset loading\n",
        "===============\n",
        "\n",
        "::: {.note}\n",
        "::: {.title}\n",
        "Note\n",
        ":::\n",
        "\n",
        "The dataset containing images of *ants* and *bees* can be downloaded\n",
        "[here](https://download.pytorch.org/tutorial/hymenoptera_data.zip) and\n",
        "should be extracted in the subfolder `../_data/hymenoptera_data`.\n",
        ":::\n",
        "\n",
        "This is a very small dataset (roughly 250 images), too small for\n",
        "training from scratch a classical or quantum model, however it is enough\n",
        "when using *transfer learning* approach.\n",
        "\n",
        "The PyTorch packages `torchvision` and `torch.utils.data` are used for\n",
        "loading the dataset and performing standard preliminary image\n",
        "operations: resize, center, crop, normalize, *etc.*\n"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "#os.system(\"rm -rf /content/data/.ipynb_checkpoints\")\n",
        "#os.system(\"rm -rf /content/data/braintumor_data/.ipynb_checkpoints\")\n",
        "#os.system(\"rm -rf /content/data/braintumor_data/train/.ipynb_checkpoints\")\n",
        "#os.system(\"rm -rf /content/data/braintumor_data/val/.ipynb_checkpoints\")"
      ],
      "metadata": {
        "id": "DM8SDO3Wthcc"
      },
      "execution_count": 8,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "id": "jQrNNVnUrd0C"
      },
      "outputs": [],
      "source": [
        "data_transforms = {\n",
        "    \"train\": transforms.Compose(\n",
        "        [\n",
        "            # transforms.RandomResizedCrop(224),     # uncomment for data augmentation\n",
        "            # transforms.RandomHorizontalFlip(),     # uncomment for data augmentation\n",
        "            transforms.Resize(256),\n",
        "            transforms.CenterCrop(224),\n",
        "            transforms.ToTensor(),\n",
        "            # Normalize input channels using mean values and standard deviations of ImageNet.\n",
        "            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),\n",
        "        ]\n",
        "    ),\n",
        "    \"val\": transforms.Compose(\n",
        "        [\n",
        "            transforms.Resize(256),\n",
        "            transforms.CenterCrop(224),\n",
        "            transforms.ToTensor(),\n",
        "            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),\n",
        "        ]\n",
        "    ),\n",
        "}\n",
        "\n",
        "data_dir = \"/content/drive/MyDrive/Colab Notebooks/data/4_cl_2xbraintumor_data\"\n",
        "image_datasets = {\n",
        "    x if x == \"train\" else \"validation\": datasets.ImageFolder(\n",
        "        os.path.join(data_dir, x), data_transforms[x]\n",
        "    )\n",
        "    for x in [\"train\", \"val\"]\n",
        "}\n",
        "dataset_sizes = {x: len(image_datasets[x]) for x in [\"train\", \"validation\"]}\n",
        "class_names = image_datasets[\"train\"].classes\n",
        "\n",
        "# Initialize dataloader\n",
        "dataloaders = {\n",
        "    x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True)\n",
        "    for x in [\"train\", \"validation\"]\n",
        "}\n",
        "\n",
        "# function to plot images\n",
        "def imshow(inp, title=None):\n",
        "    \"\"\"Display image from tensor.\"\"\"\n",
        "    inp = inp.numpy().transpose((1, 2, 0))\n",
        "    # Inverse of the initial normalization operation.\n",
        "    mean = np.array([0.485, 0.456, 0.406])\n",
        "    std = np.array([0.229, 0.224, 0.225])\n",
        "    inp = std * inp + mean\n",
        "    inp = np.clip(inp, 0, 1)\n",
        "    plt.imshow(inp)\n",
        "    if title is not None:\n",
        "        plt.title(title)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "piWk71nkrd0C"
      },
      "source": [
        "Let us show a batch of the test data, just to have an idea of the\n",
        "classification problem.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 166
        },
        "id": "u55iZYEOrd0D",
        "outputId": "b208750a-c60b-439d-acf7-743a1db4a95a"
      },
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<Figure size 640x480 with 1 Axes>"
            ],
            "image/png": "\n"
          },
          "metadata": {}
        }
      ],
      "source": [
        "# Get a batch of training data\n",
        "inputs, classes = next(iter(dataloaders[\"validation\"]))\n",
        "\n",
        "# Make a grid from batch\n",
        "out = torchvision.utils.make_grid(inputs)\n",
        "\n",
        "imshow(out, title=[class_names[x] for x in classes])\n",
        "\n",
        "dataloaders = {\n",
        "    x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True)\n",
        "    for x in [\"train\", \"validation\"]\n",
        "}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ds09ighxrd0D"
      },
      "source": [
        "Variational quantum circuit\n",
        "===========================\n",
        "\n",
        "We first define some quantum layers that will compose the quantum\n",
        "circuit.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "id": "OIbOa-KBrd0D"
      },
      "outputs": [],
      "source": [
        "def H_layer(nqubits):\n",
        "    \"\"\"Layer of single-qubit Hadamard gates.\n",
        "    \"\"\"\n",
        "    for idx in range(nqubits):\n",
        "        qml.Hadamard(wires=idx)\n",
        "\n",
        "\n",
        "def RY_layer(w):\n",
        "    \"\"\"Layer of parametrized qubit rotations around the y axis.\n",
        "    \"\"\"\n",
        "    for idx, element in enumerate(w):\n",
        "        qml.RY(element, wires=idx)\n",
        "\n",
        "\n",
        "def entangling_layer(nqubits):\n",
        "    \"\"\"Layer of CNOTs followed by another shifted layer of CNOT.\n",
        "    \"\"\"\n",
        "    # In other words it should apply something like :\n",
        "    # CNOT  CNOT  CNOT  CNOT...  CNOT\n",
        "    #   CNOT  CNOT  CNOT...  CNOT\n",
        "    for i in range(0, nqubits - 1, 2):  # Loop over even indices: i=0,2,...N-2\n",
        "        qml.CNOT(wires=[i, i + 1])\n",
        "    for i in range(1, nqubits - 1, 2):  # Loop over odd indices:  i=1,3,...N-3\n",
        "        qml.CNOT(wires=[i, i + 1])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ue__y4sQrd0D"
      },
      "source": [
        "Now we define the quantum circuit through the PennyLane\n",
        "[qnode]{.title-ref} decorator .\n",
        "\n",
        "The structure is that of a typical variational quantum circuit:\n",
        "\n",
        "-   **Embedding layer:** All qubits are first initialized in a balanced\n",
        "    superposition of *up* and *down* states, then they are rotated\n",
        "    according to the input parameters (local embedding).\n",
        "-   **Variational layers:** A sequence of trainable rotation layers and\n",
        "    constant entangling layers is applied.\n",
        "-   **Measurement layer:** For each qubit, the local expectation value\n",
        "    of the $Z$ operator is measured. This produces a classical output\n",
        "    vector, suitable for additional post-processing.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {
        "id": "C-O_cK6jrd0D"
      },
      "outputs": [],
      "source": [
        "@qml.qnode(dev, interface=\"torch\")\n",
        "def quantum_net(q_input_features, q_weights_flat):\n",
        "    \"\"\"\n",
        "    The variational quantum circuit.\n",
        "    \"\"\"\n",
        "\n",
        "    # Reshape weights\n",
        "    q_weights = q_weights_flat.reshape(q_depth, n_qubits)\n",
        "\n",
        "    # Start from state |+> , unbiased w.r.t. |0> and |1>\n",
        "    #H_layer(n_qubits)\n",
        "\n",
        "    # Embed features in the quantum node\n",
        "    RY_layer(q_input_features)\n",
        "\n",
        "    # Sequence of trainable variational layers\n",
        "    for k in range(q_depth):\n",
        "        entangling_layer(n_qubits)\n",
        "        RY_layer(q_weights[k])\n",
        "\n",
        "    # Expectation values in the Z basis\n",
        "    exp_vals = [qml.expval(qml.PauliZ(position)) for position in range(n_qubits)]\n",
        "    return tuple(exp_vals)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7FNHREgVrd0D"
      },
      "source": [
        "Dressed quantum circuit\n",
        "=======================\n",
        "\n",
        "We can now define a custom `torch.nn.Module` representing a *dressed*\n",
        "quantum circuit.\n",
        "\n",
        "This is a concatenation of:\n",
        "\n",
        "-   A classical pre-processing layer (`nn.Linear`).\n",
        "-   A classical activation function (`torch.tanh`).\n",
        "-   A constant `np.pi/2.0` scaling.\n",
        "-   The previously defined quantum circuit (`quantum_net`).\n",
        "-   A classical post-processing layer (`nn.Linear`).\n",
        "\n",
        "The input of the module is a batch of vectors with 512 real parameters\n",
        "(features) and the output is a batch of vectors with two real outputs\n",
        "(associated with the two classes of images: *ants* and *bees*).\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 13,
      "metadata": {
        "id": "C96kWCjWrd0D"
      },
      "outputs": [],
      "source": [
        "class DressedQuantumNet(nn.Module):\n",
        "    \"\"\"\n",
        "    Torch module implementing the *dressed* quantum net.\n",
        "    \"\"\"\n",
        "\n",
        "    def __init__(self):\n",
        "        \"\"\"\n",
        "        Definition of the *dressed* layout.\n",
        "        \"\"\"\n",
        "\n",
        "        super().__init__()\n",
        "        self.pre_net = nn.Linear(512, n_qubits)\n",
        "        self.q_params = nn.Parameter(q_delta * torch.randn(q_depth * n_qubits))\n",
        "        self.post_net = nn.Linear(n_qubits, 4)\n",
        "\n",
        "    def forward(self, input_features):\n",
        "        \"\"\"\n",
        "        Defining how tensors are supposed to move through the *dressed* quantum\n",
        "        net.\n",
        "        \"\"\"\n",
        "\n",
        "        # obtain the input features for the quantum circuit\n",
        "        # by reducing the feature dimension from 512 to 4\n",
        "        pre_out = self.pre_net(input_features)\n",
        "        q_in = torch.tanh(pre_out) * np.pi / 2.0\n",
        "\n",
        "        # Apply the quantum circuit to each element of the batch and append to q_out\n",
        "        q_out = torch.Tensor(0, n_qubits)\n",
        "        q_out = q_out.to(device)\n",
        "        for elem in q_in:\n",
        "            q_out_elem = torch.hstack(quantum_net(elem, self.q_params)).float().unsqueeze(0)\n",
        "            q_out = torch.cat((q_out, q_out_elem))\n",
        "\n",
        "        # return the two-dimensional prediction from the postprocessing layer\n",
        "        return self.post_net(q_out)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7hQlpDQdrd0D"
      },
      "source": [
        "Hybrid classical-quantum model\n",
        "==============================\n",
        "\n",
        "We are finally ready to build our full hybrid classical-quantum network.\n",
        "We follow the *transfer learning* approach:\n",
        "\n",
        "1.  First load the classical pre-trained network *ResNet18* from the\n",
        "    `torchvision.models` zoo.\n",
        "2.  Freeze all the weights since they should not be trained.\n",
        "3.  Replace the last fully connected layer with our trainable dressed\n",
        "    quantum circuit (`DressedQuantumNet`).\n",
        "\n",
        "::: {.note}\n",
        "::: {.title}\n",
        "Note\n",
        ":::\n",
        "\n",
        "The *ResNet18* model is automatically downloaded by PyTorch and it may\n",
        "take several minutes (only the first time).\n",
        ":::\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "metadata": {
        "id": "MAh4FqBYrd0D",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 0
        },
        "outputId": "d71f9b2f-b0d6-4a81-eeb4-dc18d9a769a7"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n",
            "  warnings.warn(\n",
            "/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.\n",
            "  warnings.warn(msg)\n",
            "Downloading: \"https://download.pytorch.org/models/resnet18-f37072fd.pth\" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth\n",
            "100%|██████████| 44.7M/44.7M [00:00<00:00, 225MB/s]\n"
          ]
        }
      ],
      "source": [
        "model_hybrid = torchvision.models.resnet18(pretrained=True)\n",
        "\n",
        "for param in model_hybrid.parameters():\n",
        "    param.requires_grad = False\n",
        "\n",
        "\n",
        "# Notice that model_hybrid.fc is the last layer of ResNet18\n",
        "model_hybrid.fc = DressedQuantumNet()\n",
        "\n",
        "# Use CUDA or CPU according to the \"device\" object.\n",
        "model_hybrid = model_hybrid.to(device)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ovX-Tkb0rd0E"
      },
      "source": [
        "Training and results\n",
        "====================\n",
        "\n",
        "Before training the network we need to specify the *loss* function.\n",
        "\n",
        "We use, as usual in classification problem, the *cross-entropy* which is\n",
        "directly available within `torch.nn`.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "metadata": {
        "id": "gkmFIK4Brd0E"
      },
      "outputs": [],
      "source": [
        "criterion = nn.CrossEntropyLoss()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qhFORuvard0E"
      },
      "source": [
        "We also initialize the *Adam optimizer* which is called at each training\n",
        "step in order to update the weights of the model.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 16,
      "metadata": {
        "id": "_K9M-VPMrd0E"
      },
      "outputs": [],
      "source": [
        "optimizer_hybrid = optim.Adam(model_hybrid.fc.parameters(), lr=step)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IGG8pksyrd0E"
      },
      "source": [
        "We schedule to reduce the learning rate by a factor of\n",
        "`gamma_lr_scheduler` every 10 epochs.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 17,
      "metadata": {
        "id": "nco3IrE6rd0E"
      },
      "outputs": [],
      "source": [
        "exp_lr_scheduler = lr_scheduler.StepLR(\n",
        "    optimizer_hybrid, step_size=10, gamma=gamma_lr_scheduler\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yJvpPNCRrd0E"
      },
      "source": [
        "What follows is a training function that will be called later. This\n",
        "function should return a trained model that can be used to make\n",
        "predictions (classifications).\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 18,
      "metadata": {
        "id": "8uO5BrOPrd0E"
      },
      "outputs": [],
      "source": [
        "def train_model(model, criterion, optimizer, scheduler, num_epochs):\n",
        "    since = time.time()\n",
        "    best_model_wts = copy.deepcopy(model.state_dict())\n",
        "    best_acc = 0.0\n",
        "    best_loss = 10000.0  # Large arbitrary number\n",
        "    best_acc_train = 0.0\n",
        "    best_loss_train = 10000.0  # Large arbitrary number\n",
        "    print(\"Training started:\")\n",
        "\n",
        "    for epoch in range(num_epochs):\n",
        "\n",
        "        # Each epoch has a training and validation phase\n",
        "        for phase in [\"train\", \"validation\"]:\n",
        "            if phase == \"train\":\n",
        "                # Set model to training mode\n",
        "                model.train()\n",
        "            else:\n",
        "                # Set model to evaluate mode\n",
        "                model.eval()\n",
        "            running_loss = 0.0\n",
        "            running_corrects = 0\n",
        "\n",
        "            # Iterate over data.\n",
        "            n_batches = dataset_sizes[phase] // batch_size\n",
        "            it = 0\n",
        "            for inputs, labels in dataloaders[phase]:\n",
        "                since_batch = time.time()\n",
        "                batch_size_ = len(inputs)\n",
        "                inputs = inputs.to(device)\n",
        "                labels = labels.to(device)\n",
        "                optimizer.zero_grad()\n",
        "\n",
        "                # Track/compute gradient and make an optimization step only when training\n",
        "                with torch.set_grad_enabled(phase == \"train\"):\n",
        "                    outputs = model(inputs)\n",
        "                    _, preds = torch.max(outputs, 1)\n",
        "                    loss = criterion(outputs, labels)\n",
        "                    if phase == \"train\":\n",
        "                        loss.backward()\n",
        "                        optimizer.step()\n",
        "\n",
        "                # Print iteration results\n",
        "                running_loss += loss.item() * batch_size_\n",
        "                batch_corrects = torch.sum(preds == labels.data).item()\n",
        "                running_corrects += batch_corrects\n",
        "                print(\n",
        "                    \"Phase: {} Epoch: {}/{} Iter: {}/{} Batch time: {:.4f}\".format(\n",
        "                        phase,\n",
        "                        epoch + 1,\n",
        "                        num_epochs,\n",
        "                        it + 1,\n",
        "                        n_batches + 1,\n",
        "                        time.time() - since_batch,\n",
        "                    ),\n",
        "                    end=\"\\r\",\n",
        "                    flush=True,\n",
        "                )\n",
        "                it += 1\n",
        "\n",
        "            # Print epoch results\n",
        "            epoch_loss = running_loss / dataset_sizes[phase]\n",
        "            epoch_acc = running_corrects / dataset_sizes[phase]\n",
        "            print(\n",
        "                \"Phase: {} Epoch: {}/{} Loss: {:.4f} Acc: {:.4f}        \".format(\n",
        "                    \"train\" if phase == \"train\" else \"validation  \",\n",
        "                    epoch + 1,\n",
        "                    num_epochs,\n",
        "                    epoch_loss,\n",
        "                    epoch_acc,\n",
        "                )\n",
        "            )\n",
        "\n",
        "            # Check if this is the best model wrt previous epochs\n",
        "            if phase == \"validation\" and epoch_acc > best_acc:\n",
        "                best_acc = epoch_acc\n",
        "                best_model_wts = copy.deepcopy(model.state_dict())\n",
        "            if phase == \"validation\" and epoch_loss < best_loss:\n",
        "                best_loss = epoch_loss\n",
        "            if phase == \"train\" and epoch_acc > best_acc_train:\n",
        "                best_acc_train = epoch_acc\n",
        "            if phase == \"train\" and epoch_loss < best_loss_train:\n",
        "                best_loss_train = epoch_loss\n",
        "      \n",
        "            # Update learning rate\n",
        "            if phase == \"train\":\n",
        "                scheduler.step()\n",
        "\n",
        "    # Print final results\n",
        "    model.load_state_dict(best_model_wts)\n",
        "    time_elapsed = time.time() - since\n",
        "    print(\n",
        "        \"Training completed in {:.0f}m {:.0f}s\".format(time_elapsed // 60, time_elapsed % 60)\n",
        "    )\n",
        "    print(\"Best test loss: {:.4f} | Best test accuracy: {:.4f}\".format(best_loss, best_acc))\n",
        "    return model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CmAl9fIQrd0E"
      },
      "source": [
        "We are ready to perform the actual training process.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 19,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 0
        },
        "id": "rAQBdDA_rd0E",
        "outputId": "fb4287bf-12ff-4c28-d497-18d0660c9543"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Training started:\n",
            "Phase: train Epoch: 1/1 Loss: 1.2981 Acc: 0.3976        \n",
            "Phase: validation   Epoch: 1/1 Loss: 1.2214 Acc: 0.5529        \n",
            "Training completed in 11m 43s\n",
            "Best test loss: 1.2214 | Best test accuracy: 0.5529\n"
          ]
        }
      ],
      "source": [
        "model_hybrid = train_model(\n",
        "    model_hybrid, criterion, optimizer_hybrid, exp_lr_scheduler, num_epochs=num_epochs\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "s4cEc4mird0E"
      },
      "source": [
        "Visualizing the model predictions\n",
        "=================================\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ARvjv_lbrd0E"
      },
      "source": [
        "We first define a visualization function for a batch of test data.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 20,
      "metadata": {
        "id": "CtUY4xU_rd0E"
      },
      "outputs": [],
      "source": [
        "def visualize_model(model, num_images=6, fig_name=\"Predictions\"):\n",
        "    images_so_far = 0\n",
        "    _fig = plt.figure(fig_name)\n",
        "    model.eval()\n",
        "    with torch.no_grad():\n",
        "        for _i, (inputs, labels) in enumerate(dataloaders[\"validation\"]):\n",
        "            inputs = inputs.to(device)\n",
        "            labels = labels.to(device)\n",
        "            outputs = model(inputs)\n",
        "            _, preds = torch.max(outputs, 1)\n",
        "            for j in range(inputs.size()[0]):\n",
        "                images_so_far += 1\n",
        "                ax = plt.subplot(num_images // 2, 2, images_so_far)\n",
        "                ax.axis(\"off\")\n",
        "                ax.set_title(\"[{}]\".format(class_names[preds[j]]))\n",
        "                imshow(inputs.cpu().data[j])\n",
        "                if images_so_far == num_images:\n",
        "                    return"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "F6I5Rk92rd0E"
      },
      "source": [
        "Finally, we can run the previous function to see a batch of images with\n",
        "the corresponding predictions.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 21,
      "metadata": {
        "id": "-RoYcZQXrd0E",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 428
        },
        "outputId": "117e96f2-051a-4880-8046-c353e49a60a4"
      },
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<Figure size 640x480 with 6 Axes>"
            ],
            "image/png": "\n"
          },
          "metadata": {}
        }
      ],
      "source": [
        "visualize_model(model_hybrid, num_images=6)\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IBXZTnzjrd0E"
      },
      "source": [
        "References\n",
        "==========\n",
        "\n",
        "\\[1\\] Andrea Mari, Thomas R. Bromley, Josh Izaac, Maria Schuld, and\n",
        "Nathan Killoran. *Transfer learning in hybrid classical-quantum neural\n",
        "networks*. arXiv:1912.08278 (2019).\n",
        "\n",
        "\\[2\\] Rajat Raina, Alexis Battle, Honglak Lee, Benjamin Packer, and\n",
        "Andrew Y Ng. *Self-taught learning: transfer learning from unlabeled\n",
        "data*. Proceedings of the 24th International Conference on Machine\n",
        "Learning\\*, 759--766 (2007).\n",
        "\n",
        "\\[3\\] Kaiming He, Xiangyu Zhang, Shaoqing ren and Jian Sun. *Deep\n",
        "residual learning for image recognition*. Proceedings of the IEEE\n",
        "Conference on Computer Vision and Pattern Recognition, 770-778 (2016).\n",
        "\n",
        "\\[4\\] Ville Bergholm, Josh Izaac, Maria Schuld, Christian Gogolin,\n",
        "Carsten Blank, Keri McKiernan, and Nathan Killoran. *PennyLane:\n",
        "Automatic differentiation of hybrid quantum-classical computations*.\n",
        "arXiv:1811.04968 (2018).\n",
        "\n",
        "About the author\n",
        "================\n"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.16"
    },
    "colab": {
      "provenance": [],
      "gpuType": "V100"
    },
    "accelerator": "GPU",
    "gpuClass": "standard"
  },
  "nbformat": 4,
  "nbformat_minor": 0
}