[404218]: / Code / PennyLane / Quantum Parameters / 10 Class 8 Depth 41.6% kkawchak.ipynb

Download this file

964 lines (963 with data), 238.5 kB

{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": 443,
      "metadata": {
        "id": "UJOq3mdA8PAH",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 0
        },
        "outputId": "cfc91fa1-9c4e-44cc-a600-6f4d149dd0e3"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Time in seconds since beginning of run: 1695696940.1295893\n",
            "Tue Sep 26 02:55:40 2023\n"
          ]
        }
      ],
      "source": [
        "# This cell is added by sphinx-gallery\n",
        "# It can be customized to whatever you like\n",
        "%matplotlib inline\n",
        "\n",
        "# from google.colab import drive\n",
        "# drive.mount('/content/drive')\n",
        "# !pip install pennylane\n",
        "\n",
        "import time\n",
        "seconds = time.time()\n",
        "print(\"Time in seconds since beginning of run:\", seconds)\n",
        "local_time = time.ctime(seconds)\n",
        "print(local_time)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 444,
      "metadata": {
        "id": "5ljdosVS8PAP"
      },
      "outputs": [],
      "source": [
        "# Some parts of this code are based on the Python script:\n",
        "# https://github.com/pytorch/tutorials/blob/master/beginner_source/transfer_learning_tutorial.py\n",
        "# License: BSD\n",
        "\n",
        "import os\n",
        "import copy\n",
        "\n",
        "# PyTorch\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "from torch.optim import lr_scheduler\n",
        "import torchvision\n",
        "from torchvision import datasets, transforms\n",
        "\n",
        "# Pennylane\n",
        "import pennylane as qml\n",
        "from pennylane import numpy as np\n",
        "\n",
        "torch.manual_seed(42)\n",
        "np.random.seed(42)\n",
        "\n",
        "# Plotting\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "# OpenMP: number of parallel threads.\n",
        "os.environ[\"OMP_NUM_THREADS\"] = \"1\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1AFilzYk8PAQ"
      },
      "source": [
        "Setting of the main hyper-parameters of the model\n",
        "=================================================\n",
        "\n",
        "::: {.note}\n",
        "::: {.title}\n",
        "Note\n",
        ":::\n",
        "\n",
        "To reproduce the results of Ref. \\[1\\], `num_epochs` should be set to\n",
        "`30` which may take a long time. We suggest to first try with\n",
        "`num_epochs=1` and, if everything runs smoothly, increase it to a larger\n",
        "value.\n",
        ":::\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 445,
      "metadata": {
        "id": "5LRcEYZg8PAR"
      },
      "outputs": [],
      "source": [
        "n_qubits = 4                # Number of qubits\n",
        "step = 0.0004               # Learning rate\n",
        "batch_size = 4              # Number of samples for each training step\n",
        "num_epochs = 5              # Number of training epochs\n",
        "q_depth = 8                 # Depth of the quantum circuit (number of variational layers)\n",
        "gamma_lr_scheduler = 0.1    # Learning rate reduction applied every 10 epochs.\n",
        "q_delta = 0.01              # Initial spread of random quantum weights\n",
        "start_time = time.time()    # Start of the computation timer"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NlU2Q7zd8PAR"
      },
      "source": [
        "We initialize a PennyLane device with a `default.qubit` backend.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 446,
      "metadata": {
        "id": "0prgZPLK8PAR"
      },
      "outputs": [],
      "source": [
        "dev = qml.device(\"default.qubit\", wires=n_qubits)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "54jRIpbZ8PAS"
      },
      "source": [
        "We configure PyTorch to use CUDA only if available. Otherwise the CPU is\n",
        "used.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 447,
      "metadata": {
        "id": "23nQUjLm8PAS"
      },
      "outputs": [],
      "source": [
        "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-AJzWJGi8PAT"
      },
      "source": [
        "Dataset loading\n",
        "===============\n",
        "\n",
        "::: {.note}\n",
        "::: {.title}\n",
        "Note\n",
        ":::\n",
        "\n",
        "The dataset containing images of *ants* and *bees* can be downloaded\n",
        "[here](https://download.pytorch.org/tutorial/hymenoptera_data.zip) and\n",
        "should be extracted in the subfolder `../_data/hymenoptera_data`.\n",
        ":::\n",
        "\n",
        "This is a very small dataset (roughly 250 images), too small for\n",
        "training from scratch a classical or quantum model, however it is enough\n",
        "when using *transfer learning* approach.\n",
        "\n",
        "The PyTorch packages `torchvision` and `torch.utils.data` are used for\n",
        "loading the dataset and performing standard preliminary image\n",
        "operations: resize, center, crop, normalize, *etc.*\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 448,
      "metadata": {
        "id": "XaNa12un8PAT"
      },
      "outputs": [],
      "source": [
        "data_transforms = {\n",
        "    \"train\": transforms.Compose(\n",
        "        [\n",
        "            # transforms.RandomResizedCrop(224),     # uncomment for data augmentation\n",
        "            # transforms.RandomHorizontalFlip(),     # uncomment for data augmentation\n",
        "            transforms.Resize(256),\n",
        "            transforms.CenterCrop(224),\n",
        "            transforms.ToTensor(),\n",
        "            # Normalize input channels using mean values and standard deviations of ImageNet.\n",
        "            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),\n",
        "        ]\n",
        "    ),\n",
        "    \"val\": transforms.Compose(\n",
        "        [\n",
        "            transforms.Resize(256),\n",
        "            transforms.CenterCrop(224),\n",
        "            transforms.ToTensor(),\n",
        "            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),\n",
        "        ]\n",
        "    ),\n",
        "}\n",
        "\n",
        "data_dir = \"/content/drive/MyDrive/Colab Notebooks/data/Shuffle Split 10 of 17 Classes Big Brain Tumor MRI Images\"\n",
        "image_datasets = {\n",
        "    x if x == \"train\" else \"validation\": datasets.ImageFolder(\n",
        "        os.path.join(data_dir, x), data_transforms[x]\n",
        "    )\n",
        "    for x in [\"train\", \"val\"]\n",
        "}\n",
        "dataset_sizes = {x: len(image_datasets[x]) for x in [\"train\", \"validation\"]}\n",
        "class_names = image_datasets[\"train\"].classes\n",
        "\n",
        "# Initialize dataloader\n",
        "dataloaders = {\n",
        "    x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True)\n",
        "    for x in [\"train\", \"validation\"]\n",
        "}\n",
        "\n",
        "# function to plot images\n",
        "def imshow(inp, title=None):\n",
        "    \"\"\"Display image from tensor.\"\"\"\n",
        "    inp = inp.numpy().transpose((1, 2, 0))\n",
        "    # Inverse of the initial normalization operation.\n",
        "    mean = np.array([0.485, 0.456, 0.406])\n",
        "    std = np.array([0.229, 0.224, 0.225])\n",
        "    inp = std * inp + mean\n",
        "    inp = np.clip(inp, 0, 1)\n",
        "    plt.imshow(inp)\n",
        "    if title is not None:\n",
        "        plt.title(title)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ANdmcnR98PAU"
      },
      "source": [
        "Let us show a batch of the test data, just to have an idea of the\n",
        "classification problem.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 449,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 207
        },
        "id": "QzIKQxS78PAU",
        "outputId": "55823e0b-d4f1-400c-d977-7d95c0e18b18"
      },
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<Figure size 640x480 with 1 Axes>"
            ],
            "image/png": "\n"
          },
          "metadata": {}
        }
      ],
      "source": [
        "# Get a batch of training data\n",
        "inputs, classes = next(iter(dataloaders[\"validation\"]))\n",
        "\n",
        "# Make a grid from batch\n",
        "out = torchvision.utils.make_grid(inputs)\n",
        "\n",
        "imshow(out, title=[class_names[x] for x in classes])\n",
        "\n",
        "dataloaders = {\n",
        "    x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True)\n",
        "    for x in [\"train\", \"validation\"]\n",
        "}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_ULbO8f28PAU"
      },
      "source": [
        "Variational quantum circuit\n",
        "===========================\n",
        "\n",
        "We first define some quantum layers that will compose the quantum\n",
        "circuit.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 450,
      "metadata": {
        "id": "6gMomjvL8PAV"
      },
      "outputs": [],
      "source": [
        "def H_layer(nqubits):\n",
        "    \"\"\"Layer of single-qubit Hadamard gates.\n",
        "    \"\"\"\n",
        "    for idx in range(nqubits):\n",
        "        qml.Hadamard(wires=idx)\n",
        "\n",
        "\n",
        "def RY_layer(w):\n",
        "    \"\"\"Layer of parametrized qubit rotations around the y axis.\n",
        "    \"\"\"\n",
        "    for idx, element in enumerate(w):\n",
        "        qml.RY(element, wires=idx)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0iroynmF8PAV"
      },
      "source": [
        "Now we define the quantum circuit through the PennyLane\n",
        "[qnode]{.title-ref} decorator .\n",
        "\n",
        "The structure is that of a typical variational quantum circuit:\n",
        "\n",
        "-   **Embedding layer:** All qubits are first initialized in a balanced\n",
        "    superposition of *up* and *down* states, then they are rotated\n",
        "    according to the input parameters (local embedding).\n",
        "-   **Variational layers:** A sequence of trainable rotation layers and\n",
        "    constant entangling layers is applied.\n",
        "-   **Measurement layer:** For each qubit, the local expectation value\n",
        "    of the $Z$ operator is measured. This produces a classical output\n",
        "    vector, suitable for additional post-processing.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 451,
      "metadata": {
        "id": "ONyq04RY8PAV"
      },
      "outputs": [],
      "source": [
        "@qml.qnode(dev, interface=\"torch\")\n",
        "def quantum_net(q_input_features, q_weights_flat):\n",
        "    \"\"\"\n",
        "    The variational quantum circuit.\n",
        "    \"\"\"\n",
        "\n",
        "    # Reshape weights\n",
        "    q_weights = q_weights_flat.reshape(q_depth, n_qubits)\n",
        "\n",
        "    # Start from state |+> , unbiased w.r.t. |0> and |1>\n",
        "    H_layer(n_qubits)\n",
        "\n",
        "    # Embed features in the quantum node\n",
        "    RY_layer(q_input_features)\n",
        "\n",
        "    # Sequence of trainable variational layers\n",
        "    for k in range(q_depth):\n",
        "        RY_layer(q_weights[k])\n",
        "\n",
        "    # Expectation values in the Z basis\n",
        "    exp_vals = [qml.expval(qml.PauliZ(position)) for position in range(n_qubits)]\n",
        "    return tuple(exp_vals)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4eG97j4f8PAV"
      },
      "source": [
        "Dressed quantum circuit\n",
        "=======================\n",
        "\n",
        "We can now define a custom `torch.nn.Module` representing a *dressed*\n",
        "quantum circuit.\n",
        "\n",
        "This is a concatenation of:\n",
        "\n",
        "-   A classical pre-processing layer (`nn.Linear`).\n",
        "-   A classical activation function (`torch.tanh`).\n",
        "-   A constant `np.pi/2.0` scaling.\n",
        "-   The previously defined quantum circuit (`quantum_net`).\n",
        "-   A classical post-processing layer (`nn.Linear`).\n",
        "\n",
        "The input of the module is a batch of vectors with 512 real parameters\n",
        "(features) and the output is a batch of vectors with two real outputs\n",
        "(associated with the two classes of images: *ants* and *bees*).\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 452,
      "metadata": {
        "id": "hIljGdv_8PAW"
      },
      "outputs": [],
      "source": [
        "class DressedQuantumNet(nn.Module):\n",
        "    \"\"\"\n",
        "    Torch module implementing the *dressed* quantum net.\n",
        "    \"\"\"\n",
        "\n",
        "    def __init__(self):\n",
        "        \"\"\"\n",
        "        Definition of the *dressed* layout.\n",
        "        \"\"\"\n",
        "\n",
        "        super().__init__()\n",
        "        self.pre_net = nn.Linear(2048, n_qubits)\n",
        "        self.q_params = nn.Parameter(q_delta * torch.randn(q_depth * n_qubits))\n",
        "        self.post_net = nn.Linear(n_qubits, 10)\n",
        "\n",
        "    def forward(self, input_features):\n",
        "        \"\"\"\n",
        "        Defining how tensors are supposed to move through the *dressed* quantum\n",
        "        net.\n",
        "        \"\"\"\n",
        "\n",
        "        # obtain the input features for the quantum circuit\n",
        "        # by reducing the feature dimension from 512 to 4\n",
        "        pre_out = self.pre_net(input_features)\n",
        "        q_in = torch.tanh(pre_out) * np.pi / 2.0\n",
        "\n",
        "        # Apply the quantum circuit to each element of the batch and append to q_out\n",
        "        q_out = torch.Tensor(0, n_qubits)\n",
        "        q_out = q_out.to(device)\n",
        "        for elem in q_in:\n",
        "            q_out_elem = torch.hstack(quantum_net(elem, self.q_params)).float().unsqueeze(0)\n",
        "            q_out = torch.cat((q_out, q_out_elem))\n",
        "\n",
        "        # return the two-dimensional prediction from the postprocessing layer\n",
        "        return self.post_net(q_out)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "E8-EDnhn8PAW"
      },
      "source": [
        "Hybrid classical-quantum model\n",
        "==============================\n",
        "\n",
        "We are finally ready to build our full hybrid classical-quantum network.\n",
        "We follow the *transfer learning* approach:\n",
        "\n",
        "1.  First load the classical pre-trained network *ResNet18* from the\n",
        "    `torchvision.models` zoo.\n",
        "2.  Freeze all the weights since they should not be trained.\n",
        "3.  Replace the last fully connected layer with our trainable dressed\n",
        "    quantum circuit (`DressedQuantumNet`).\n",
        "\n",
        "::: {.note}\n",
        "::: {.title}\n",
        "Note\n",
        ":::\n",
        "\n",
        "The *ResNet18* model is automatically downloaded by PyTorch and it may\n",
        "take several minutes (only the first time).\n",
        ":::\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 453,
      "metadata": {
        "id": "lnJnW_ra8PAX"
      },
      "outputs": [],
      "source": [
        "model_hybrid = torchvision.models.resnet50(pretrained=True)\n",
        "\n",
        "for param in model_hybrid.parameters():\n",
        "    param.requires_grad = False\n",
        "\n",
        "\n",
        "# Notice that model_hybrid.fc is the last layer of ResNet18\n",
        "model_hybrid.fc = DressedQuantumNet()\n",
        "\n",
        "# Use CUDA or CPU according to the \"device\" object.\n",
        "model_hybrid = model_hybrid.to(device)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5k96EBuZ8PAX"
      },
      "source": [
        "Training and results\n",
        "====================\n",
        "\n",
        "Before training the network we need to specify the *loss* function.\n",
        "\n",
        "We use, as usual in classification problem, the *cross-entropy* which is\n",
        "directly available within `torch.nn`.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 454,
      "metadata": {
        "id": "BKvfgR5N8PAX"
      },
      "outputs": [],
      "source": [
        "criterion = nn.CrossEntropyLoss()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UUvuVdii8PAX"
      },
      "source": [
        "We also initialize the *Adam optimizer* which is called at each training\n",
        "step in order to update the weights of the model.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 455,
      "metadata": {
        "id": "bPI2SbMQ8PAX"
      },
      "outputs": [],
      "source": [
        "optimizer_hybrid = optim.Adam(model_hybrid.fc.parameters(), lr=step)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a8wMKvP48PAY"
      },
      "source": [
        "We schedule to reduce the learning rate by a factor of\n",
        "`gamma_lr_scheduler` every 10 epochs.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 456,
      "metadata": {
        "id": "dLQsPIzy8PAY"
      },
      "outputs": [],
      "source": [
        "exp_lr_scheduler = lr_scheduler.StepLR(\n",
        "    optimizer_hybrid, step_size=10, gamma=gamma_lr_scheduler\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Q-xTUZhq8PAY"
      },
      "source": [
        "What follows is a training function that will be called later. This\n",
        "function should return a trained model that can be used to make\n",
        "predictions (classifications).\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 457,
      "metadata": {
        "id": "rppVRya_8PAY"
      },
      "outputs": [],
      "source": [
        "def train_model(model, criterion, optimizer, scheduler, num_epochs):\n",
        "    since = time.time()\n",
        "    best_model_wts = copy.deepcopy(model.state_dict())\n",
        "    best_acc = 0.0\n",
        "    best_loss = 10000.0  # Large arbitrary number\n",
        "    best_acc_train = 0.0\n",
        "    best_loss_train = 10000.0  # Large arbitrary number\n",
        "    print(\"Training started:\")\n",
        "\n",
        "    for epoch in range(num_epochs):\n",
        "\n",
        "        # Each epoch has a training and validation phase\n",
        "        for phase in [\"train\", \"validation\"]:\n",
        "            if phase == \"train\":\n",
        "                # Set model to training mode\n",
        "                model.train()\n",
        "            else:\n",
        "                # Set model to evaluate mode\n",
        "                model.eval()\n",
        "            running_loss = 0.0\n",
        "            running_corrects = 0\n",
        "\n",
        "            # Iterate over data.\n",
        "            n_batches = dataset_sizes[phase] // batch_size\n",
        "            it = 0\n",
        "            for inputs, labels in dataloaders[phase]:\n",
        "                since_batch = time.time()\n",
        "                batch_size_ = len(inputs)\n",
        "                inputs = inputs.to(device)\n",
        "                labels = labels.to(device)\n",
        "                optimizer.zero_grad()\n",
        "\n",
        "                # Track/compute gradient and make an optimization step only when training\n",
        "                with torch.set_grad_enabled(phase == \"train\"):\n",
        "                    outputs = model(inputs)\n",
        "                    _, preds = torch.max(outputs, 1)\n",
        "                    loss = criterion(outputs, labels)\n",
        "                    if phase == \"train\":\n",
        "                        loss.backward()\n",
        "                        optimizer.step()\n",
        "\n",
        "                # Print iteration results\n",
        "                running_loss += loss.item() * batch_size_\n",
        "                batch_corrects = torch.sum(preds == labels.data).item()\n",
        "                running_corrects += batch_corrects\n",
        "                print(\n",
        "                    \"Phase: {} Epoch: {}/{} Iter: {}/{} Batch time: {:.4f}\".format(\n",
        "                        phase,\n",
        "                        epoch + 1,\n",
        "                        num_epochs,\n",
        "                        it + 1,\n",
        "                        n_batches + 1,\n",
        "                        time.time() - since_batch,\n",
        "                    ),\n",
        "                    end=\"\\r\",\n",
        "                    flush=True,\n",
        "                )\n",
        "                it += 1\n",
        "\n",
        "            # Print epoch results\n",
        "            epoch_loss = running_loss / dataset_sizes[phase]\n",
        "            epoch_acc = running_corrects / dataset_sizes[phase]\n",
        "            print(\n",
        "                \"Phase: {} Epoch: {}/{} Loss: {:.4f} Acc: {:.4f}        \".format(\n",
        "                    \"train\" if phase == \"train\" else \"validation  \",\n",
        "                    epoch + 1,\n",
        "                    num_epochs,\n",
        "                    epoch_loss,\n",
        "                    epoch_acc,\n",
        "                )\n",
        "            )\n",
        "\n",
        "            # Check if this is the best model wrt previous epochs\n",
        "            if phase == \"validation\" and epoch_acc > best_acc:\n",
        "                best_acc = epoch_acc\n",
        "                best_model_wts = copy.deepcopy(model.state_dict())\n",
        "            if phase == \"validation\" and epoch_loss < best_loss:\n",
        "                best_loss = epoch_loss\n",
        "            if phase == \"train\" and epoch_acc > best_acc_train:\n",
        "                best_acc_train = epoch_acc\n",
        "            if phase == \"train\" and epoch_loss < best_loss_train:\n",
        "                best_loss_train = epoch_loss\n",
        "\n",
        "            # Update learning rate\n",
        "            if phase == \"train\":\n",
        "                scheduler.step()\n",
        "\n",
        "    # Print final results\n",
        "    model.load_state_dict(best_model_wts)\n",
        "    time_elapsed = time.time() - since\n",
        "    print(\n",
        "        \"Training completed in {:.0f}m {:.0f}s\".format(time_elapsed // 60, time_elapsed % 60)\n",
        "    )\n",
        "    print(\"Best test loss: {:.4f} | Best test accuracy: {:.4f}\".format(best_loss, best_acc))\n",
        "    return model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a_XtRwDI8PAZ"
      },
      "source": [
        "We are ready to perform the actual training process.\n"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from IPython.display import display, Javascript\n",
        "\n",
        "# Run this cell to keep Colab awake\n",
        "display(Javascript('''\n",
        "  function keep_colab_awake(){\n",
        "    console.log(\"Colab is being kept awake.\");\n",
        "    document.querySelector(\"#top-toolbar > colab-connect-button\").shadowRoot.querySelector(\"#connect\").click();\n",
        "    document.querySelector(\"body > colab-sandbox-output > div > div.output.container.output-wrapper > div.output > pre\").innerText;\n",
        "    setTimeout(keep_colab_awake, 61000);\n",
        "  }\n",
        "  keep_colab_awake();\n",
        "'''))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 17
        },
        "id": "p2W621Tsy2hY",
        "outputId": "66ebfb1d-5449-408d-a61f-f2264f3953e9"
      },
      "execution_count": 458,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.Javascript object>"
            ],
            "application/javascript": [
              "\n",
              "  function keep_colab_awake(){\n",
              "    console.log(\"Colab is being kept awake.\");\n",
              "    document.querySelector(\"#top-toolbar > colab-connect-button\").shadowRoot.querySelector(\"#connect\").click();\n",
              "    document.querySelector(\"body > colab-sandbox-output > div > div.output.container.output-wrapper > div.output > pre\").innerText;\n",
              "    setTimeout(keep_colab_awake, 61000);\n",
              "  }\n",
              "  keep_colab_awake();\n"
            ]
          },
          "metadata": {}
        }
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 459,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 0
        },
        "id": "5VgfdD3-8PAZ",
        "outputId": "90f8608e-aa55-4ce0-969e-4caff873f465"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Training started:\n",
            "Phase: train Epoch: 1/5 Loss: 2.2026 Acc: 0.2093        \n",
            "Phase: validation   Epoch: 1/5 Loss: 2.0288 Acc: 0.2659        \n",
            "Phase: train Epoch: 2/5 Loss: 2.0414 Acc: 0.2716        \n",
            "Phase: validation   Epoch: 2/5 Loss: 1.9228 Acc: 0.3117        \n",
            "Phase: train Epoch: 3/5 Loss: 1.9595 Acc: 0.2944        \n",
            "Phase: validation   Epoch: 3/5 Loss: 1.8117 Acc: 0.3728        \n",
            "Phase: train Epoch: 4/5 Loss: 1.9000 Acc: 0.3280        \n",
            "Phase: validation   Epoch: 4/5 Loss: 1.7902 Acc: 0.3728        \n",
            "Phase: train Epoch: 5/5 Loss: 1.8263 Acc: 0.3435        \n",
            "Phase: validation   Epoch: 5/5 Loss: 1.6910 Acc: 0.4156        \n",
            "Training completed in 19m 24s\n",
            "Best test loss: 1.6910 | Best test accuracy: 0.4156\n"
          ]
        }
      ],
      "source": [
        "model_hybrid = train_model(\n",
        "    model_hybrid, criterion, optimizer_hybrid, exp_lr_scheduler, num_epochs=num_epochs\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AG82Ot6Y8PAZ"
      },
      "source": [
        "Visualizing the model predictions\n",
        "=================================\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cwycKwbd8PAZ"
      },
      "source": [
        "We first define a visualization function for a batch of test data.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 460,
      "metadata": {
        "id": "_8R2rHzF8PAZ"
      },
      "outputs": [],
      "source": [
        "def visualize_model(model, num_images=6, fig_name=\"Predictions\"):\n",
        "    images_so_far = 0\n",
        "    _fig = plt.figure(fig_name)\n",
        "    model.eval()\n",
        "    with torch.no_grad():\n",
        "        for _i, (inputs, labels) in enumerate(dataloaders[\"validation\"]):\n",
        "            inputs = inputs.to(device)\n",
        "            labels = labels.to(device)\n",
        "            outputs = model(inputs)\n",
        "            _, preds = torch.max(outputs, 1)\n",
        "            for j in range(inputs.size()[0]):\n",
        "                images_so_far += 1\n",
        "                ax = plt.subplot(num_images // 2, 2, images_so_far)\n",
        "                ax.axis(\"off\")\n",
        "                ax.set_title(\"[{}]\".format(class_names[preds[j]]))\n",
        "                imshow(inputs.cpu().data[j])\n",
        "                if images_so_far == num_images:\n",
        "                    return"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LQvJfmme8PAa"
      },
      "source": [
        "Finally, we can run the previous function to see a batch of images with\n",
        "the corresponding predictions.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 461,
      "metadata": {
        "id": "mKBJn2x68PAa",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 428
        },
        "outputId": "93116c97-5cd0-4cbb-d747-90037371c4a5"
      },
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<Figure size 640x480 with 16 Axes>"
            ],
            "image/png": "\n"
          },
          "metadata": {}
        }
      ],
      "source": [
        "visualize_model(model_hybrid, num_images=16)\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "seconds = time.time()\n",
        "print(\"Time in seconds since end of run:\", seconds)\n",
        "local_time = time.ctime(seconds)\n",
        "print(local_time)"
      ],
      "metadata": {
        "id": "D3AaQc2xMk-G",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 0
        },
        "outputId": "f982c3c4-768d-408b-9ecb-efa2c87d1b39"
      },
      "execution_count": 462,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Time in seconds since end of run: 1695698106.8667762\n",
            "Tue Sep 26 03:15:06 2023\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# from google.colab import runtime\n",
        "# runtime.unassign()"
      ],
      "metadata": {
        "id": "fALJ8tZXA0to"
      },
      "execution_count": 463,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0yhgWSns8PAa"
      },
      "source": [
        "References\n",
        "==========\n",
        "\n",
        "\\[1\\] Andrea Mari, Thomas R. Bromley, Josh Izaac, Maria Schuld, and\n",
        "Nathan Killoran. *Transfer learning in hybrid classical-quantum neural\n",
        "networks*. arXiv:1912.08278 (2019).\n",
        "\n",
        "\\[2\\] Rajat Raina, Alexis Battle, Honglak Lee, Benjamin Packer, and\n",
        "Andrew Y Ng. *Self-taught learning: transfer learning from unlabeled\n",
        "data*. Proceedings of the 24th International Conference on Machine\n",
        "Learning\\*, 759--766 (2007).\n",
        "\n",
        "\\[3\\] Kaiming He, Xiangyu Zhang, Shaoqing ren and Jian Sun. *Deep\n",
        "residual learning for image recognition*. Proceedings of the IEEE\n",
        "Conference on Computer Vision and Pattern Recognition, 770-778 (2016).\n",
        "\n",
        "\\[4\\] Ville Bergholm, Josh Izaac, Maria Schuld, Christian Gogolin,\n",
        "Carsten Blank, Keri McKiernan, and Nathan Killoran. *PennyLane:\n",
        "Automatic differentiation of hybrid quantum-classical computations*.\n",
        "arXiv:1811.04968 (2018).\n",
        "\n",
        "About the author\n",
        "================\n"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.17"
    },
    "colab": {
      "provenance": [],
      "machine_shape": "hm",
      "gpuType": "V100"
    },
    "accelerator": "GPU"
  },
  "nbformat": 4,
  "nbformat_minor": 0
}