[404218]: / Code / PennyLane / 2 Class 4 Class 10 Class / 10 Class 80.2% kkawchak.ipynb

Download this file

1173 lines (1173 with data), 256.1 kB

{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": 77,
      "metadata": {
        "id": "WNXEyZ23rdz-"
      },
      "outputs": [],
      "source": [
        "# This cell is added by sphinx-gallery\n",
        "# It can be customized to whatever you like\n",
        "# from google.colab import drive\n",
        "# drive.mount('/content/drive')"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [],
      "metadata": {
        "id": "WYVI3RMhAdap"
      }
    },
    {
      "cell_type": "code",
      "execution_count": 78,
      "metadata": {
        "id": "sANL984ird0B",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 0
        },
        "outputId": "7c3a12a2-ff66-4596-9aba-8ff998e6d848"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Time in seconds since beginning of run: 1685943611.0886571\n",
            "Mon Jun  5 05:40:11 2023\n"
          ]
        }
      ],
      "source": [
        "# !pip install pennylane\n",
        "# Some parts of this code are based on the Python script:\n",
        "# https://github.com/pytorch/tutorials/blob/master/beginner_source/transfer_learning_tutorial.py\n",
        "# License: BSD\n",
        "\n",
        "import time\n",
        "import os\n",
        "import copy\n",
        "\n",
        "# PyTorch\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "from torch.optim import lr_scheduler\n",
        "import torchvision\n",
        "from torchvision import datasets, transforms\n",
        "from PIL import Image\n",
        "\n",
        "# Pennylane\n",
        "import pennylane as qml\n",
        "from pennylane import numpy as np\n",
        "\n",
        "torch.manual_seed(42)\n",
        "np.random.seed(42)\n",
        "\n",
        "# Plotting\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "# OpenMP: number of parallel threads.\n",
        "os.environ[\"OMP_NUM_THREADS\"] = \"16\"\n",
        "\n",
        "seconds = time.time()\n",
        "print(\"Time in seconds since beginning of run:\", seconds)\n",
        "local_time = time.ctime(seconds)\n",
        "print(local_time)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "o7eGTk_Prd0B"
      },
      "source": [
        "Setting of the main hyper-parameters of the model\n",
        "=================================================\n",
        "\n",
        "::: {.note}\n",
        "::: {.title}\n",
        "Note\n",
        ":::\n",
        "\n",
        "To reproduce the results of Ref. \\[1\\], `num_epochs` should be set to\n",
        "`30` which may take a long time. We suggest to first try with\n",
        "`num_epochs=1` and, if everything runs smoothly, increase it to a larger\n",
        "value.\n",
        ":::\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 79,
      "metadata": {
        "id": "Wncy61Mdrd0B"
      },
      "outputs": [],
      "source": [
        "n_qubits = 20               # Number of qubits\n",
        "step = 0.0006               # Learning rate\n",
        "batch_size = 17             # Number of samples for each training step\n",
        "num_epochs = 30             # Number of training epochs\n",
        "q_depth = 6                 # Depth of the quantum circuit (number of variational layers)\n",
        "gamma_lr_scheduler = 0.1    # Learning rate reduction applied every 10 epochs.\n",
        "q_delta = 0.01              # Initial spread of random quantum weights\n",
        "start_time = time.time()    # Start of the computation timer"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PVqjLo8Rrd0B"
      },
      "source": [
        "We initialize a PennyLane device with a `default.qubit` backend.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 80,
      "metadata": {
        "id": "qLOa5trRrd0B"
      },
      "outputs": [],
      "source": [
        "dev = qml.device(\"default.qubit\", wires=n_qubits)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dxlK7Jjtrd0C"
      },
      "source": [
        "We configure PyTorch to use CUDA only if available. Otherwise the CPU is\n",
        "used.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 81,
      "metadata": {
        "id": "Br_YGwRDrd0C"
      },
      "outputs": [],
      "source": [
        "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WFHuYd5xrd0C"
      },
      "source": [
        "Dataset loading\n",
        "===============\n",
        "\n",
        "::: {.note}\n",
        "::: {.title}\n",
        "Note\n",
        ":::\n",
        "\n",
        "The dataset containing images of *ants* and *bees* can be downloaded\n",
        "[here](https://download.pytorch.org/tutorial/hymenoptera_data.zip) and\n",
        "should be extracted in the subfolder `../_data/hymenoptera_data`.\n",
        ":::\n",
        "\n",
        "This is a very small dataset (roughly 250 images), too small for\n",
        "training from scratch a classical or quantum model, however it is enough\n",
        "when using *transfer learning* approach.\n",
        "\n",
        "The PyTorch packages `torchvision` and `torch.utils.data` are used for\n",
        "loading the dataset and performing standard preliminary image\n",
        "operations: resize, center, crop, normalize, *etc.*\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 82,
      "metadata": {
        "id": "jQrNNVnUrd0C"
      },
      "outputs": [],
      "source": [
        "data_transforms = {\n",
        "    \"train\": transforms.Compose(\n",
        "        [\n",
        "            # transforms.RandomResizedCrop(224),     # uncomment for data augmentation\n",
        "            # transforms.RandomHorizontalFlip(),     # uncomment for data augmentation\n",
        "            transforms.Resize(232),\n",
        "            transforms.CenterCrop(224),\n",
        "            transforms.ToTensor(),\n",
        "            # Normalize input channels using mean values and standard deviations of ImageNet.\n",
        "            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),\n",
        "        ]\n",
        "    ),\n",
        "    \"val\": transforms.Compose(\n",
        "        [\n",
        "            transforms.Resize(232),\n",
        "            transforms.CenterCrop(224),\n",
        "            transforms.ToTensor(),\n",
        "            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),\n",
        "        ]\n",
        "    ),\n",
        "}\n",
        "\n",
        "data_dir = \"/content/drive/MyDrive/Colab Notebooks/data/Shuffle Split 10 of 17 Classes Big Brain Tumor MRI Images\"\n",
        "image_datasets = {\n",
        "    x if x == \"train\" else \"validation\": datasets.ImageFolder(\n",
        "        os.path.join(data_dir, x), data_transforms[x]\n",
        "    )\n",
        "    for x in [\"train\", \"val\"]\n",
        "}\n",
        "dataset_sizes = {x: len(image_datasets[x]) for x in [\"train\", \"validation\"]}\n",
        "class_names = image_datasets[\"train\"].classes\n",
        "\n",
        "# Initialize dataloader\n",
        "dataloaders = {\n",
        "    x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True)\n",
        "    for x in [\"train\", \"validation\"]\n",
        "}\n",
        "\n",
        "# function to plot images\n",
        "def imshow(inp, title=None):\n",
        "    \"\"\"Display image from tensor.\"\"\"\n",
        "    inp = inp.numpy().transpose((1, 2, 0))\n",
        "    # Inverse of the initial normalization operation.\n",
        "    mean = np.array([0.485, 0.456, 0.406])\n",
        "    std = np.array([0.229, 0.224, 0.225])\n",
        "    inp = std * inp + mean\n",
        "    inp = np.clip(inp, 0, 1)\n",
        "    plt.imshow(inp)\n",
        "    if title is not None:\n",
        "        plt.title(title)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "piWk71nkrd0C"
      },
      "source": [
        "Let us show a batch of the test data, just to have an idea of the\n",
        "classification problem.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 83,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 107
        },
        "id": "u55iZYEOrd0D",
        "outputId": "dc55e578-9efb-4902-a0eb-ac0f13015ddf"
      },
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<Figure size 640x480 with 1 Axes>"
            ],
            "image/png": "\n"
          },
          "metadata": {}
        }
      ],
      "source": [
        "# Get a batch of training data\n",
        "inputs, classes = next(iter(dataloaders[\"validation\"]))\n",
        "\n",
        "# Make a grid from batch\n",
        "out = torchvision.utils.make_grid(inputs)\n",
        "\n",
        "imshow(out, title=[class_names[x] for x in classes])\n",
        "\n",
        "dataloaders = {\n",
        "    x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True)\n",
        "    for x in [\"train\", \"validation\"]\n",
        "}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ds09ighxrd0D"
      },
      "source": [
        "Variational quantum circuit\n",
        "===========================\n",
        "\n",
        "We first define some quantum layers that will compose the quantum\n",
        "circuit.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 84,
      "metadata": {
        "id": "OIbOa-KBrd0D"
      },
      "outputs": [],
      "source": [
        "def H_layer(nqubits):\n",
        "    \"\"\"Layer of single-qubit Hadamard gates.\n",
        "    \"\"\"\n",
        "    for idx in range(nqubits):\n",
        "        qml.Hadamard(wires=idx)\n",
        "\n",
        "\n",
        "def RY_layer(w):\n",
        "    \"\"\"Layer of parametrized qubit rotations around the y axis.\n",
        "    \"\"\"\n",
        "    for idx, element in enumerate(w):\n",
        "        qml.RY(element, wires=idx)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ue__y4sQrd0D"
      },
      "source": [
        "Now we define the quantum circuit through the PennyLane\n",
        "[qnode]{.title-ref} decorator .\n",
        "\n",
        "The structure is that of a typical variational quantum circuit:\n",
        "\n",
        "-   **Embedding layer:** All qubits are first initialized in a balanced\n",
        "    superposition of *up* and *down* states, then they are rotated\n",
        "    according to the input parameters (local embedding).\n",
        "-   **Variational layers:** A sequence of trainable rotation layers and\n",
        "    constant entangling layers is applied.\n",
        "-   **Measurement layer:** For each qubit, the local expectation value\n",
        "    of the $Z$ operator is measured. This produces a classical output\n",
        "    vector, suitable for additional post-processing.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 85,
      "metadata": {
        "id": "C-O_cK6jrd0D"
      },
      "outputs": [],
      "source": [
        "@qml.qnode(dev, interface=\"torch\")\n",
        "def quantum_net(q_input_features, q_weights_flat):\n",
        "    \"\"\"\n",
        "    The quantum circuit.\n",
        "    \"\"\"\n",
        "\n",
        "    # Reshape weights\n",
        "    q_weights = q_weights_flat.reshape(q_depth, n_qubits)\n",
        "\n",
        "    # Start from state |+> , unbiased w.r.t. |0> and |1>\n",
        "    H_layer(n_qubits)\n",
        "\n",
        "    # Embed features in the quantum node\n",
        "    RY_layer(q_input_features)\n",
        "\n",
        "    # Expectation values in the Z basis\n",
        "    exp_vals = [qml.expval(qml.PauliZ(position)) for position in range(n_qubits)]\n",
        "    return tuple(exp_vals)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7FNHREgVrd0D"
      },
      "source": [
        "Dressed quantum circuit\n",
        "=======================\n",
        "\n",
        "We can now define a custom `torch.nn.Module` representing a *dressed*\n",
        "quantum circuit.\n",
        "\n",
        "This is a concatenation of:\n",
        "\n",
        "-   A classical pre-processing layer (`nn.Linear`).\n",
        "-   A classical activation function (`torch.tanh`).\n",
        "-   A constant `np.pi/2.0` scaling.\n",
        "-   The previously defined quantum circuit (`quantum_net`).\n",
        "-   A classical post-processing layer (`nn.Linear`).\n",
        "\n",
        "The input of the module is a batch of vectors with 512 real parameters\n",
        "(features) and the output is a batch of vectors with two real outputs\n",
        "(associated with the two classes of images: *ants* and *bees*).\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 86,
      "metadata": {
        "id": "C96kWCjWrd0D"
      },
      "outputs": [],
      "source": [
        "class DressedQuantumNet(nn.Module):\n",
        "    \"\"\"\n",
        "    Torch module implementing the *dressed* quantum net.\n",
        "    \"\"\"\n",
        "\n",
        "    def __init__(self):\n",
        "        \"\"\"\n",
        "        Definition of the *dressed* layout.\n",
        "        \"\"\"\n",
        "\n",
        "        super().__init__()\n",
        "        self.pre_net = nn.Linear(2048, n_qubits)\n",
        "        self.q_params = nn.Parameter(q_delta * torch.randn(q_depth * n_qubits))\n",
        "        self.post_net = nn.Linear(n_qubits, 10)\n",
        "\n",
        "    def forward(self, input_features):\n",
        "        \"\"\"\n",
        "        Defining how tensors are supposed to move through the *dressed* quantum\n",
        "        net.\n",
        "        \"\"\"\n",
        "\n",
        "        # obtain the input features for the quantum circuit\n",
        "        # by reducing the feature dimension from 512 to 4\n",
        "        pre_out = self.pre_net(input_features)\n",
        "        q_in = torch.tanh(pre_out) * np.pi / 2.0\n",
        "\n",
        "        # Apply the quantum circuit to each element of the batch and append to q_out\n",
        "        q_out = torch.Tensor(0, n_qubits)\n",
        "        q_out = q_out.to(device)\n",
        "        for elem in q_in:\n",
        "            q_out_elem = torch.hstack(quantum_net(elem, self.q_params)).float().unsqueeze(0)\n",
        "            q_out = torch.cat((q_out, q_out_elem))\n",
        "\n",
        "        # return the two-dimensional prediction from the postprocessing layer\n",
        "        return self.post_net(q_out)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7hQlpDQdrd0D"
      },
      "source": [
        "Hybrid classical-quantum model\n",
        "==============================\n",
        "\n",
        "We are finally ready to build our full hybrid classical-quantum network.\n",
        "We follow the *transfer learning* approach:\n",
        "\n",
        "1.  First load the classical pre-trained network *ResNet18* from the\n",
        "    `torchvision.models` zoo.\n",
        "2.  Freeze all the weights since they should not be trained.\n",
        "3.  Replace the last fully connected layer with our trainable dressed\n",
        "    quantum circuit (`DressedQuantumNet`).\n",
        "\n",
        "::: {.note}\n",
        "::: {.title}\n",
        "Note\n",
        ":::\n",
        "\n",
        "The *ResNet18* model is automatically downloaded by PyTorch and it may\n",
        "take several minutes (only the first time).\n",
        ":::\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 87,
      "metadata": {
        "id": "MAh4FqBYrd0D",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 0
        },
        "outputId": "5c092ed9-c1c1-480e-d120-41ff791b3f3a"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n",
            "  warnings.warn(\n",
            "/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet50_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet50_Weights.DEFAULT` to get the most up-to-date weights.\n",
            "  warnings.warn(msg)\n"
          ]
        }
      ],
      "source": [
        "model_hybrid = torchvision.models.resnet50(pretrained=True)\n",
        "\n",
        "for param in model_hybrid.parameters():\n",
        "    param.requires_grad = False\n",
        "\n",
        "\n",
        "# Notice that model_hybrid.fc is the last layer of ResNet18\n",
        "model_hybrid.fc = DressedQuantumNet()\n",
        "\n",
        "# Use CUDA or CPU according to the \"device\" object.\n",
        "model_hybrid = model_hybrid.to(device)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ovX-Tkb0rd0E"
      },
      "source": [
        "Training and results\n",
        "====================\n",
        "\n",
        "Before training the network we need to specify the *loss* function.\n",
        "\n",
        "We use, as usual in classification problem, the *cross-entropy* which is\n",
        "directly available within `torch.nn`.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 88,
      "metadata": {
        "id": "gkmFIK4Brd0E"
      },
      "outputs": [],
      "source": [
        "criterion = nn.CrossEntropyLoss()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qhFORuvard0E"
      },
      "source": [
        "We also initialize the *Adam optimizer* which is called at each training\n",
        "step in order to update the weights of the model.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 89,
      "metadata": {
        "id": "_K9M-VPMrd0E"
      },
      "outputs": [],
      "source": [
        "optimizer_hybrid = optim.Adam(model_hybrid.fc.parameters(), lr=step)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IGG8pksyrd0E"
      },
      "source": [
        "We schedule to reduce the learning rate by a factor of\n",
        "`gamma_lr_scheduler` every 10 epochs.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 90,
      "metadata": {
        "id": "nco3IrE6rd0E"
      },
      "outputs": [],
      "source": [
        "exp_lr_scheduler = lr_scheduler.StepLR(\n",
        "    optimizer_hybrid, step_size=1000, gamma=gamma_lr_scheduler\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yJvpPNCRrd0E"
      },
      "source": [
        "What follows is a training function that will be called later. This\n",
        "function should return a trained model that can be used to make\n",
        "predictions (classifications).\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 91,
      "metadata": {
        "id": "8uO5BrOPrd0E"
      },
      "outputs": [],
      "source": [
        "def train_model(model, criterion, optimizer, scheduler, num_epochs):\n",
        "    since = time.time()\n",
        "    best_model_wts = copy.deepcopy(model.state_dict())\n",
        "    best_acc = 0.0\n",
        "    best_loss = 10000.0  # Large arbitrary number\n",
        "    best_acc_train = 0.0\n",
        "    best_loss_train = 10000.0  # Large arbitrary number\n",
        "    print(\"Training started:\")\n",
        "\n",
        "    for epoch in range(num_epochs):\n",
        "\n",
        "        # Each epoch has a training and validation phase\n",
        "        for phase in [\"train\", \"validation\"]:\n",
        "            if phase == \"train\":\n",
        "                # Set model to training mode\n",
        "                model.train()\n",
        "            else:\n",
        "                # Set model to evaluate mode\n",
        "                model.eval()\n",
        "            running_loss = 0.0\n",
        "            running_corrects = 0\n",
        "\n",
        "            # Iterate over data.\n",
        "            n_batches = dataset_sizes[phase] // batch_size\n",
        "            it = 0\n",
        "            for inputs, labels in dataloaders[phase]:\n",
        "                since_batch = time.time()\n",
        "                batch_size_ = len(inputs)\n",
        "                inputs = inputs.to(device)\n",
        "                labels = labels.to(device)\n",
        "                optimizer.zero_grad()\n",
        "\n",
        "                # Track/compute gradient and make an optimization step only when training\n",
        "                with torch.set_grad_enabled(phase == \"train\"):\n",
        "                    outputs = model(inputs)\n",
        "                    _, preds = torch.max(outputs, 1)\n",
        "                    loss = criterion(outputs, labels)\n",
        "                    if phase == \"train\":\n",
        "                        loss.backward()\n",
        "                        optimizer.step()\n",
        "\n",
        "                # Print iteration results\n",
        "                running_loss += loss.item() * batch_size_\n",
        "                batch_corrects = torch.sum(preds == labels.data).item()\n",
        "                running_corrects += batch_corrects\n",
        "                print(\n",
        "                    \"Phase: {} Epoch: {}/{} Iter: {}/{} Batch time: {:.4f}\".format(\n",
        "                        phase,\n",
        "                        epoch + 1,\n",
        "                        num_epochs,\n",
        "                        it + 1,\n",
        "                        n_batches + 1,\n",
        "                        time.time() - since_batch,\n",
        "                    ),\n",
        "                    end=\"\\r\",\n",
        "                    flush=True,\n",
        "                )\n",
        "                it += 1\n",
        "\n",
        "            # Print epoch results\n",
        "            epoch_loss = running_loss / dataset_sizes[phase]\n",
        "            epoch_acc = running_corrects / dataset_sizes[phase]\n",
        "            print(\n",
        "                \"Phase: {} Epoch: {}/{} Loss: {:.4f} Acc: {:.4f}        \".format(\n",
        "                    \"train\" if phase == \"train\" else \"validation  \",\n",
        "                    epoch + 1,\n",
        "                    num_epochs,\n",
        "                    epoch_loss,\n",
        "                    epoch_acc,\n",
        "                )\n",
        "            )\n",
        "\n",
        "            # Check if this is the best model wrt previous epochs\n",
        "            if phase == \"validation\" and epoch_acc > best_acc:\n",
        "                best_acc = epoch_acc\n",
        "                best_model_wts = copy.deepcopy(model.state_dict())\n",
        "            if phase == \"validation\" and epoch_loss < best_loss:\n",
        "                best_loss = epoch_loss\n",
        "            if phase == \"train\" and epoch_acc > best_acc_train:\n",
        "                best_acc_train = epoch_acc\n",
        "            if phase == \"train\" and epoch_loss < best_loss_train:\n",
        "                best_loss_train = epoch_loss\n",
        "      \n",
        "            # Update learning rate\n",
        "            if phase == \"train\":\n",
        "                scheduler.step()\n",
        "\n",
        "    # Print final results\n",
        "    model.load_state_dict(best_model_wts)\n",
        "    time_elapsed = time.time() - since\n",
        "    print(\n",
        "        \"Training completed in {:.0f}m {:.0f}s\".format(time_elapsed // 60, time_elapsed % 60)\n",
        "    )\n",
        "    print(\"Best test loss: {:.4f} | Best test accuracy: {:.4f}\".format(best_loss, best_acc))\n",
        "    return model"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "#!pip install -U scikit-learn\n",
        "import numpy as np\n",
        "import torch\n",
        "import copy\n",
        "import time\n",
        "import sklearn\n",
        "from sklearn.metrics import confusion_matrix\n",
        "\n",
        "def train_model(model, criterion, optimizer, scheduler, num_epochs):\n",
        "    since = time.time()\n",
        "    best_model_wts = copy.deepcopy(model.state_dict())\n",
        "    best_acc = 0.0\n",
        "    best_loss = 10000.0  # Large arbitrary number\n",
        "    best_acc_train = 0.0\n",
        "    best_loss_train = 10000.0  # Large arbitrary number\n",
        "    print(\"Training started:\")\n",
        "\n",
        "    for epoch in range(num_epochs):\n",
        "\n",
        "        # Each epoch has a training and validation phase\n",
        "        for phase in [\"train\", \"validation\"]:\n",
        "            if phase == \"train\":\n",
        "                # Set model to training mode\n",
        "                model.train()\n",
        "            else:\n",
        "                # Set model to evaluate mode\n",
        "                model.eval()\n",
        "            running_loss = 0.0\n",
        "            running_corrects = 0\n",
        "            all_labels = []\n",
        "            all_preds = []\n",
        "\n",
        "            # Iterate over data.\n",
        "            n_batches = dataset_sizes[phase] // batch_size\n",
        "            it = 0\n",
        "            for inputs, labels in dataloaders[phase]:\n",
        "                since_batch = time.time()\n",
        "                batch_size_ = len(inputs)\n",
        "                inputs = inputs.to(device)\n",
        "                labels = labels.to(device)\n",
        "                optimizer.zero_grad()\n",
        "\n",
        "                # Track/compute gradient and make an optimization step only when training\n",
        "                with torch.set_grad_enabled(phase == \"train\"):\n",
        "                    outputs = model(inputs)\n",
        "                    _, preds = torch.max(outputs, 1)\n",
        "                    loss = criterion(outputs, labels)\n",
        "                    if phase == \"train\":\n",
        "                        loss.backward()\n",
        "                        optimizer.step()\n",
        "\n",
        "                # Print iteration results\n",
        "                running_loss += loss.item() * batch_size_\n",
        "                batch_corrects = torch.sum(preds == labels.data).item()\n",
        "                running_corrects += batch_corrects\n",
        "\n",
        "                all_labels.extend(labels.cpu().numpy())\n",
        "                all_preds.extend(preds.cpu().numpy())\n",
        "\n",
        "                print(\n",
        "                    \"Phase: {} Epoch: {}/{} Iter: {}/{} Batch time: {:.4f}\".format(\n",
        "                        phase,\n",
        "                        epoch + 1,\n",
        "                        num_epochs,\n",
        "                        it + 1,\n",
        "                        n_batches + 1,\n",
        "                        time.time() - since_batch,\n",
        "                    ),\n",
        "                    end=\"\\r\",\n",
        "                    flush=True,\n",
        "                )\n",
        "                it += 1\n",
        "\n",
        "            # Print epoch results\n",
        "            epoch_loss = running_loss / dataset_sizes[phase]\n",
        "            epoch_acc = running_corrects / dataset_sizes[phase]\n",
        "            print(\n",
        "                \"Phase: {} Epoch: {}/{} Loss: {:.4f} Acc: {:.4f}        \".format(\n",
        "                    \"train\" if phase == \"train\" else \"validation  \",\n",
        "                    epoch + 1,\n",
        "                    num_epochs,\n",
        "                    epoch_loss,\n",
        "                    epoch_acc,\n",
        "                )\n",
        "            )\n",
        "\n",
        "            # Check if this is the best model wrt previous epochs\n",
        "            if phase == \"validation\" and epoch_acc > best_acc:\n",
        "                best_acc = epoch_acc\n",
        "                best_model_wts = copy.deepcopy(model.state_dict())\n",
        "            if phase == \"validation\" and epoch_loss < best_loss:\n",
        "                best_loss = epoch_loss\n",
        "            if phase == \"train\" and epoch_acc > best_acc_train:\n",
        "                best_acc_train = epoch_acc\n",
        "            if phase == \"train\" and epoch_loss < best_loss_train:\n",
        "                best_loss_train = epoch_loss\n",
        "\n",
        "            # Update learning rate\n",
        "            if phase == \"train\":\n",
        "                scheduler.step()\n",
        "\n",
        "            # Calculate and print confusion matrix\n",
        "            if phase == \"validation\":\n",
        "                cm = confusion_matrix(all_labels, all_preds)\n",
        "                print(\"Confusion Matrix:\")\n",
        "                print(cm)\n",
        "\n",
        "    # Print final results"
      ],
      "metadata": {
        "id": "tMGx77J_zSLE"
      },
      "execution_count": 92,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CmAl9fIQrd0E"
      },
      "source": [
        "We are ready to perform the actual training process.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 93,
      "metadata": {
        "id": "rAQBdDA_rd0E",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 3933
        },
        "outputId": "7b3e4f1f-4182-47ac-8800-341e2dd94875"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Training started:\n",
            "Phase: train Epoch: 1/30 Loss: 1.9051 Acc: 0.2975        \n",
            "Phase: validation   Epoch: 1/30 Loss: 1.6243 Acc: 0.4897        \n",
            "Confusion Matrix:\n",
            "[[ 55   7   3  58  32   2   0   0   0   0]\n",
            " [  5  55   2   6  99   0   0   0   0   0]\n",
            " [  0   0  61   0   1  58   0   3   0   0]\n",
            " [ 22   0   2  93  12   0   0   0   0   0]\n",
            " [ 14   4   0   5 207   2   0   1   0   0]\n",
            " [  1   0  19   1   3  95   0   4   0   0]\n",
            " [  6   0   3  50  22   0  18   1   0   1]\n",
            " [  0   0   7   0   0  55   0  47   0   0]\n",
            " [  3   0   0   2  90   0   0   0   0   0]\n",
            " [  1   0   5   2  53   0   0   1   0  10]]\n",
            "Phase: train Epoch: 2/30 Loss: 1.5724 Acc: 0.4823        \n",
            "Phase: validation   Epoch: 2/30 Loss: 1.4402 Acc: 0.5172        \n",
            "Confusion Matrix:\n",
            "[[ 81   1   0  51  23   0   1   0   0   0]\n",
            " [  8  39   0   2 118   0   0   0   0   0]\n",
            " [  0   0  96   0   1  11   0  15   0   0]\n",
            " [ 20   0   2  96  11   0   0   0   0   0]\n",
            " [  7   0   2  11 212   0   0   1   0   0]\n",
            " [  0   0  75   1   3  30   0  14   0   0]\n",
            " [ 12   0   1  37  18   0  31   1   0   1]\n",
            " [  0   0  21   0   1  11   0  76   0   0]\n",
            " [  5   0   0   1  88   0   0   0   1   0]\n",
            " [  2   0   3   0  51   0   0   1   0  15]]\n",
            "Phase: train Epoch: 3/30 Loss: 1.3715 Acc: 0.5842        \n",
            "Phase: validation   Epoch: 3/30 Loss: 1.2943 Acc: 0.5737        \n",
            "Confusion Matrix:\n",
            "[[ 96   1   0  25  26   0   9   0   0   0]\n",
            " [  7  51   0   0 109   0   0   0   0   0]\n",
            " [  0   0  81   0  10  10   1  21   0   0]\n",
            " [ 35   0   0  75  13   0   6   0   0   0]\n",
            " [  7   0   0   1 222   0   1   2   0   0]\n",
            " [  2   0  56   0   8  25   1  31   0   0]\n",
            " [  5   0   0  10  11   0  75   0   0   0]\n",
            " [  0   0   8   0   3   7   0  91   0   0]\n",
            " [  1   0   0   1  68   0   0   0  25   0]\n",
            " [  0   0   0   0  59   0   3   0   0  10]]\n",
            "Phase: train Epoch: 4/30 Loss: 1.2408 Acc: 0.6269        \n",
            "Phase: validation   Epoch: 4/30 Loss: 1.1349 Acc: 0.6960        \n",
            "Confusion Matrix:\n",
            "[[ 75   2   1  47  10   0  13   0   6   3]\n",
            " [  9  89   0   2  60   0   2   0   4   1]\n",
            " [  0   0  78   0   0  30   1  14   0   0]\n",
            " [ 16   0   2  97   1   0  11   0   2   0]\n",
            " [  8   1   0   7 193   1   2   1   6  14]\n",
            " [  0   0  20   1   1  86   1  13   1   0]\n",
            " [  4   0   0   9   2   0  85   0   0   1]\n",
            " [  0   0   5   0   0  22   0  82   0   0]\n",
            " [  0   1   0   2   9   0   0   0  82   1]\n",
            " [  0   0   2   1  20   0   5   0   0  44]]\n",
            "Phase: train Epoch: 5/30 Loss: 1.1178 Acc: 0.6779        \n",
            "Phase: validation   Epoch: 5/30 Loss: 1.0509 Acc: 0.7089        \n",
            "Confusion Matrix:\n",
            "[[123   9   2  18   0   0   4   0   0   1]\n",
            " [ 12 139   0   1  14   0   0   0   1   0]\n",
            " [  0   0  87   0   0  14   1  21   0   0]\n",
            " [ 37   2   1  84   1   0   4   0   0   0]\n",
            " [ 10  48   1  14 144   0   2   1   0  13]\n",
            " [  0   0  38   3   0  59   0  22   0   1]\n",
            " [  9   5   0  10   0   0  77   0   0   0]\n",
            " [  0   0   8   0   0   9   0  92   0   0]\n",
            " [  3   5   0   1   9   0   0   0  77   0]\n",
            " [  1   9   3   2   8   0   3   0   0  46]]\n",
            "Phase: train Epoch: 6/30 Loss: 1.0091 Acc: 0.7106        \n",
            "Phase: validation   Epoch: 6/30 Loss: 0.9293 Acc: 0.7410        \n",
            "Confusion Matrix:\n",
            "[[110  11   1  16   3   0  13   0   0   3]\n",
            " [  8 131   0   0  24   0   0   1   3   0]\n",
            " [  0   0  82   0   0  21   1  19   0   0]\n",
            " [ 30   0   2  83   3   0  10   0   1   0]\n",
            " [  8  26   0   8 174   1   1   1   3  11]\n",
            " [  0   0  25   0   0  74   0  23   0   1]\n",
            " [  8   1   0   4   2   0  85   0   0   1]\n",
            " [  0   0   4   0   0   9   0  96   0   0]\n",
            " [  0   4   0   2   3   0   0   0  86   0]\n",
            " [  0   7   4   1   9   0   2   0   0  49]]\n",
            "Phase: train Epoch: 7/30 Loss: 0.9341 Acc: 0.7411        \n",
            "Phase: validation   Epoch: 7/30 Loss: 1.0064 Acc: 0.6944        \n",
            "Confusion Matrix:\n",
            "[[126   9   0   9  11   0   2   0   0   0]\n",
            " [  7 125   0   0  34   0   0   0   1   0]\n",
            " [  5   6  82   1  10  11   1   7   0   0]\n",
            " [ 47   1   1  64  11   0   5   0   0   0]\n",
            " [  4  14   0   2 208   1   1   0   1   2]\n",
            " [  8   6  42   2  10  48   0   7   0   0]\n",
            " [  9   1   0   5   5   0  80   0   0   1]\n",
            " [  3   3  14   0   3  10   2  74   0   0]\n",
            " [  0   3   0   1  14   0   0   0  77   0]\n",
            " [  2   7   0   2  36   0   0   0   0  25]]\n",
            "Phase: train Epoch: 8/30 Loss: 0.8810 Acc: 0.7470        \n",
            "Phase: validation   Epoch: 8/30 Loss: 0.8609 Acc: 0.7471        \n",
            "Confusion Matrix:\n",
            "[[101   5   1  24   9   0  16   0   0   1]\n",
            " [  5 116   0   1  39   0   1   1   4   0]\n",
            " [  0   0  88   0   1   6   1  27   0   0]\n",
            " [ 11   1   0  99   6   0  11   0   1   0]\n",
            " [  2   5   1   9 199   0   2   1   6   8]\n",
            " [  0   0  37   0   3  49   1  33   0   0]\n",
            " [  4   0   0   5   2   0  90   0   0   0]\n",
            " [  0   0   6   0   0   0   0 103   0   0]\n",
            " [  0   2   0   2   3   0   0   0  88   0]\n",
            " [  0   2   2   1  19   0   3   0   0  45]]\n",
            "Phase: train Epoch: 9/30 Loss: 0.8138 Acc: 0.7666        \n",
            "Phase: validation   Epoch: 9/30 Loss: 0.7972 Acc: 0.7678        \n",
            "Confusion Matrix:\n",
            "[[114   7   1  19   4   0  12   0   0   0]\n",
            " [  8 125   0   1  30   1   0   1   1   0]\n",
            " [  0   0  76   0   0  22   1  24   0   0]\n",
            " [ 16   0   0  97   4   0  11   0   1   0]\n",
            " [  3  11   0   8 194   1   3   1   7   5]\n",
            " [  0   0  23   0   0  69   1  29   0   1]\n",
            " [  3   0   0   3   1   0  94   0   0   0]\n",
            " [  0   0   3   0   0   4   0 102   0   0]\n",
            " [  2   1   0   1   4   0   0   0  87   0]\n",
            " [  0   5   3   1  11   0   4   0   1  47]]\n",
            "Phase: train Epoch: 10/30 Loss: 0.7510 Acc: 0.7775        \n",
            "Phase: validation   Epoch: 10/30 Loss: 0.7569 Acc: 0.7662        \n",
            "Confusion Matrix:\n",
            "[[ 98   7   1  26   2   0  20   0   0   3]\n",
            " [  8 132   0   0  20   0   1   0   2   4]\n",
            " [  0   0  85   0   0  30   1   7   0   0]\n",
            " [  9   0   1 103   1   0  15   0   0   0]\n",
            " [  2  16   0  12 180   2   3   0   4  14]\n",
            " [  0   0  20   0   0  90   1  11   0   1]\n",
            " [  3   2   0   3   0   0  93   0   0   0]\n",
            " [  0   0   6   0   0  24   0  79   0   0]\n",
            " [  0   4   0   2   3   0   0   0  86   0]\n",
            " [  0   3   3   0   4   0   5   0   0  57]]\n",
            "Phase: train Epoch: 11/30 Loss: 0.7298 Acc: 0.7862        \n",
            "Phase: validation   Epoch: 11/30 Loss: 0.7479 Acc: 0.7578        \n",
            "Confusion Matrix:\n",
            "[[109   2   0  18   6   0  11   0   8   3]\n",
            " [  4  96   0   0  57   1   1   1   7   0]\n",
            " [  0   0  72   0   1  38   0  11   0   1]\n",
            " [ 11   0   0  93   6   0   9   0   5   5]\n",
            " [  3   2   0   3 207   2   1   0   6   9]\n",
            " [  0   0  14   1   2  96   0   9   1   0]\n",
            " [  4   0   0   5   2   0  89   0   0   1]\n",
            " [  0   0   5   0   0  13   0  91   0   0]\n",
            " [  1   1   0   1   5   0   0   0  87   0]\n",
            " [  0   1   0   1  15   0   2   0   1  52]]\n",
            "Phase: train Epoch: 12/30 Loss: 0.7349 Acc: 0.7712        \n",
            "Phase: validation   Epoch: 12/30 Loss: 0.6927 Acc: 0.7754        \n",
            "Confusion Matrix:\n",
            "[[115   8   0  17   0   0  13   0   1   3]\n",
            " [  9 128   0   1  26   0   0   0   1   2]\n",
            " [  1   1  88   0   2  27   1   3   0   0]\n",
            " [ 12   1   0  98   7   0   9   0   0   2]\n",
            " [  2  11   0  10 196   2   1   0   0  11]\n",
            " [  1   2  28   3   2  81   0   5   0   1]\n",
            " [  4   2   0   3   0   0  92   0   0   0]\n",
            " [  1   2  11   0   0  21   1  72   0   1]\n",
            " [  0   3   0   1   5   0   0   0  86   0]\n",
            " [  0   4   0   0   6   0   3   0   0  59]]\n",
            "Phase: train Epoch: 13/30 Loss: 0.6608 Acc: 0.7948        \n",
            "Phase: validation   Epoch: 13/30 Loss: 0.6373 Acc: 0.8021        \n",
            "Confusion Matrix:\n",
            "[[122  11   1  11   2   0   6   0   1   3]\n",
            " [  3 138   0   0  22   1   0   0   2   1]\n",
            " [  0   0  89   0   2  25   0   6   0   1]\n",
            " [ 17   3   1  96   6   0   4   0   0   2]\n",
            " [  4  15   0   3 199   2   0   0   0  10]\n",
            " [  0   1  17   0   1  94   0  10   0   0]\n",
            " [ 10   2   0   2   0   0  87   0   0   0]\n",
            " [  0   0   9   0   0  14   0  86   0   0]\n",
            " [  0   5   0   1   3   0   0   0  86   0]\n",
            " [  3   9   1   0   6   0   0   0   0  53]]\n",
            "Phase: train Epoch: 14/30 Loss: 0.6274 Acc: 0.8144        \n",
            "Phase: validation   Epoch: 14/30 Loss: 0.6640 Acc: 0.7830        \n",
            "Confusion Matrix:\n",
            "[[114   6   1  23   2   0   7   0   1   3]\n",
            " [  3 125   0   1  35   0   0   0   2   1]\n",
            " [  0   0  86   1   2  24   1   9   0   0]\n",
            " [  9   1   0 108   5   0   4   0   1   1]\n",
            " [  2   7   0  10 199   2   1   0   1  11]\n",
            " [  0   0  31   4   2  78   0   8   0   0]\n",
            " [  3   0   0   7   1   0  90   0   0   0]\n",
            " [  0   0  12   0   0  10   0  87   0   0]\n",
            " [  0   2   0   2   3   0   0   0  87   1]\n",
            " [  3   2   1   3  11   0   1   0   0  51]]\n",
            "Phase: train Epoch: 15/30 Loss: 0.6185 Acc: 0.8116        \n",
            "Phase: validation   Epoch: 15/30 Loss: 0.7286 Acc: 0.7296        \n",
            "Confusion Matrix:\n",
            "[[117   4   1  16   0   0  18   0   0   1]\n",
            " [ 13 118   0   1  25   1   1   1   2   5]\n",
            " [  0   0 106   0   0   2   2  13   0   0]\n",
            " [ 14   0   0  78   2   0  32   0   0   3]\n",
            " [  1   8   2   9 180   0   8   0   6  19]\n",
            " [  0   0  77   2   1  23   1  18   0   1]\n",
            " [  0   0   1   1   1   0  98   0   0   0]\n",
            " [  0   0  22   0   0   1   0  86   0   0]\n",
            " [  1   1   0   1   3   0   1   0  88   0]\n",
            " [  0   0   2   0   4   0   5   0   0  61]]\n"
          ]
        },
        {
          "output_type": "error",
          "ename": "KeyboardInterrupt",
          "evalue": "ignored",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
            "\u001b[0;32m<ipython-input-93-bbd0ce1e21df>\u001b[0m in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m model_hybrid = train_model(\n\u001b[0m\u001b[1;32m      2\u001b[0m     \u001b[0mmodel_hybrid\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer_hybrid\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mexp_lr_scheduler\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_epochs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnum_epochs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m )\n",
            "\u001b[0;32m<ipython-input-92-ef1a1cf14921>\u001b[0m in \u001b[0;36mtrain_model\u001b[0;34m(model, criterion, optimizer, scheduler, num_epochs)\u001b[0m\n\u001b[1;32m     43\u001b[0m                 \u001b[0;31m# Track/compute gradient and make an optimization step only when training\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     44\u001b[0m                 \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_grad_enabled\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mphase\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"train\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 45\u001b[0;31m                     \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     46\u001b[0m                     \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpreds\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     47\u001b[0m                     \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1499\u001b[0m                 \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1500\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1502\u001b[0m         \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1503\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torchvision/models/resnet.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m    283\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    284\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 285\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    286\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    287\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torchvision/models/resnet.py\u001b[0m in \u001b[0;36m_forward_impl\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m    278\u001b[0m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mavgpool\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    279\u001b[0m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mflatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 280\u001b[0;31m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    281\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    282\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1499\u001b[0m                 \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1500\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1502\u001b[0m         \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1503\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m<ipython-input-86-ef8634c23b32>\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input_features)\u001b[0m\n\u001b[1;32m     29\u001b[0m         \u001b[0mq_out\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mq_out\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     30\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0melem\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mq_in\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 31\u001b[0;31m             \u001b[0mq_out_elem\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhstack\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mquantum_net\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0melem\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mq_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munsqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     32\u001b[0m             \u001b[0mq_out\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mq_out\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mq_out_elem\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     33\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/pennylane/qnode.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m    865\u001b[0m                 \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexecute_kwargs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpop\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"mode\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    866\u001b[0m             \u001b[0;31m# pylint: disable=unexpected-keyword-arg\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 867\u001b[0;31m             res = qml.execute(\n\u001b[0m\u001b[1;32m    868\u001b[0m                 \u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtape\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    869\u001b[0m                 \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/pennylane/interfaces/execution.py\u001b[0m in \u001b[0;36mexecute\u001b[0;34m(tapes, device, gradient_fn, interface, grad_on_execution, gradient_kwargs, cache, cachesize, max_diff, override_shots, expand_fn, max_expansion, device_batch_transform)\u001b[0m\n\u001b[1;32m    405\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mgradient_fn\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"backprop\"\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0minterface\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    406\u001b[0m         return batch_fn(\n\u001b[0;32m--> 407\u001b[0;31m             qml.interfaces.cache_execute(\n\u001b[0m\u001b[1;32m    408\u001b[0m                 \u001b[0mbatch_execute\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcache\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreturn_tuple\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mexpand_fn\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mexpand_fn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    409\u001b[0m             )(tapes)\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/pennylane/interfaces/execution.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(tapes, **kwargs)\u001b[0m\n\u001b[1;32m    202\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    203\u001b[0m             \u001b[0;31m# execute all unique tapes that do not exist in the cache\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 204\u001b[0;31m             \u001b[0mres\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mexecution_tapes\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    205\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    206\u001b[0m         \u001b[0mfinal_res\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/pennylane/interfaces/execution.py\u001b[0m in \u001b[0;36mfn\u001b[0;34m(tapes, **kwargs)\u001b[0m\n\u001b[1;32m    128\u001b[0m         \u001b[0;32mdef\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtapes\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mSequence\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mQuantumTape\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m  \u001b[0;31m# pylint: disable=function-redefined\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    129\u001b[0m             \u001b[0mtapes\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mexpand_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtape\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mtape\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtapes\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 130\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0moriginal_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtapes\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    131\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    132\u001b[0m     \u001b[0;34m@\u001b[0m\u001b[0mwraps\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/lib/python3.10/contextlib.py\u001b[0m in \u001b[0;36minner\u001b[0;34m(*args, **kwds)\u001b[0m\n\u001b[1;32m     77\u001b[0m         \u001b[0;32mdef\u001b[0m \u001b[0minner\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     78\u001b[0m             \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_recreate_cm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 79\u001b[0;31m                 \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     80\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0minner\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     81\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/pennylane/_qubit_device.py\u001b[0m in \u001b[0;36mbatch_execute\u001b[0;34m(self, circuits)\u001b[0m\n\u001b[1;32m    586\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    587\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 588\u001b[0;31m             \u001b[0mres\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexecute\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcircuit\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    589\u001b[0m             \u001b[0mresults\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mres\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    590\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/pennylane/devices/default_qubit_torch.py\u001b[0m in \u001b[0;36mexecute\u001b[0;34m(self, circuit, **kwargs)\u001b[0m\n\u001b[1;32m    230\u001b[0m                     )\n\u001b[1;32m    231\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 232\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexecute\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcircuit\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    233\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    234\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_asarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/pennylane/_qubit_device.py\u001b[0m in \u001b[0;36mexecute\u001b[0;34m(self, circuit, **kwargs)\u001b[0m\n\u001b[1;32m    316\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    317\u001b[0m         \u001b[0;31m# apply all circuit operations\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 318\u001b[0;31m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcircuit\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moperations\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrotations\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_get_diagonalizing_gates\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcircuit\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    319\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    320\u001b[0m         \u001b[0;31m# generate computational basis samples\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/pennylane/devices/default_qubit.py\u001b[0m in \u001b[0;36mapply\u001b[0;34m(self, operations, rotations, **kwargs)\u001b[0m\n\u001b[1;32m    274\u001b[0m                 \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_state\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply_parametrized_evolution\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_state\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moperation\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    275\u001b[0m             \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 276\u001b[0;31m                 \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_state\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply_operation\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_state\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moperation\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    277\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    278\u001b[0m         \u001b[0;31m# store the pre-rotated state\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/pennylane/devices/default_qubit.py\u001b[0m in \u001b[0;36m_apply_operation\u001b[0;34m(self, state, operation)\u001b[0m\n\u001b[1;32m    320\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mwires\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m<=\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    321\u001b[0m             \u001b[0;31m# Einsum is faster for small gates\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 322\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply_unitary_einsum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmatrix\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwires\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    323\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    324\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply_unitary\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmatrix\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwires\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/pennylane/devices/default_qubit.py\u001b[0m in \u001b[0;36m_apply_unitary_einsum\u001b[0;34m(self, state, mat, wires)\u001b[0m\n\u001b[1;32m    868\u001b[0m         )\n\u001b[1;32m    869\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 870\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_einsum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0meinsum_indices\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmat\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    871\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    872\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_apply_diagonal_unitary\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mphases\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwires\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/functional.py\u001b[0m in \u001b[0;36meinsum\u001b[0;34m(*args)\u001b[0m\n\u001b[1;32m    376\u001b[0m         \u001b[0;31m# the path for contracting 0 or 1 time(s) is already optimized\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    377\u001b[0m         \u001b[0;31m# or the user has disabled using opt_einsum\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 378\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0m_VF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meinsum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mequation\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moperands\u001b[0m\u001b[0;34m)\u001b[0m  \u001b[0;31m# type: ignore[attr-defined]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    379\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    380\u001b[0m     \u001b[0mpath\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
          ]
        }
      ],
      "source": [
        "model_hybrid = train_model(\n",
        "    model_hybrid, criterion, optimizer_hybrid, exp_lr_scheduler, num_epochs=num_epochs\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "seconds = time.time()\n",
        "print(\"Time in seconds since beginning of run:\", seconds)\n",
        "local_time = time.ctime(seconds)\n",
        "print(local_time)"
      ],
      "metadata": {
        "id": "-qU71GfkJOnf"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "s4cEc4mird0E"
      },
      "source": [
        "Visualizing the model predictions\n",
        "=================================\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ARvjv_lbrd0E"
      },
      "source": [
        "We first define a visualization function for a batch of test data.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IBXZTnzjrd0E"
      },
      "source": [
        "References\n",
        "==========\n",
        "\n",
        "\\[1\\] Andrea Mari, Thomas R. Bromley, Josh Izaac, Maria Schuld, and\n",
        "Nathan Killoran. *Transfer learning in hybrid classical-quantum neural\n",
        "networks*. arXiv:1912.08278 (2019).\n",
        "\n",
        "\\[2\\] Rajat Raina, Alexis Battle, Honglak Lee, Benjamin Packer, and\n",
        "Andrew Y Ng. *Self-taught learning: transfer learning from unlabeled\n",
        "data*. Proceedings of the 24th International Conference on Machine\n",
        "Learning\\*, 759--766 (2007).\n",
        "\n",
        "\\[3\\] Kaiming He, Xiangyu Zhang, Shaoqing ren and Jian Sun. *Deep\n",
        "residual learning for image recognition*. Proceedings of the IEEE\n",
        "Conference on Computer Vision and Pattern Recognition, 770-778 (2016).\n",
        "\n",
        "\\[4\\] Ville Bergholm, Josh Izaac, Maria Schuld, Christian Gogolin,\n",
        "Carsten Blank, Keri McKiernan, and Nathan Killoran. *PennyLane:\n",
        "Automatic differentiation of hybrid quantum-classical computations*.\n",
        "arXiv:1811.04968 (2018).\n",
        "\n",
        "About the author\n",
        "================\n"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.16"
    },
    "colab": {
      "provenance": [],
      "gpuType": "V100"
    },
    "accelerator": "GPU",
    "gpuClass": "standard"
  },
  "nbformat": 4,
  "nbformat_minor": 0
}