[404218]: / Code / All PennyLane QML Demos / 22 Quantum GANs 0.64 Loss kkawchak.ipynb

Download this file

991 lines (990 with data), 95.3 kB

{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": 14,
      "metadata": {
        "id": "Adv9qCFis2V1",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 0
        },
        "outputId": "6031da25-97f1-4790-f416-e3e2d3225184"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Time in seconds since beginning of run: 1693264730.4222255\n",
            "Mon Aug 28 23:18:50 2023\n"
          ]
        }
      ],
      "source": [
        "# This cell is added by sphinx-gallery\n",
        "# It can be customized to whatever you like\n",
        "%matplotlib inline\n",
        "# !pip install pennylane\n",
        "# from google.colab import drive\n",
        "# drive.mount('/content/drive')\n",
        "import time\n",
        "seconds = time.time()\n",
        "print(\"Time in seconds since beginning of run:\", seconds)\n",
        "local_time = time.ctime(seconds)\n",
        "print(local_time)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UK7teNKZs2V2"
      },
      "source": [
        "Quantum GANs {#quantum_gans}\n",
        "============\n",
        "\n",
        "::: {.meta}\n",
        ":property=\\\"og:description\\\": Explore quantum GANs to generate\n",
        "hand-written digits of zero :property=\\\"og:image\\\":\n",
        "<https://pennylane.ai/qml/_images/patch.jpeg>\n",
        ":::\n",
        "\n",
        "::: {.related}\n",
        "tutorial\\_QGAN Quantum generative adversarial networks with Cirq +\n",
        "TensorFlow\n",
        ":::\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xjulm5j1s2V3"
      },
      "source": [
        "*Author: James Ellis --- Posted: 01 February 2022. Last updated: 27\n",
        "January 2022.*\n",
        "\n",
        "In this tutorial, we will explore quantum GANs to generate hand-written\n",
        "digits of zero. We will first cover the theory of the classical case,\n",
        "then extend to a quantum method recently proposed in the literature. If\n",
        "you have no experience with GANs, particularly in PyTorch, you might\n",
        "find [PyTorch\\'s\n",
        "tutorial](https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html)\n",
        "useful since it serves as the foundation for what is to follow.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lEfsJh_fs2V3"
      },
      "source": [
        "Generative Adversarial Networks (GANs)\n",
        "======================================\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EByLMGeNs2V3"
      },
      "source": [
        "The goal of generative adversarial networks (GANs) is to generate data\n",
        "that resembles the original data used in training. To achieve this, we\n",
        "train two neural networks simulatenously: a generator and a\n",
        "discriminator. The job of the generator is to create fake data which\n",
        "imitates the real training dataset. On the otherhand, the discriminator\n",
        "acts like a detective trying to discern real from fake data. During the\n",
        "training process, both players iteratively improve with one another. By\n",
        "the end, the generator should hopefully generate new data very similar\n",
        "to the training dataset.\n",
        "\n",
        "Specifically, the training dataset represents samples drawn from some\n",
        "unknown data distribution $P_{data}$, and the generator has the job of\n",
        "trying to capture this distribution. The generator, $G$, starts from\n",
        "some initial latent distribution, $P_z$, and maps it to $P_g = G(P_z)$.\n",
        "The best solution would be for $P_g = P_{data}$. However, this point is\n",
        "rarely achieved in practice apart from in the most simple tasks.\n",
        "\n",
        "Both the discriminator, $D$, and generator, $G$, play in a 2-player\n",
        "minimax game. The discriminator tries to maximise the probability of\n",
        "discerning real from fake data, while the generator tries to minimise\n",
        "the same probability. The value function for the game is summarised by,\n",
        "\n",
        "$$\\begin{aligned}\n",
        "\\begin{align}\n",
        "\\min_G \\max_D V(D,G) &= \\mathbb{E}_{\\boldsymbol{x}\\sim p_{data}}[\\log D(\\boldsymbol{x})] \\\\\n",
        "    & ~~ + \\mathbb{E}_{\\boldsymbol{z}\\sim p_{\\boldsymbol{z}}}[\\log(1 - D(G(\\boldsymbol{z}))]\n",
        "\\end{align}\n",
        "\\end{aligned}$$\n",
        "\n",
        "-   $\\boldsymbol{x}$: real data sample\n",
        "-   $\\boldsymbol{z}$: latent vector\n",
        "-   $D(\\boldsymbol{x})$: probability of the discriminator classifying\n",
        "    real data as real\n",
        "-   $G(\\boldsymbol{z})$: fake data\n",
        "-   $D(G(\\boldsymbol{z}))$: probability of discriminator classifying\n",
        "    fake data as real\n",
        "\n",
        "In practice, the two networks are trained iteratively, each with a\n",
        "separate loss function to be minimised,\n",
        "\n",
        "$$L_D = -[y \\cdot \\log(D(x)) + (1-y)\\cdot \\log(1-D(G(z)))]$$\n",
        "\n",
        "$$L_G = [(1-y) \\cdot \\log(1-D(G(z)))]$$\n",
        "\n",
        "where $y$ is a binary label for real ($y=1$) or fake ($y=0$) data. In\n",
        "practice, generator training is shown to be more stable when made to\n",
        "maximise $\\log(D(G(z)))$ instead of minimising $\\log(1-D(G(z)))$. Hence,\n",
        "the generator loss function to be minimised becomes,\n",
        "\n",
        "$$L_G = -[(1-y) \\cdot \\log(D(G(z)))]$$\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EZP6T-7Is2V3"
      },
      "source": [
        "Quantum GANs: The Patch Method\n",
        "==============================\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PonC35zos2V3"
      },
      "source": [
        "In this tutorial, we re-create one of the quantum GAN methods presented\n",
        "by Huang et al.: the patch method. This method uses several quantum\n",
        "generators, with each sub-generator, $G^{(i)}$, responsible for\n",
        "constructing a small patch of the final image. The final image is\n",
        "contructed by concatenting all of the patches together as shown below.\n",
        "\n",
        "![](../demonstrations/quantum_gans/patch.jpeg){.align-center\n",
        "width=\"90.0%\"}\n",
        "\n",
        "The main advantage of this method is that it is particulary suited to\n",
        "situations where the number of available qubits are limited. The same\n",
        "quantum device can be used for each sub-generator in an iterative\n",
        "fashion, or execution of the generators can be parallelised across\n",
        "multiple devices.\n",
        "\n",
        "::: {.note}\n",
        "::: {.title}\n",
        "Note\n",
        ":::\n",
        "\n",
        "In this tutorial, parenthesised superscripts are used to denote\n",
        "individual objects as part of a collection.\n",
        ":::\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FoPuIvwJs2V3"
      },
      "source": [
        "Module Imports\n",
        "==============\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "metadata": {
        "id": "3HyrRTKws2V3"
      },
      "outputs": [],
      "source": [
        "# Library imports\n",
        "import math\n",
        "import random\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "import matplotlib.pyplot as plt\n",
        "import matplotlib.gridspec as gridspec\n",
        "import pennylane as qml\n",
        "\n",
        "# Pytorch imports\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "import torchvision\n",
        "import torchvision.transforms as transforms\n",
        "from torch.utils.data import Dataset, DataLoader\n",
        "\n",
        "# Set the random seed for reproducibility\n",
        "seed = 42\n",
        "torch.manual_seed(seed)\n",
        "np.random.seed(seed)\n",
        "random.seed(seed)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PlFxROwFs2V3"
      },
      "source": [
        "Data\n",
        "====\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0Xc1K-WFs2V4"
      },
      "source": [
        "As mentioned in the introduction, we will use a [small\n",
        "dataset](https://archive.ics.uci.edu/ml/datasets/optical+recognition+of+handwritten+digits)\n",
        "of handwritten zeros. First, we need to create a custom dataloader for\n",
        "this dataset.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 16,
      "metadata": {
        "id": "bS-UnbQ8s2V4"
      },
      "outputs": [],
      "source": [
        "class DigitsDataset(Dataset):\n",
        "    \"\"\"Pytorch dataloader for the Optical Recognition of Handwritten Digits Data Set\"\"\"\n",
        "\n",
        "    def __init__(self, csv_file, label=0, transform=None):\n",
        "        \"\"\"\n",
        "        Args:\n",
        "            csv_file (string): Path to the csv file with annotations.\n",
        "            root_dir (string): Directory with all the images.\n",
        "            transform (callable, optional): Optional transform to be applied\n",
        "                on a sample.\n",
        "        \"\"\"\n",
        "        self.csv_file = csv_file\n",
        "        self.transform = transform\n",
        "        self.df = self.filter_by_label(label)\n",
        "\n",
        "    def filter_by_label(self, label):\n",
        "        # Use pandas to return a dataframe of only zeros\n",
        "        df = pd.read_csv(self.csv_file)\n",
        "        df = df.loc[df.iloc[:, -1] == label]\n",
        "        return df\n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.df)\n",
        "\n",
        "    def __getitem__(self, idx):\n",
        "        if torch.is_tensor(idx):\n",
        "            idx = idx.tolist()\n",
        "\n",
        "        image = self.df.iloc[idx, :-1] / 16\n",
        "        image = np.array(image)\n",
        "        image = image.astype(np.float32).reshape(8, 8)\n",
        "\n",
        "        if self.transform:\n",
        "            image = self.transform(image)\n",
        "\n",
        "        # Return image and label\n",
        "        return image, 0"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "za7LyZnNs2V4"
      },
      "source": [
        "Next we define some variables and create the dataloader instance.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 17,
      "metadata": {
        "id": "FejShZNxs2V4"
      },
      "outputs": [],
      "source": [
        "image_size = 8  # Height / width of the square images\n",
        "batch_size = 1\n",
        "\n",
        "transform = transforms.Compose([transforms.ToTensor()])\n",
        "dataset = DigitsDataset(csv_file=\"/content/drive/MyDrive/Colab Notebooks/data/optdigits.tra\", transform=transform)\n",
        "dataloader = torch.utils.data.DataLoader(\n",
        "    dataset, batch_size=batch_size, shuffle=True, drop_last=True\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FEFii_G2s2V4"
      },
      "source": [
        "Let\\'s visualize some of the data.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 18,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 102
        },
        "id": "dzGFZQyRs2V4",
        "outputId": "f0e8aea0-63b3-4ede-ba57-1b5bebd60051"
      },
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<Figure size 800x200 with 8 Axes>"
            ],
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAoAAAABVCAYAAADZoVWWAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAGLklEQVR4nO3d0VEbyxYF0NEtB2BHYDKwQ8AZQARABIYIBBEYIgBHgJyBHIFNBiYCTAS636/qVe8uzUga7lnr94jpnlZPzyl9bBabzWYzAABQxj+HngAAAPulAQQAKEYDCABQjAYQAKAYDSAAQDEaQACAYjSAAADFaAABAIp51/vBxWIxaqCHh4f4mbOzs2b94uJi9BhjbZObPXbt/v79Gz/z+fPnZn29Xo8a4+TkJM7hz58/zfou1u79+/fN+u3tbRwj3Vu6xvX1dRxjrEPsu6Ojo/iZ1WrVrKd9d3l52T2fbW2bdT92/Xr2xXK5bNZvbm5GjzHWIfZez75In0lr8199X/TcVzqr0/skrX26fo+5nnlpfdO7dE5r5xdAAIBiNIAAAMVoAAEAitEAAgAUowEEAChGAwgAUIwGEACgmMWmMzAm5eukbJtv377FMVLmVcrMOj09bdZTZlmPQ2QT9Yz5+vrarP/+/btZT3l6PWuXcrd2sXbn5+fN+v39fRzj6uqqWT8+Pm7W09pOkdV2iH3Xk0e1633Vk+OY7CoHMN3by8tLHOPu7m7UGGMzx3rsYu+l7M3Hx8c4Rlq79NymsyPt7R5v9X2Rnrv0/b3V3NgpzrxUT/syZTD2kAMIAMD/pQEEAChGAwgAUIwGEACgGA0gAEAxGkAAgGI0gAAAxbyb6kIp26ZHyktLGT0pu2iKjJ9dODo6Gn2NNO+xuU1TfL+7kHKdeqzX62b94eFh1N/3ZNmlPLddSDloHz9+jNdImVXpvlJe1RQ5gLuS7j1lrQ3D+Jy+tDd7no9D7L20dj9//ozXSGuXzrS0t+Z65k0xr7Q26UxL+6onB/AQz3aaV0+PkK6Rnqc0Rk8/0DPPHn4BBAAoRgMIAFCMBhAAoBgNIABAMRpAAIBiNIAAAMVoAAEAitEAAgAUM1kQdAqG7An2TFLoaQpoTEHTPdfYhTTmFIGyKZxybHjlXD09PcXPjA3/HhvCPQx5b+9C+s57ntmxIcI/fvxo1nvWbrVajZrDtlJg6z5C5VNg71z3Xpr3FGHHaV+kIPRUH4bDrF0K0X5+fo7XSOufjH0Xz1XPeTb2zEtr17PvenqZHn4BBAAoRgMIAFCMBhAAoBgNIABAMRpAAIBiNIAAAMVoAAEAipksB3AOUu7WFNlSu5Ay9nryxMZmjqVsqblK+ZNjM5t6pH11iKywHuk7T2s7hbRvU9beIe1jbyXp7JjrmTeH7zWt3T72/zbSvPaR2TrXfZXMYd+lc2Of+84vgAAAxWgAAQCK0QACABSjAQQAKEYDCABQjAYQAKAYDSAAQDGT5QCuVqtm/fr6Ol5jbKbbyclJs76PfKS5ury8HFU/Pz+fbC5TSvtuuVyOHiPd+xxyubZxe3vbrE+xdslcs9Z6pPNoH5ljbzW/Mz23aW/2SHsrPddzPfPSeTJF7mjaV+l9nt7Fh5JyR9N7cB/kAAIAsDMaQACAYjSAAADFaAABAIrRAAIAFKMBBAAoRgMIAFDMYrPZbLo+uFiMGqhnmJubm2Y95Wql7KHj4+M4h5QT1Llc/yOtXZr34+NjHOPp6alZ//TpU7N+cXHRrE+RLbWLtUtSVtsw5Hv7+vVrs/7ly5dmfb1exzkkh1i79CwMQ85KS3li9/f3zfqHDx/iHNJ3vM3aDUNev5TX9fLyEse4u7tr1tO9TZFll/bnXPdeem7Tvac8vSmy7Haxduk91nNWp+/87OysWb+6umrWp8hxnOu+S2uXntmUNdiz76Z6Zv0CCABQjAYQAKAYDSAAQDEaQACAYjSAAADFaAABAIrRAAIAFKMBBAAo5t2+Bvr+/Xv8zHK5bNZfX1+b9RSQ2RPyeAir1apZ71m7FJJ9eno6ag5vVU8oagqMTQHlUwQ9z1FPiPCvX7+a9efn52Y9hWj3BHkfSppbClcfhhyYOzZU9q3uzXRfw5DPrHTvUwQ9H0K6r57vPL0rq74ves689E5J/3RhTu8TvwACABSjAQQAKEYDCABQjAYQAKAYDSAAQDEaQACAYjSAAADFLDabzebQkwAAYH/8AggAUIwGEACgGA0gAEAxGkAAgGI0gAAAxWgAAQCK0QACABSjAQQAKEYDCABQzL/T5WECTjZA+wAAAABJRU5ErkJggg==\n"
          },
          "metadata": {}
        }
      ],
      "source": [
        "plt.figure(figsize=(8,2))\n",
        "\n",
        "for i in range(8):\n",
        "    image = dataset[i][0].reshape(image_size,image_size)\n",
        "    plt.subplot(1,8,i+1)\n",
        "    plt.axis('off')\n",
        "    plt.imshow(image.numpy(), cmap='gray')\n",
        "\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-NemlIWxs2V4"
      },
      "source": [
        "Implementing the Discriminator\n",
        "==============================\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IOcvPprks2V4"
      },
      "source": [
        "For the discriminator, we use a fully connected neural network with two\n",
        "hidden layers. A single output is sufficient to represent the\n",
        "probability of an input being classified as real.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 19,
      "metadata": {
        "id": "hPJvAU4Ts2V4"
      },
      "outputs": [],
      "source": [
        "class Discriminator(nn.Module):\n",
        "    \"\"\"Fully connected classical discriminator\"\"\"\n",
        "\n",
        "    def __init__(self):\n",
        "        super().__init__()\n",
        "\n",
        "        self.model = nn.Sequential(\n",
        "            # Inputs to first hidden layer (num_input_features -> 64)\n",
        "            nn.Linear(image_size * image_size, 64),\n",
        "            nn.ReLU(),\n",
        "            # First hidden layer (64 -> 16)\n",
        "            nn.Linear(64, 16),\n",
        "            nn.ReLU(),\n",
        "            # Second hidden layer (16 -> output)\n",
        "            nn.Linear(16, 1),\n",
        "            nn.Sigmoid(),\n",
        "        )\n",
        "\n",
        "    def forward(self, x):\n",
        "        return self.model(x)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1s1-1Hcus2V4"
      },
      "source": [
        "Implementing the Generator\n",
        "==========================\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "V1rbipYps2V4"
      },
      "source": [
        "Each sub-generator, $G^{(i)}$, shares the same circuit architecture as\n",
        "shown below. The overall quantum generator consists of $N_G$\n",
        "sub-generators, each consisting of $N$ qubits. The process from latent\n",
        "vector input to image output can be split into four distinct sections:\n",
        "state embedding, parameterisation, non-linear transformation, and\n",
        "post-processing. Each of the following sections below refer to a single\n",
        "iteration of the training process to simplify the discussion.\n",
        "\n",
        "![](../demonstrations/quantum_gans/qcircuit.jpeg){.align-center\n",
        "width=\"90.0%\"}\n",
        "\n",
        "**1) State Embedding**\n",
        "\n",
        "A latent vector, $\\boldsymbol{z}\\in\\mathbb{R}^N$, is sampled from a\n",
        "uniform distribution in the interval $[0,\\pi/2)$. All sub-generators\n",
        "receive the same latent vector which is then embedded using RY gates.\n",
        "\n",
        "**2) Parameterised Layers**\n",
        "\n",
        "The parameterised layer consists of parameterised RY gates followed by\n",
        "control Z gates. This layer is repeated $D$ times in total.\n",
        "\n",
        "**3) Non-Linear Transform**\n",
        "\n",
        "Quantum gates in the circuit model are unitary which, by definition,\n",
        "linearly transform the quantum state. A linear mapping between the\n",
        "latent and generator distribution would suffice for only the most simple\n",
        "generative tasks, hence we need non-linear transformations. We will use\n",
        "ancillary qubits to help.\n",
        "\n",
        "For a given sub-generator, the pre-measurement quantum state is given\n",
        "by,\n",
        "\n",
        "$$|\\Psi(z)\\rangle = U_{G}(\\theta)|\\boldsymbol{z}\\rangle$$\n",
        "\n",
        "where $U_{G}(\\theta)$ represents the overall unitary of the\n",
        "parameterised layers. Let us inspect the state when we take a partial\n",
        "measurment, $\\Pi$, and trace out the ancillary subsystem, $\\mathcal{A}$,\n",
        "\n",
        "$$\\rho(\\boldsymbol{z}) = \\frac{\\text{Tr}_{\\mathcal{A}}(\\Pi \\otimes \\mathbb{I} |\\Psi(z)\\rangle \\langle \\Psi(\\boldsymbol{z})|) }{\\text{Tr}(\\Pi \\otimes \\mathbb{I} |\\Psi(\\boldsymbol{z})\\rangle \\langle \\Psi(\\boldsymbol{z})|))} = \\frac{\\text{Tr}_{\\mathcal{A}}(\\Pi \\otimes \\mathbb{I} |\\Psi(\\boldsymbol{z})\\rangle \\langle \\Psi(\\boldsymbol{z})|) }{\\langle \\Psi(\\boldsymbol{z})| \\Pi \\otimes \\mathbb{I} |\\Psi(\\boldsymbol{z})\\rangle}$$\n",
        "\n",
        "The post-measurement state, $\\rho(\\boldsymbol{z})$, is dependent on\n",
        "$\\boldsymbol{z}$ in both the numerator and denominator. This means the\n",
        "state has been non-linearly transformed! For this tutorial,\n",
        "$\\Pi = (|0\\rangle \\langle0|)^{\\otimes N_A}$, where $N_A$ is the number\n",
        "of ancillary qubits in the system.\n",
        "\n",
        "With the remaining data qubits, we measure the probability of\n",
        "$\\rho(\\boldsymbol{z})$ in each computational basis state, $P(j)$, to\n",
        "obtain the sub-generator output, $\\boldsymbol{g}^{(i)}$,\n",
        "\n",
        "$$\\boldsymbol{g}^{(i)} = [P(0), P(1), ... ,P(2^{N-N_A} - 1)]$$\n",
        "\n",
        "**4) Post Processing**\n",
        "\n",
        "Due to the normalisation constraint of the measurment, all elements in\n",
        "$\\boldsymbol{g}^{(i)}$ must sum to one. This is a problem if we are to\n",
        "use $\\boldsymbol{g}^{(i)}$ as the pixel intensity values for our patch.\n",
        "For example, imagine a hypothetical situation where a patch of full\n",
        "intensity pixels was the target. The best patch a sub-generator could\n",
        "produce would be a patch of pixels all at a magnitude of\n",
        "$\\frac{1}{2^{N-N_A}}$. To alleviate this constraint, we apply a\n",
        "post-processing technique to each patch,\n",
        "\n",
        "$$\\boldsymbol{\\tilde{x}^{(i)}} = \\frac{\\boldsymbol{g}^{(i)}}{\\max_{k}\\boldsymbol{g}_k^{(i)}}$$\n",
        "\n",
        "Therefore, the final image, $\\boldsymbol{\\tilde{x}}$, is given by\n",
        "\n",
        "$$\\boldsymbol{\\tilde{x}} = [\\boldsymbol{\\tilde{x}^{(1)}}, ... ,\\boldsymbol{\\tilde{x}^{(N_G)}}]$$\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 20,
      "metadata": {
        "id": "ByOocBfSs2V5"
      },
      "outputs": [],
      "source": [
        "# Quantum variables\n",
        "n_qubits = 5  # Total number of qubits / N\n",
        "n_a_qubits = 1  # Number of ancillary qubits / N_A\n",
        "q_depth = 10  # Depth of the parameterised quantum circuit / D\n",
        "n_generators = 4  # Number of subgenerators for the patch method / N_G"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ow6pzPrJs2V5"
      },
      "source": [
        "Now we define the quantum device we want to use, along with any\n",
        "available CUDA GPUs (if available).\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 21,
      "metadata": {
        "id": "hve9vkDEs2V5"
      },
      "outputs": [],
      "source": [
        "# Quantum simulator\n",
        "dev = qml.device(\"lightning.qubit\", wires=n_qubits)\n",
        "# Enable CUDA device if available\n",
        "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ish10FIUs2V5"
      },
      "source": [
        "Next, we define the quantum circuit and measurement process described\n",
        "above.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 22,
      "metadata": {
        "id": "LNy_ET6ns2V5"
      },
      "outputs": [],
      "source": [
        "@qml.qnode(dev, interface=\"torch\", diff_method=\"parameter-shift\")\n",
        "def quantum_circuit(noise, weights):\n",
        "\n",
        "    weights = weights.reshape(q_depth, n_qubits)\n",
        "\n",
        "    # Initialise latent vectors\n",
        "    for i in range(n_qubits):\n",
        "        qml.RY(noise[i], wires=i)\n",
        "\n",
        "    # Repeated layer\n",
        "    for i in range(q_depth):\n",
        "        # Parameterised layer\n",
        "        for y in range(n_qubits):\n",
        "            qml.RY(weights[i][y], wires=y)\n",
        "\n",
        "        # Control Z gates\n",
        "        for y in range(n_qubits - 1):\n",
        "            qml.CZ(wires=[y, y + 1])\n",
        "\n",
        "    return qml.probs(wires=list(range(n_qubits)))\n",
        "\n",
        "\n",
        "# For further info on how the non-linear transform is implemented in Pennylane\n",
        "# https://discuss.pennylane.ai/t/ancillary-subsystem-measurement-then-trace-out/1532\n",
        "def partial_measure(noise, weights):\n",
        "    # Non-linear Transform\n",
        "    probs = quantum_circuit(noise, weights)\n",
        "    probsgiven0 = probs[: (2 ** (n_qubits - n_a_qubits))]\n",
        "    probsgiven0 /= torch.sum(probs)\n",
        "\n",
        "    # Post-Processing\n",
        "    probsgiven = probsgiven0 / torch.max(probsgiven0)\n",
        "    return probsgiven"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OG5gYWSvs2V5"
      },
      "source": [
        "Now we create a quantum generator class to use during training.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 23,
      "metadata": {
        "id": "n6DbP7d7s2V5"
      },
      "outputs": [],
      "source": [
        "class PatchQuantumGenerator(nn.Module):\n",
        "    \"\"\"Quantum generator class for the patch method\"\"\"\n",
        "\n",
        "    def __init__(self, n_generators, q_delta=1):\n",
        "        \"\"\"\n",
        "        Args:\n",
        "            n_generators (int): Number of sub-generators to be used in the patch method.\n",
        "            q_delta (float, optional): Spread of the random distribution for parameter initialisation.\n",
        "        \"\"\"\n",
        "\n",
        "        super().__init__()\n",
        "\n",
        "        self.q_params = nn.ParameterList(\n",
        "            [\n",
        "                nn.Parameter(q_delta * torch.rand(q_depth * n_qubits), requires_grad=True)\n",
        "                for _ in range(n_generators)\n",
        "            ]\n",
        "        )\n",
        "        self.n_generators = n_generators\n",
        "\n",
        "    def forward(self, x):\n",
        "        # Size of each sub-generator output\n",
        "        patch_size = 2 ** (n_qubits - n_a_qubits)\n",
        "\n",
        "        # Create a Tensor to 'catch' a batch of images from the for loop. x.size(0) is the batch size.\n",
        "        images = torch.Tensor(x.size(0), 0).to(device)\n",
        "\n",
        "        # Iterate over all sub-generators\n",
        "        for params in self.q_params:\n",
        "\n",
        "            # Create a Tensor to 'catch' a batch of the patches from a single sub-generator\n",
        "            patches = torch.Tensor(0, patch_size).to(device)\n",
        "            for elem in x:\n",
        "                q_out = partial_measure(elem, params).float().unsqueeze(0)\n",
        "                patches = torch.cat((patches, q_out))\n",
        "\n",
        "            # Each batch of patches is concatenated with each other to create a batch of images\n",
        "            images = torch.cat((images, patches), 1)\n",
        "\n",
        "        return images"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hdzKXAD5s2V5"
      },
      "source": [
        "Training\n",
        "========\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BJbPX_TSs2V5"
      },
      "source": [
        "Let\\'s define learning rates and number of iterations for the training\n",
        "process.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 24,
      "metadata": {
        "id": "ROVGLoQ0s2V5"
      },
      "outputs": [],
      "source": [
        "lrG = 0.3  # Learning rate for the generator\n",
        "lrD = 0.01  # Learning rate for the discriminator\n",
        "num_iter = 500  # Number of training iterations"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "T-tO_03Ys2V5"
      },
      "source": [
        "Now putting everything together and executing the training process.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 25,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 0
        },
        "id": "XaowdvYgs2V5",
        "outputId": "9d40d53c-6cd8-4d8d-92e9-ceedfb78dcae"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Iteration: 10, Discriminator Loss: 1.374, Generator Loss: 0.582\n",
            "Iteration: 20, Discriminator Loss: 1.355, Generator Loss: 0.595\n",
            "Iteration: 30, Discriminator Loss: 1.365, Generator Loss: 0.580\n",
            "Iteration: 40, Discriminator Loss: 1.374, Generator Loss: 0.576\n",
            "Iteration: 50, Discriminator Loss: 1.352, Generator Loss: 0.589\n",
            "Iteration: 60, Discriminator Loss: 1.333, Generator Loss: 0.605\n",
            "Iteration: 70, Discriminator Loss: 1.348, Generator Loss: 0.600\n",
            "Iteration: 80, Discriminator Loss: 1.314, Generator Loss: 0.599\n",
            "Iteration: 90, Discriminator Loss: 1.333, Generator Loss: 0.601\n",
            "Iteration: 100, Discriminator Loss: 1.312, Generator Loss: 0.606\n",
            "Iteration: 110, Discriminator Loss: 1.279, Generator Loss: 0.629\n",
            "Iteration: 120, Discriminator Loss: 1.309, Generator Loss: 0.604\n",
            "Iteration: 130, Discriminator Loss: 1.253, Generator Loss: 0.624\n",
            "Iteration: 140, Discriminator Loss: 1.315, Generator Loss: 0.604\n",
            "Iteration: 150, Discriminator Loss: 1.227, Generator Loss: 0.660\n",
            "Iteration: 160, Discriminator Loss: 1.268, Generator Loss: 0.610\n",
            "Iteration: 170, Discriminator Loss: 1.306, Generator Loss: 0.599\n",
            "Iteration: 180, Discriminator Loss: 1.296, Generator Loss: 0.615\n",
            "Iteration: 190, Discriminator Loss: 1.264, Generator Loss: 0.604\n",
            "Iteration: 200, Discriminator Loss: 1.321, Generator Loss: 0.589\n",
            "Iteration: 210, Discriminator Loss: 1.208, Generator Loss: 0.667\n",
            "Iteration: 220, Discriminator Loss: 1.290, Generator Loss: 0.622\n",
            "Iteration: 230, Discriminator Loss: 1.236, Generator Loss: 0.628\n",
            "Iteration: 240, Discriminator Loss: 1.282, Generator Loss: 0.630\n",
            "Iteration: 250, Discriminator Loss: 1.190, Generator Loss: 0.660\n",
            "Iteration: 260, Discriminator Loss: 1.343, Generator Loss: 0.585\n",
            "Iteration: 270, Discriminator Loss: 1.303, Generator Loss: 0.624\n",
            "Iteration: 280, Discriminator Loss: 1.244, Generator Loss: 0.688\n",
            "Iteration: 290, Discriminator Loss: 1.269, Generator Loss: 0.641\n",
            "Iteration: 300, Discriminator Loss: 1.273, Generator Loss: 0.616\n",
            "Iteration: 310, Discriminator Loss: 1.247, Generator Loss: 0.759\n",
            "Iteration: 320, Discriminator Loss: 1.323, Generator Loss: 0.612\n",
            "Iteration: 330, Discriminator Loss: 1.281, Generator Loss: 0.647\n",
            "Iteration: 340, Discriminator Loss: 1.212, Generator Loss: 0.760\n",
            "Iteration: 350, Discriminator Loss: 1.322, Generator Loss: 0.636\n",
            "Iteration: 360, Discriminator Loss: 1.334, Generator Loss: 0.631\n",
            "Iteration: 370, Discriminator Loss: 1.283, Generator Loss: 0.721\n",
            "Iteration: 380, Discriminator Loss: 1.209, Generator Loss: 0.685\n",
            "Iteration: 390, Discriminator Loss: 1.332, Generator Loss: 0.629\n",
            "Iteration: 400, Discriminator Loss: 1.161, Generator Loss: 0.695\n",
            "Iteration: 410, Discriminator Loss: 1.340, Generator Loss: 0.630\n",
            "Iteration: 420, Discriminator Loss: 1.126, Generator Loss: 0.703\n",
            "Iteration: 430, Discriminator Loss: 1.204, Generator Loss: 0.753\n",
            "Iteration: 440, Discriminator Loss: 1.124, Generator Loss: 0.735\n",
            "Iteration: 450, Discriminator Loss: 1.136, Generator Loss: 0.743\n",
            "Iteration: 460, Discriminator Loss: 1.169, Generator Loss: 0.731\n",
            "Iteration: 470, Discriminator Loss: 1.054, Generator Loss: 0.785\n",
            "Iteration: 480, Discriminator Loss: 1.204, Generator Loss: 0.729\n",
            "Iteration: 490, Discriminator Loss: 1.075, Generator Loss: 0.778\n",
            "Iteration: 500, Discriminator Loss: 1.227, Generator Loss: 0.644\n"
          ]
        }
      ],
      "source": [
        "discriminator = Discriminator().to(device)\n",
        "generator = PatchQuantumGenerator(n_generators).to(device)\n",
        "\n",
        "# Binary cross entropy\n",
        "criterion = nn.BCELoss()\n",
        "\n",
        "# Optimisers\n",
        "optD = optim.SGD(discriminator.parameters(), lr=lrD)\n",
        "optG = optim.SGD(generator.parameters(), lr=lrG)\n",
        "\n",
        "real_labels = torch.full((batch_size,), 1.0, dtype=torch.float, device=device)\n",
        "fake_labels = torch.full((batch_size,), 0.0, dtype=torch.float, device=device)\n",
        "\n",
        "# Fixed noise allows us to visually track the generated images throughout training\n",
        "fixed_noise = torch.rand(8, n_qubits, device=device) * math.pi / 2\n",
        "\n",
        "# Iteration counter\n",
        "counter = 0\n",
        "\n",
        "# Collect images for plotting later\n",
        "results = []\n",
        "\n",
        "while True:\n",
        "    for i, (data, _) in enumerate(dataloader):\n",
        "\n",
        "        # Data for training the discriminator\n",
        "        data = data.reshape(-1, image_size * image_size)\n",
        "        real_data = data.to(device)\n",
        "\n",
        "        # Noise follwing a uniform distribution in range [0,pi/2)\n",
        "        noise = torch.rand(batch_size, n_qubits, device=device) * math.pi / 2\n",
        "        fake_data = generator(noise)\n",
        "\n",
        "        # Training the discriminator\n",
        "        discriminator.zero_grad()\n",
        "        outD_real = discriminator(real_data).view(-1)\n",
        "        outD_fake = discriminator(fake_data.detach()).view(-1)\n",
        "\n",
        "        errD_real = criterion(outD_real, real_labels)\n",
        "        errD_fake = criterion(outD_fake, fake_labels)\n",
        "        # Propagate gradients\n",
        "        errD_real.backward()\n",
        "        errD_fake.backward()\n",
        "\n",
        "        errD = errD_real + errD_fake\n",
        "        optD.step()\n",
        "\n",
        "        # Training the generator\n",
        "        generator.zero_grad()\n",
        "        outD_fake = discriminator(fake_data).view(-1)\n",
        "        errG = criterion(outD_fake, real_labels)\n",
        "        errG.backward()\n",
        "        optG.step()\n",
        "\n",
        "        counter += 1\n",
        "\n",
        "        # Show loss values\n",
        "        if counter % 10 == 0:\n",
        "            print(f'Iteration: {counter}, Discriminator Loss: {errD:0.3f}, Generator Loss: {errG:0.3f}')\n",
        "            test_images = generator(fixed_noise).view(8,1,image_size,image_size).cpu().detach()\n",
        "\n",
        "            # Save images every 50 iterations\n",
        "            if counter % 50 == 0:\n",
        "                results.append(test_images)\n",
        "\n",
        "        if counter == num_iter:\n",
        "            break\n",
        "    if counter == num_iter:\n",
        "        break"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uDK5G66Bs2V5"
      },
      "source": [
        "::: {.note}\n",
        "::: {.title}\n",
        "Note\n",
        ":::\n",
        "\n",
        "You may have noticed `errG = criterion(outD_fake, real_labels)` and\n",
        "wondered why we don't use `fake_labels` instead of `real_labels`.\n",
        "However, this is simply a trick to be able to use the same `criterion`\n",
        "function for both the generator and discriminator. Using `real_labels`\n",
        "forces the generator loss function to use the $\\log(D(G(z))$ term\n",
        "instead of the $\\log(1 - D(G(z))$ term of the binary cross entropy loss\n",
        "function.\n",
        ":::\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Xsa1GwZus2V6"
      },
      "source": [
        "Finally, we plot how the generated images evolved throughout training.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 26,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 417
        },
        "id": "fvpigTx8s2V6",
        "outputId": "80be262d-635f-4269-f558-af22312ee315"
      },
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<Figure size 1000x500 with 80 Axes>"
            ],
            "image/png": "\n"
          },
          "metadata": {}
        }
      ],
      "source": [
        "fig = plt.figure(figsize=(10, 5))\n",
        "outer = gridspec.GridSpec(5, 2, wspace=0.1)\n",
        "\n",
        "for i, images in enumerate(results):\n",
        "    inner = gridspec.GridSpecFromSubplotSpec(1, images.size(0),\n",
        "                    subplot_spec=outer[i])\n",
        "\n",
        "    images = torch.squeeze(images, dim=1)\n",
        "    for j, im in enumerate(images):\n",
        "\n",
        "        ax = plt.Subplot(fig, inner[j])\n",
        "        ax.imshow(im.numpy(), cmap=\"gray\")\n",
        "        ax.set_xticks([])\n",
        "        ax.set_yticks([])\n",
        "        if j==0:\n",
        "            ax.set_title(f'Iteration {50+i*50}', loc='left')\n",
        "        fig.add_subplot(ax)\n",
        "\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "seconds = time.time()\n",
        "print(\"Time in seconds since end of run:\", seconds)\n",
        "local_time = time.ctime(seconds)\n",
        "print(local_time)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 0
        },
        "id": "y4dVhN44beWS",
        "outputId": "e034b58b-afd9-4403-889c-47ad30b8a9be"
      },
      "execution_count": 27,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Time in seconds since end of run: 1693266593.7901826\n",
            "Mon Aug 28 23:49:53 2023\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "THvPX4IBs2V6"
      },
      "source": [
        "Acknowledgements\n",
        "================\n",
        "\n",
        "Many thanks to Karolis Špukas who I co-developed much of the code with.\n",
        "I also extend my thanks to Dr. Yuxuan Du for answering my questions\n",
        "regarding his paper. I am also indebited to the Pennylane community for\n",
        "their help over the past few years.\n",
        "\n",
        "References\n",
        "==========\n",
        "\n",
        "About the author\n",
        "================\n"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.17"
    },
    "colab": {
      "provenance": [],
      "gpuType": "V100"
    },
    "accelerator": "GPU"
  },
  "nbformat": 4,
  "nbformat_minor": 0
}