[404218]: / Code / All PennyLane QML Demos / 24 Quantum Kernels 100% kkawchak.ipynb

Download this file

1143 lines (1142 with data), 63.0 kB

{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": 104,
      "metadata": {
        "id": "IZi5TUAp-1Lt"
      },
      "outputs": [],
      "source": [
        "# This cell is added by sphinx-gallery\n",
        "# It can be customized to whatever you like\n",
        "%matplotlib inline\n",
        "# !pip install pennylane"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yx4UlxKp-1Lu"
      },
      "source": [
        "Training and evaluating quantum kernels\n",
        "=======================================\n",
        "\n",
        "::: {.meta}\n",
        ":property=\\\"og:description\\\": Kernels and alignment training with\n",
        "Pennylane. :property=\\\"og:image\\\":\n",
        "<https://pennylane.ai/qml/_images/QEK_thumbnail.png>\n",
        ":::\n",
        "\n",
        "::: {.related}\n",
        "tutorial\\_kernel\\_based\\_training Kernel-based training with\n",
        "scikit-learn tutorial\\_data\\_reuploading\\_classifier Data-reuploading\n",
        "classifier\n",
        ":::\n",
        "\n",
        "*Authors: Peter-Jan Derks, Paul K. Faehrmann, Elies Gil-Fuster, Tom\n",
        "Hubregtsen, Johannes Jakob Meyer and David Wierichs --- Posted: 24 June\n",
        "2021. Last updated: 18 November 2021.*\n",
        "\n",
        "Kernel methods are one of the cornerstones of classical machine\n",
        "learning. Here we are concerned with kernels that can be evaluated on\n",
        "quantum computers, *quantum kernels* for short. In this tutorial you\n",
        "will learn how to evaluate kernels, use them for classification and\n",
        "train them with gradient-based optimization, and all that using the\n",
        "functionality of PennyLane\\'s [kernels\n",
        "module](https://pennylane.readthedocs.io/en/latest/code/qml_kernels.html).\n",
        "The demo is based on Ref., a project from Xanadu\\'s own\n",
        "[QHack](https://qhack.ai/) hackathon.\n",
        "\n",
        "What are kernel methods?\n",
        "------------------------\n",
        "\n",
        "To understand what a kernel method does, let\\'s first revisit one of the\n",
        "simplest methods to assign binary labels to datapoints: linear\n",
        "classification.\n",
        "\n",
        "Imagine we want to discern two different classes of points that lie in\n",
        "different corners of the plane. A linear classifier corresponds to\n",
        "drawing a line and assigning different labels to the regions on opposing\n",
        "sides of the line:\n",
        "\n",
        "![](../demonstrations/kernels_module/linear_classification.png){.align-center\n",
        "width=\"30.0%\"}\n",
        "\n",
        "We can mathematically formalize this by assigning the label $y$ via\n",
        "\n",
        "$$y(\\boldsymbol{x}) = \\operatorname{sgn}(\\langle \\boldsymbol{w}, \\boldsymbol{x}\\rangle + b).$$\n",
        "\n",
        "The vector $\\boldsymbol{w}$ points perpendicular to the line and thus\n",
        "determine its slope. The independent term $b$ specifies the position on\n",
        "the plane. In this form, linear classification can also be extended to\n",
        "higher dimensional vectors $\\boldsymbol{x}$, where a line does not\n",
        "divide the entire space into two regions anymore. Instead one needs a\n",
        "*hyperplane*. It is immediately clear that this method is not very\n",
        "powerful, as datasets that are not separable by a hyperplane can\\'t be\n",
        "classified without error.\n",
        "\n",
        "We can actually sneak around this limitation by performing a neat trick:\n",
        "if we define some map $\\phi(\\boldsymbol{x})$ that *embeds* our\n",
        "datapoints into a larger *feature space* and then perform linear\n",
        "classification there, we could actually realise non-linear\n",
        "classification in our original space!\n",
        "\n",
        "![](../demonstrations/kernels_module/embedding_nonlinear_classification.png){.align-center\n",
        "width=\"65.0%\"}\n",
        "\n",
        "If we go back to the expression for our prediction and include the\n",
        "embedding, we get\n",
        "\n",
        "$$y(\\boldsymbol{x}) = \\operatorname{sgn}(\\langle \\boldsymbol{w}, \\phi(\\boldsymbol{x})\\rangle + b).$$\n",
        "\n",
        "We will forgo one tiny step, but it can be shown that for the purpose of\n",
        "optimal classification, we can choose the vector defining the decision\n",
        "boundary as a linear combination of the embedded datapoints\n",
        "$\\boldsymbol{w} = \\sum_i \\alpha_i \\phi(\\boldsymbol{x}_i)$. Putting this\n",
        "into the formula yields\n",
        "\n",
        "$$y(\\boldsymbol{x}) = \\operatorname{sgn}\\left(\\sum_i \\alpha_i \\langle \\phi(\\boldsymbol{x}_i), \\phi(\\boldsymbol{x})\\rangle + b\\right).$$\n",
        "\n",
        "This rewriting might not seem useful at first, but notice the above\n",
        "formula only contains inner products between vectors in the embedding\n",
        "space:\n",
        "\n",
        "$$k(\\boldsymbol{x}_i, \\boldsymbol{x}_j) = \\langle \\phi(\\boldsymbol{x}_i), \\phi(\\boldsymbol{x}_j)\\rangle.$$\n",
        "\n",
        "We call this function the *kernel*. It provides the advantage that we\n",
        "can often find an explicit formula for the kernel $k$ that makes it\n",
        "superfluous to actually perform the (potentially expensive) embedding\n",
        "$\\phi$. Consider for example the following embedding and the associated\n",
        "kernel:\n",
        "\n",
        "$$\\begin{aligned}\n",
        "\\phi((x_1, x_2)) &= (x_1^2, \\sqrt{2} x_1 x_2, x_2^2) \\\\\n",
        "k(\\boldsymbol{x}, \\boldsymbol{y}) &= x_1^2 y_1^2 + 2 x_1 x_2 y_1 y_2 + x_2^2 y_2^2 = \\langle \\boldsymbol{x}, \\boldsymbol{y} \\rangle^2.\n",
        "\\end{aligned}$$\n",
        "\n",
        "This means by just replacing the regular scalar product in our linear\n",
        "classification with the map $k$, we can actually express much more\n",
        "intricate decision boundaries!\n",
        "\n",
        "This is very important, because in many interesting cases the embedding\n",
        "$\\phi$ will be much costlier to compute than the kernel $k$.\n",
        "\n",
        "In this demo, we will explore one particular kind of kernel that can be\n",
        "realized on near-term quantum computers, namely *Quantum Embedding\n",
        "Kernels (QEKs)*. These are kernels that arise from embedding data into\n",
        "the space of quantum states. We formalize this by considering a\n",
        "parameterised quantum circuit $U(\\boldsymbol{x})$ that maps a datapoint\n",
        "$\\boldsymbol{x}$ to the state\n",
        "\n",
        "$$|\\psi(\\boldsymbol{x})\\rangle = U(\\boldsymbol{x}) |0 \\rangle.$$\n",
        "\n",
        "The kernel value is then given by the *overlap* of the associated\n",
        "embedded quantum states\n",
        "\n",
        "$$k(\\boldsymbol{x}_i, \\boldsymbol{x}_j) = | \\langle\\psi(\\boldsymbol{x}_i)|\\psi(\\boldsymbol{x}_j)\\rangle|^2.$$\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YFfhnwX7-1Lv"
      },
      "source": [
        "A toy problem\n",
        "=============\n",
        "\n",
        "In this demo, we will treat a toy problem that showcases the inner\n",
        "workings of classification with quantum embedding kernels, training\n",
        "variational embedding kernels and the available functionalities to do\n",
        "both in PennyLane. We of course need to start with some imports:\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 105,
      "metadata": {
        "id": "_0dLIZDi-1Lv"
      },
      "outputs": [],
      "source": [
        "from pennylane import numpy as np\n",
        "import matplotlib as mpl\n",
        "\n",
        "np.random.seed(1359)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rxY0RHIl-1Lv"
      },
      "source": [
        "And we proceed right away to create a dataset to work with, the\n",
        "`DoubleCake` dataset. Firstly, we define two functions to enable us to\n",
        "generate the data. The details of these functions are not essential for\n",
        "understanding the demo, so don\\'t mind them if they are confusing.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 106,
      "metadata": {
        "id": "-MlPehLl-1Lw"
      },
      "outputs": [],
      "source": [
        "def _make_circular_data(num_sectors):\n",
        "    \"\"\"Generate datapoints arranged in an even circle.\"\"\"\n",
        "    center_indices = np.array(range(0, num_sectors))\n",
        "    sector_angle = 2 * np.pi / num_sectors\n",
        "    angles = (center_indices + 0.5) * sector_angle\n",
        "    x = 0.7 * np.cos(angles)\n",
        "    y = 0.7 * np.sin(angles)\n",
        "    labels = 2 * np.remainder(np.floor_divide(angles, sector_angle), 2) - 1\n",
        "\n",
        "    return x, y, labels\n",
        "\n",
        "\n",
        "def make_double_cake_data(num_sectors):\n",
        "    x1, y1, labels1 = _make_circular_data(num_sectors)\n",
        "    x2, y2, labels2 = _make_circular_data(num_sectors)\n",
        "\n",
        "    # x and y coordinates of the datapoints\n",
        "    x = np.hstack([x1, 0.5 * x2])\n",
        "    y = np.hstack([y1, 0.5 * y2])\n",
        "\n",
        "    # Canonical form of dataset\n",
        "    X = np.vstack([x, y]).T\n",
        "\n",
        "    labels = np.hstack([labels1, -1 * labels2])\n",
        "\n",
        "    # Canonical form of labels\n",
        "    Y = labels.astype(int)\n",
        "\n",
        "    return X, Y"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sE-g9OvE-1Lw"
      },
      "source": [
        "Next, we define a function to help plot the `DoubleCake` data:\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 107,
      "metadata": {
        "id": "dTzNJQG2-1Lw"
      },
      "outputs": [],
      "source": [
        "def plot_double_cake_data(X, Y, ax, num_sectors=None):\n",
        "    \"\"\"Plot double cake data and corresponding sectors.\"\"\"\n",
        "    x, y = X.T\n",
        "    cmap = mpl.colors.ListedColormap([\"#FF0000\", \"#0000FF\"])\n",
        "    ax.scatter(x, y, c=Y, cmap=cmap, s=25, marker=\"s\")\n",
        "\n",
        "    if num_sectors is not None:\n",
        "        sector_angle = 360 / num_sectors\n",
        "        for i in range(num_sectors):\n",
        "            color = [\"#FF0000\", \"#0000FF\"][(i % 2)]\n",
        "            other_color = [\"#FF0000\", \"#0000FF\"][((i + 1) % 2)]\n",
        "            ax.add_artist(\n",
        "                mpl.patches.Wedge(\n",
        "                    (0, 0),\n",
        "                    1,\n",
        "                    i * sector_angle,\n",
        "                    (i + 1) * sector_angle,\n",
        "                    lw=0,\n",
        "                    color=color,\n",
        "                    alpha=0.1,\n",
        "                    width=0.5,\n",
        "                )\n",
        "            )\n",
        "            ax.add_artist(\n",
        "                mpl.patches.Wedge(\n",
        "                    (0, 0),\n",
        "                    0.5,\n",
        "                    i * sector_angle,\n",
        "                    (i + 1) * sector_angle,\n",
        "                    lw=0,\n",
        "                    color=other_color,\n",
        "                    alpha=0.1,\n",
        "                )\n",
        "            )\n",
        "            ax.set_xlim(-1, 1)\n",
        "\n",
        "    ax.set_ylim(-1, 1)\n",
        "    ax.set_aspect(\"equal\")\n",
        "    ax.axis(\"off\")\n",
        "\n",
        "    return ax"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5W9DtU_s-1Lw"
      },
      "source": [
        "Let\\'s now have a look at our dataset. In our example, we will work with\n",
        "3 sectors:\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 108,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 406
        },
        "id": "9fVgxo-2-1Lw",
        "outputId": "44fcfa20-c1d4-4bf9-97ad-65fca917e2c6"
      },
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<Figure size 640x480 with 1 Axes>"
            ],
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGFCAYAAAASI+9IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAmc0lEQVR4nO3d2XLcOnuF4UWy5261JstJ5dD3sa9+34cPk/olW0NP7IFDDmDKkizJPZAEQL5PlcreiX4Zlm0ufsAHIMjzPBcAAJJC2wMAALiDUAAAPCMUAADPCAUAwDNCAQDwjFAAADwjFAAAzwgFAMCzju0BAKXLMinPzY/v/bzYr/ly32aeS0Hw+usU/x0EUhiaj/d+/vZ/B3iMUIA/skxKUylJzI/Fz98+/OsWBK/DIop+f3Q6v39OeMADAcdcwCnFg/7tgz9N7Tzwy/Q2JIqfd7sEBpxBKMCeNJW2W2m3+/3h+4P/WJ2OCYfio9cjKGAFoYB6JMnrh3+bA2Bfb4Oi2zXTU0CFCAVUY7eTNhvzQQCUp9MxVUS/bz4ICZSMUEA50vR3CGw2hEBdut3fAcGUE0pAKOA4ef46BJLE9ogQBK+riG7X9ojgIUIB+0tTKY6l9dosEMNtYWjCYTg0P1JFYA+EAj6XpiYE4pgg8FkQSIMBAYG/IhTwpywzIUAQNFMYvg4I4AVCAUYRBOu1WSNAOxAQeINQaLv1WlqtzI9otzCURiNpPDY7rdFKhEIbZZkJguXSrBkAb/X7JhwGA9sjQc0IhTbZbk0QrNevTwgFPhJFJhxGIzbKtQSh0HR5btYKlkuzsxg4RtG9NB6bvRBoLEKhqZLEBEEcs7sY5ep2TeUwGtHa2kCEQtMkiTSfmzAAqhSG0mRiqgfCoTEIhabY7UwY0EWEuhEOjUIo+I4wgCvC0ATDeMyitMcIBV8RBnAV4eA1QsE3260JA3Ydw3VB8HtaiXDwBqHgiySRnp4IA/inCIfJhDUHDxAKrssyUxksl7ZHApwmiqTp1JyzBGcRCq7KcxME8zm7j9Es3a50fs4mOEcRCg5ar6Xd00pn6aPtoQDVGQ5N5cDhe04hFBzydtngRnfqiqMp0GCsNziHUHBAnv9eNnj5p9HXRtf6aW9gQF06HVM1cCqrdYSCZeu1qQ4+OsH6SvcaiL0IaInBwKw3MKVkDaFgSZZJs5m51uAzkVJ91a0C8ceElghDEwx0KVlBKFjwt+rgrbN8prNgUe2gANcMBtLFBRvfakYo1Gjf6uCtQLm+6laRuCUNLUPVUDtCoSaHVgdvDRXrUg/lDgrwBVVDbQiFih1bHbzni36op+3pXwjwEVVDLQiFCm020uPj8dXBWz1t9UU/yvligK+oGipFKFRkNpMWFawNX+QPGgXcqoaWC0Pp8lLq922PpHEIhZKlqfTwYE64rgItqsAL06nZDY3SEAol2mxMIGRZtb/ORAtNNav2FwF8wXRSqQiFkszn5qMOtKgCb0SRdHVlTmDFSQiFE2WZWUyu+1bMgda60n29vyjgsiAw00njse2ReI1QOMFuJ93fl9dddKhr/VRf3MQGvDIcmukkTl09CqFwpOXSdBjZ/O51tdON7uwNAHBVp2Omkzod2yPxDqFwoDw3O5PL2IxWhnM9aSyu6gT+EASmbZXjuA/Ccv0BssxMF7kSCJI015ky/hiBP+W5+QfL/eYH4WmypzSVfvz4fSuaKzKFmuvM9jAAdz09mble7IVQ2MNuZwIhSWyP5H1LjZWIuVPgQ4uF2UTEbPlfEQp/sdmYQLDVYbSvJ53bHgLgtjiWfv6sfnep5wiFT6xW5u+QDy8XG/W1FgtqwKe2Wz/e8iwiFD4wn5tNaT6Zaapc9GYDn0oS6e7OzAvjD4TCG3luwqCuIyvKlKijpdjNCfxVlrnZOeIAQuGFooPNpZbTQy00oUUV2Eeem/nhmKPoX+Lp8UsRCL6/OGQKNdPU9jAAfzw8EAwvEApqTiAUVhppJ06LBPZGMDxrfSg0LRAKtKgCByIYJLU8FJoaCJK0VU+xuOAcOAjB0N5QaHIgFGhRBY7Q8mBoZSi0IRAkKVWkhbi/FjhYi4OhdaHQlkAoLDRRqsj2MAD/tDQYWhUKbQsEScoV0KIKHKuFwdCqUHh8bFcgFGINtVXP9jAAPz08tOrB0ZpQmM1aF/iv0KIKnOD+vjVnJbUiFJZLc5x6m+3U1Uoj28MA/FQcidGC01UbHwrrtbl4CbSoAifJslbcx9DoUNhuzXQgDK7uBE6UJGYqyYdLVo7U2FBowZ/dUbi6EzjRduvfZSsHaGQoZJkJhIZXeUehRRUoQRyb7pUGalwoFHsRksT2SNy11kAb9W0PA/DbYmG6WBqmcaHw8GCqO3yOFlWgBE9PppulQRoVCvN54/58KsPVnUBJHh4aNTUR5HkzlmI3G9Mthv2FyvRVtwrF4ktrff/+8RTIeCx9+1bveHzV6Ug3N1Lgf8t3I0IhTaW7OxaWjzHWUudiI0crff8u/fPP55/z778Ew76GQ+ny0vYoTub99FGem+qNQDgOLaotts8iaQMXUisTx434fnkfCrMZC8unYtEZKMls5v0ZSV6HQkOC2bqN+lprYHsYgP+KnniPpy68DYUkafSmwtpxLhJQkjT1+uHkZSgUYez/Erk7aFEFSrRee3s0s5eh8PjYqLZgZ8x1pszPvxKAe2YzLy/n8e4JsFq1+7KcKnEuUsuM96gM9/kcfMzD1kiv9imkqXR7y7RR1W50p6787qDAnti8Vr3BQLq6sj2KvXkVCj9/elmNeaenrb7oh+1hAM1xeWk2t3nAm+mj5ZJAqMtWPcXy4y8w4IWnJ2+mkbwIhTRt7NHlzqJFFShRlnnTpupFKDw8sI5Qt1SRFprYHgbQHOu1F10yzofCcskxFrYsNFGqyPYwgObwYBrJ6VBIEqaNbKJFFSiZB9NITofC4yPTRrbFGmqrnu1hAM3h+DSSs6HAtJE7OEUVKNnTk+mgcZCToUC3kVt26mqlke1hAM2RZSYYHORkKDw9MW3kGlpUgZKt105eKu9cKDj6fWq9TKHmOrM9DKBZZjPn3oCdCoU8Z9rIZVzdCZQsSZy7KcypUFguORLbZbSoAhWYz51adHbmtS9NzfcGbltroI366svfg6hyBcoVKFP47jpJ8PwZuUK5vdEIDVBMkVxe2h6JJIdCwcGpNXzgSef6qlvbw3hXplCJOkoVKVX0/PNM4fPHoQvmxf8yUG6+ap6oE5hfoaNEkVIF4i8vThDH5qjynv09QU4cnb3bSXd3tkeBQ5zrSWPZnQtN1NFO3Vcftm6O6yh5M5IdVQYO0+1KNze2R+FGpeBouy4+MdeZhoprffDt1NVGfW3UtxoA70nUUaLOqyPHO0rU0/bXiDeEBD6325mKwfK9C9YrhfVaur+3OQIca6ylzlVdomcKn9cwNuo7FQLH6Gr3HBA+r8mgQlEkff0qBfb2BFkNhTw300Z0HPmr7Ks7iyCINdRG/dK+rmtCZRporaFiAgKvnZ2ZD0ushsJyydSR7/ra6Fo/T/oauQLFGjY+CD4SKvv1u4/VEwd+tV4QSP/1X1JopzK2Fgp5Lt3eOtWeiyNd6V4DHb4NPVFHS40Va+j91FBZutpppJVGWtHR1GYWqwVroUCV0ByRUn3V7d4PsTgfaBWMW1kV7CtQrpFWGmupjphfbZ0wNGsLFqoFK6FAldA8U8000eLD/3+uQCuNuM3tCAOtdaZ5qWs38IClasFKS2ocEwhNU7SoRnr9B0sYnG6tgdYaqK+NzjRn3aEtlktpMqm9E8lKpfCf/xAKTTTSShd6lEQYVIlwaBEL1ULtlcJqRSA01UojjbRSmoeaBeeEQUWKnQ5DxZpq9kd1hgaxUC3Uvoqx+HjaGZ7raqdEHT0EVwRCDWINdauvXIDUZFlW+9HatVYKccxGtSYKlelM8+ezkDbqvzruAdXJFWihiWINNdVMQ7l7ITyOtFiYw/JqqhZqrRQ4Grt5+troRnevDsebakaPfc1SRXrQpe51xZ6Ppqm5Wqjtbw9VQrOEynShR13r5x9z2pHST9tTUZ21BrrVVyq1plksartboLZQYC2hOYrqYKTVh58z0YIFUEsyhVQNTZNl5s26BrX8jdluzamw8FugXOd6erc6eO9zpzlb1m0qqgZ2jjdETVNItYTC6uMXSngiUqpr/TzoYp1hsOYEUMsyhfqpa81l79RNlGS3M2/YFas8FGqselCRYrromM1SU80qGBEONdeZfuqa6STf1VAtVP43ZLXi7mWfnWmua/08+tawrnYa5Xav7YSxUV93utFOXdtDwbHWa/OmXaHKQ6HmfRcoSaBcl3rQmU7vI54Gc66idESqSD/0RWsNbA8Fx8jzyufjKw2FzYYjLXwUKtO1fpa2ESpUpknOJhVX5Ap0rystNbY9FByj4jftSkOBKsE/kVJ90Y/SD1sbByvuBXDMk84109T2MHCoNDXTSBWpLBQqHjcq0NVOX/Sjkod30c4Ktyw00YMuOTvJNxW+cVcWClQJfulqt9f+g1P0tVE/503BNbGGBINvKpybryQUalgLQYmKQKhjMfg84FwkF601IBh8U9GbdyWhUEPXFEpSZyBIUkfJQRvgUB+CwTMVvXlXFgpwX92BUJhoQYuqowgGj2SZmUYqWemhkOeEgg86SqwEgmRaVNnp7K61BnrUhe1hYB8VHBdReiis1+xgdl2oTFe6t/q2PtJKXXFKoqtiDWlX9UEFb+ClhwLnHLktUK4r3TuxZ4AWVbctNGGDm+sqmEIq9TrOPK9kigslutRD6RvTjtXTVkPFrb4Q5vv3j5tIxmPp27d6x/PWk84VKdVAzAk7K46lfnnHo5caCkwdue1cT879455qprUGrVzY/P5d+uefzz/n33/tB8ODLvVFP5juc1Xx4C3pDudSp4+YOnLXULGTraBtvrpznzZzFzaBFmclcey2o0qeQirtT7mi7iiUoKNEF3q0PYwPcXWn+1JFetCl7WHgIyW+kZcWCkwdualYWHZ5F3GgnBZVD2zU5wY3V5X4AC4tFJg6ctOFHp3oNPqboWJnFsDxsbnOuPPZRSV2+ZQSCkwduWmkVWl3ItSBFlU/POiS9QUXlfRmXsqfLIHgnkipdw/ZrnYaiZMUXZcpZMezi1yqFAgF91zo0el1hI+cae7luI8x3mNf2D6fY8Nag1bvL3FSlkm709uGgzw/fXXiP//h2k2XjLX0rkp4aaFJa45YcH3z2mdCZfqqWw43dMl0Kk0mJ32Jk0MhTU0owA2RUn3Vrddv27kC3eqrUkW2h4K/GGitK93bHgYK/b50fX3Slzh5+ogTUd3i67TRS1zd6Q+mkRyz3Z7cmnpyKGzpInTGULH6asYCz0Drxvxemm6maSuPKXFSnp/8UD45FFhkdkMTN4A17ffTVKkiNrW55MSH8kmhsNtx7aYrmnhURFc7J89rwp+WGrMG5AqboUCV4IYmHyp3pjndLR7IFbSmY8x5J76tEwoN0OTe/lCZzjS3PQzsIdZQW/VsDwPSSQ/no0OhhPUMlKCjpPG7gMdaenF+E8TagitshEIJnU8oQVOnjd6iRdUPG/WpFlxwwhv70aFQwm5qnChS2vgqodDXxrlb4/A+qgUHJMnR6wqEgsfaNtc+1ayxaydNslFfO3VtDwNHPqQJBU9FSr06FrsMHSW0qHqCasEBdYZClpnqBPZMtGjlWzMtqn5Ya6BEHdvDaLc6Q4Eqwa5AeeuqhEITd2431VKOnvvdFoRCewwVt/pteaSVuuIvoetWGnEmkk1JclSLKKHgIebVaVH1Qa5Aq5wTVK064mFNKHimpy1vyTLfh7ZOoflkFTCFZNUR+xUODoU8Z5HZprbsS9gHLaru26nLZjab6qgUqBLsafMC83uafBBgk3AJj0V1hALnHdnT14Y34zeaeGR40xAKFh2x2HxwKDB1ZA9Vwp9oUXVfplAb9W0Po70OrBYIBU8Eyjn75wNDxeqJEtZlcT6wPYT2Sg+rpA8OhQO/Pkoy0Jqpo0/Qouq2dcAUkjVVhkKeEwq2UCV8rqsdnVkOYwrJogOndw4KBQLBnr645u5vaFF1G6FgSZWVAqFgR1e7Vh9rsS+u7nQboWBJlZUCi8x2UCXsj6s73bVTV9lp18LjGFQKzUMo7I8WVbdRLVhywBs9lYLjAuW0Wx5ooDVB6ihCwZID3uipFBzX05bF0yPQouomzkGypKpKgVCoHyeiHoerO92UqMO6gg1VVApZZj5QL0LheFzd6aaduraH0D5VhQLqRygcjxZVNxEKFhAKzRAqo73yRLSouodQsOCABzih4DCqhHKw6OwWQsGCA47PJhQcxhtuOfracHaUQxJ1lCuwPYx2qaJSOPCeBpSAUCgP5yK5JVVkewjtkud7P8SpFBzGjWLloUXVLUlOKNSOUPBflFMplIkWVXekQcf2ENpnz4c400cO6wRUCmXiXCR3MH1kQdmhQKVQr1AZc+AVGGlFV5cDElEp1I5Q8BvrCdWhRdU+KgULWFPwG3Pf1elpq6Fi28NoNVpSLWBNwW9MHVWLFlW7OBTPAkLBb1QK1YqUaqKF7WG0FqFgQdnTR4RCvQiF6k20YO3GIqaQ3NTaFoDv36XlB3uZxmPp27d6x/NWLaHg+jehYkWL6oMubQ+llTKFhHKd9nyzb2UofP8u/fPP55/z7792n4mVz3f78E2owVCxlhpzI5gFVApuauX00Ucvx4d+jtf4JjyjRdUOQsFNrPag9braaaSV7WEATiAUHEW7ZL162toeAuAEQsFRlNb1yRVopqntYQBO2DsUAp5RaKi5zuibt4Bq2E38S3AUlUI9UkVaamx7GIAzWhkK4z2eAft8TpUqDwUfvgk1eNI5AWwJGzRrtud0T5Dn+zWb/t//Nast1fV9W2Mtq2+VdP2bULGN+vqpa9vDaK3/0f/aHkK7TCbS9O9rZ3tvXguCZoWC68+7Wua4Xf8mVOxJ57aH0FqsJ1iwZ6Ww95MnbOVEkz0sfFZrqTEXvVjE1JEFZYcC3Uf1Yp67OplCzXVmexitRihYsOebPZWCo6gUqkMLqn2EggWEgt+4rrAaiTq0oDqA01EtYE3Bb7kCgqECLC67oaPE9hDap+xKgTWF+rEQWq61Btqob3sYEJWCFUwf+Y9KoTycb+SWKKdSqB2h4D9CoTy0oLqlE1Ap1I41Bf/xECsHLahuCZXRfVS3Ax7grCk4bKeu7SE0wkxT9n04pKud7SG0TxWhQKVQv0Qd+ulPtFNXK41sDwMvEAoWHPBWTyg4jmrhNLSguodQsKCKSiGKmEKygVA4XqyhturZHgbe4OpTCzr7r08e9P4f0QxTO0LhOLSguilUxh4FGw54eBMKjuNN9zgLTWjpdRBTR5ZUVSkc8HVRklQRrakHShVpoYntYeAdfW1sD6GdqBSahaMZDkMLqrsIBUsIhWYhFPa3VU+xhraHgXeEypg+siEMq+k+kpg+soV1hf3RguouqgRLDnybp1LwQKaQYNjDSiO6tRxGKFhy4Nv8QaFwYBWCEq01sD0Ep9GC6r6B1raH0E5VVgpHfH2UhFD4HFdsuq2vDYfg2VJlpXDE10dJEnWYGvkAV2y6b6jY9hDai0qhueiqeR8tqO5j6siiqiuFLi+r1hAKf9qoz9Sa45g6sigMq68UejTBWJMqogvpDVpQ3cfUkUVHvMUfNX1EB5I93A3wG1dsui9URijYVEcoHPnroCSxhnTZiCs2fTFUrEC57WG0F6HQfLkCrXLWFmhB9cNYS9tDaLcj5vsJBQ+tgna3X9KC6oe+NuoosT2M9jpikVk6MhRYbLYrUafVh+SxuOwHqgTLjnx7PyoUWGy2r61vymsNWh2IvoiUsjfBtjpD4YRfDyVZa9C6Hc6cb+SPST63PQQQCu3Ttu4bWlD9ECnVKKAN1boj5/kJBY+1qVqgBdUfEy1oQ7XtyEVm6YRQYLHZDfO8HXcRc76RHyKlGmllexg44a396FCIIk5MdcE6GDZ+SmWnLju5PUGV4Ij+8c0YJ/UQnfDrokRNX3ylBdUPHSVUCa4gFNqtyW2asYYcAuiJqWZUCS4IQzvTRxLrCi5p4ts0Laj+6GnLvgRXnPi2flIohCHB4IomHv2w0ESpuNXJB+d6sj0EFGyGQgm/PkrUpEPiUkVaqB2dVb4ba6mudraHgQKhgEKmsDHTSLSg+iFSqjOxe9kZnc7JdyafHArdrhTwb9cZsYbeX0+5VY+rRz1xoUeu2nRJCW/pJ4dCEFAtuOZRF15PIzWl2mm6kVbqa2N7GHiphIdxKbue+n1pTeOBM4pppEs92B7KwVYateboDkn6/l1afnDC9HgsfftW73j2FSllcdlFLoUC3BJrqKFir9oEM4WtakH9/l3655/PP+fff90Mhgs9sifBNb1eKXP5pcwxdDoceeGiR1141dK50MTraa9DfVQhHPo5dRtrybSRiwblrCWW9i9wyLqgczKFetClF108Tdxn0UQ9bTXVzPYw8J6SHsKlhUJJIYWSbdXzYkqGFlT3hcp0qQemjVzU653cilooLRS6XaaQXLXU2OkWz4363rfRtsGlHhQptT0MvKfEqZpSJ3CZQnLXoy6cPWKbFlT3nWnOOoLLSpyqIRRaIlegn7p2biGXKzbdN1TMrmWXlTh1JJUcCp0O13S6LFWke105M3ff9is2x3usq+/zOVXqa6MLPdodBD5X8tt4kOd5qatG87n5gLsGWutK97aHoSedt77jyOXNax0l+qIfHGPhuv/+b3NkdUlKD4UkkW5vy/yKqMJYS6s7Unfq6k431n59fC5Sqi/6wcKy6/p96fq61C9Z+gQzU0h+WGpsderGhzbZtgqV6Ur3BIIPKljIrWTVkQVnP8x1ZuXOgiZfH+q7UJmu9ZP7EXxRwQYxQqHlZprWGgy5AlpQHUUgeGYwKHUtoVBJKEQRO5x9UmcwLDX26jymtiAQPDQaVfJlK2tat91Kh8PMNK18jSFV1OoWVFcRCB6q8M27slDo9zn2wjdznVU6tTPLz5zZIwGjaDslEDxT4Vt3pdtbK6puUKGlxpVscNuqpzjgL4RLetrqi36oo8T2UHCIIKj04Vp5KHB/s3/WGpR+JAYtqG4ZaK1r/WRjmo8qWmAuVBoKYUgnkq+26ulON6WcSxRrqK16JYwKZRhrqSvdcwS2rypesK38dDSmkPyVKtKdbk46djtXQJXgiEC5LvXA3co+63bNAXgVqjwUej12OPssV6AHXepJ50etM8zzCS2oDugo0Y3uNFRseyg4RQ1tnbWco0x7qv+WGuunrg96wKeKtAzq3zGN14aKdaM7FpR9V9N8fC2hMBxWui6CmhTrDPtOJx1bXaAcgXJd6JErNJtiOKylc6eWR3XFHVSoUaZQD7rUgy4/7U7iik27+troq2410sr2UFCWmqZcant/n0xoT22SWEPd6uuHD34Wl+0IlOtcT7rWT045bZLhsLbdwLWFQhhSLTRNplD3uvqjalhqrJ3oLqhbUR2M9cGtPfDXWX3Hw9R6EMVkIq1WUrnX+sC2WENt1NeZ5hoq5nyjmkVKNc2fNAzWtoeCKtRYJUg1h0IUmWrho+sH4a9MoZ50riyI1M133JdQg0C5JlpoooWCgDetxprU28FX+5F1VAvN1VGiST5XoFzrvK9ZcF7Kjmj8aaSVzjRn3aDpBoPaN3rV/i+WaqG5zvX03Po4CDYa6FZrDTTXGWsMJRlppYkW7DloixrXEgpWXuOoFppnoLX62rz7fx9oTTicIFCuoWIqg7axUCVIlkIhiszayYoW6kYIlGuq2aef8zIclhqz5rCHUJmGijXRgjBoIwtVgmQpFCTz+yUUmmGs5d7TGUU4JOpopZFWGpV6RHcTdLXTWEsNFbMTua36fWuHxgV5bm8S5/GRYPBdqExfdXv0ufy5AsUaaqVRq4/XLqaIRlqpp63t4cC2L18qPw31I1ZbQ87OpDhmbcFnU81OuqglUP6rXlgpVaRYQ8UatmLtIVCugdYaKlZfG6oCGIOBtUCQLFcKkjSbSYuFzRHgWF3tdKO7Sr52os6viaZBoyqIUJn62jxPoxEE+MPXr1YvuLceClkm3d6aH+GXL/pRy1RHplBb9bRRXxv1vdr7EChXT9tfI9+oq53tIcFl47F0fm51CNb/dYWhmUZ64jIorwwV1zb3HSp7frOWzD0NW/W0U/f5w5XF6o6SF6Paqact1QD2UzwMLbMeCpIJx+VSStiP44V9WlCrFCn9tfLw+xaxIiiSPFIadJQqUqJOJbe+BcrVUaJIqaI8USdIn8PglPUVtNxk4sTFM06EgiRNp9L9ve1RYB8u9s0XQfH2Tp9cgTKFShU9/7z4yBV8eAlQoPzFZ2bP/x0pff3g5zh4lCGKnLmi0plQGAzMx5qDHp0WKdVE/nQGBMrNG71jIQa8cn7uzIUz9muVF6ZTZ74v+MBUM+bIgTL1++aN2BFOhUKn40wFhXf0tH01jw/gREFgvdvoLadCQTKL71H5a4MowbloEQNKNR5b3ZPwHudCIQjMNBLcMtKKHnugTFHkRAvqW86FgmROUHVoiq31bLegAo3k0OLyS06GgmS+Xw607ELSmeb03wNlGo2cffN19rEbRc6tv7RSR4nG4po8oDRR5PQcubOhIDGN5AJaUIGSOT4N4u7Ifrm4cPr712jFaZ4ASuLwtFHB+cdtGDKNZAstqECJHJ82KjgfChLTSDYccsUmgD14Mu3h/gh/8eT72QihMp1pbnsYQHOMRuY4Cw9485hlGqk+tKACJfJk2qjgTShITCPVgRZUoGSeTXO4dejGHi4upLs7KeUk5EqwuNwy37+bG67eMx5L377VO56mmUy8mTYqeBcKYShdXUk/fkh2b5dunoHW6mtjexioy/fv0j//fP45//5LMByr1/Nq2qjgT03zQrfr5ffaaZxv1EIfVQiHfg7+FIbS5aXtURzFy1CQTGU7HNoeRXPQggqU6PLS2zsAvA0FyawvOHYUuZdoQQVKNJ16t47wktehEARmfcHB02e9wvlGQEkGA7O47DGvQ0EylcLFhe1R+KurnUZa2R4G4L8oasTDyPtQkMzaAnc7H4cWVKAExbSFR/sRPuL/7+CX6dR0gGF/Q8XqaWt7GLBlnzcp3rb2M52atsgGCPK8Od3+aWo2tmWc0PBXgXJ91a0isQuw1di8drrRqBHTRoVGhYIk7XZsbNvHmeZ0HAGn6vcb1+3SmOmjQrfr7Z6R2kRKNdHC9jAAv3U65mHToECQGhgKkukK40TVj9GCCpwoiqTr60YsLL/VvN/RL+Ox9+3Clehpq6Fi28MA/FV0Gnm6Y/lvGhsKkmkI4CiM12hBBU50ddWYTqP3NDoUJNMUQKuqMdJKXe1sDwPw18WF10dY7KPxoVBUem0/I4lTUIETnZ2Z9tOGa3woSGYt6Pq6sVOAe+GKTeAEo5EJhRZoRShIJhAasgv9YFyxCZyg32/U5rS/adUjstttZzDQggocqdic1iItezyaRec2BUNfGw20tj0MwD8N3K28j5Y8Gl9rUzDQggocoaWBILU0FKR2BANXbAJHaHEgSC0OBanZwcAVm8ARWh4IUstDQWpuMNCCChyIQJBEKEhqXjB0taMFFTgEgfCsIY/B0zUpGNi5DByAQHilAY/A8vR6/u98Hmitvja2hwH4YTgkEN4gFN7odqUvX/w8KylQTgsqsK/JpJGX5JyKUHhHFJlg8O101bGW3LkM7OP83Jytjz8QCh8oDtHz5T6GSCktqMDfFMcmj8e2R+IsQuETQWCqSx9ucDvTnPONgM8Ub3qDge2ROM3DmfP6TadmSunJ0en6rnYaaWV7GIC7Oh0uVtkT36E9jccmGB4epNyxF3IWl4FPNKnfvAZ8lw4wGJgFaJf+bg0Vq6et7WEAbhoMzJSRS/9oHcd36kDdrnRz40ZnEldsAp84O2MPwhGCPHdtMsQPeS7NZtLS4mkSZ5rTcQS8FYamQ6Tftz0SLxEKJ4pj6fGx/nWGSKm+6paOI+ClXs8Egs/HEljGQvOJhkMzpXR/LyU1Xl3AFZvAG+OxaRVkuugkVAolyXNTMcRx9b9WT1t90Y/qfyHAB0EgXVz4s9PUcYRCyZZLs9ZQ5Xf1RnfqalfdLwD4gv0HpeM7WbLx2Exr3t9LaQXHEI3ypboBgQBoODQVAtNFpaJSqEiWmemk9bq8rxkq01fdcqMa2i0IzIF2o5HtkTQSoVCxODbHY2QlPMenmmmixelfCPBVv2+qA7qLKkMo1KCMqqGjRDe6o+MI7RQEprOI000rRyjU6JSq4Sr/qUHAjWpoIaqDWhEKNTumaujna10H95WNCXAS1YEVhIIl+1YNgXLd6E4d1bgzDrCNncnWEAoW7VM1jPOFzgMOvUNLUB1YRyg4II7Nhre3+xpoQUWrDAam1ZTqwCpCwRF5Ls3nZkd08Sdynj9qHHCjGhqu0zFhwKmmTiAUHJMkv6qG9U43urM9HKA6QWDuPBiP2ZXsEELBUdl6q3D2WO/Rq0BdhsPfl5/DKYSC65ZLM69UxpZowLZez0wVdbu2R4IPEAo+yDJpsXi94AD4pNMxlcFgYHsk+AtCwSdpaqqGFYvP8EQUSZOJObyOdQMvEAo+KsIhjqkc4CbCwFuEgs/S1EwrrVaEA9wQRaajaDgkDDxFKDQB4QDbOp3flQG8Rig0CeGAunU6vysDNAKh0ERFt9JqRSsrqtHtmsqAMGgcQqHJ8twsRi+X0o57nXGiIDAhMBqZ/QZoJEKhLXY7Ew50LOFQnY4JgtFICkPbo0HFCIW2yTIzrbRacYQGPjcYmHOJOKiuVQiFNttsTPVwyuXRaJYwNBXBeMy5RC1FKMB0LcWx+WDtoX2CwFQFw6GpCthf0GqEAl5LElM5EBDNFgQmAIZDEwgEAX4hFPCxJPldQbD+4D+CAHsgFLAfAsJPQWDaR4sgoHsIf0Eo4HBJYhapiw/+CrklikwA9HqmMiAIcABCAafJc7P2UATEdmt7RO0ThubhX3zQNYQTEAooV5aZYChCgqmm8hVTQkUIcIsZSkQooFppakJit/v9wXlMh+l0zIO/+Oj1WCRGZQgF1I+g+NjLAOj1zI8EAGpEKMANafo7IJLEfKRpc8MiikwAFD8SAHAEoQC3ZZkJhzT9HRTFj2nqbudTGL5+8L/9OQ9/OIpQgN+KaiLLTEAUP3/73y9/Lr0Ok5c/f/uwDoLfH2H4++Nv/00bKDxFKAAAnvE6AwB4RigAAJ4RCgCAZ4QCAOAZoQAAeEYoAACeEQoAgGeEAgDgGaEAAHj2/0Wof1QaeofMAAAAAElFTkSuQmCC\n"
          },
          "metadata": {}
        }
      ],
      "source": [
        "import matplotlib.pyplot as plt\n",
        "\n",
        "num_sectors = 3\n",
        "X, Y = make_double_cake_data(num_sectors)\n",
        "\n",
        "ax = plot_double_cake_data(X, Y, plt.gca(), num_sectors=num_sectors)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YiOE-Kte-1Lw"
      },
      "source": [
        "Defining a Quantum Embedding Kernel\n",
        "===================================\n",
        "\n",
        "PennyLane\\'s [kernels\n",
        "module](https://pennylane.readthedocs.io/en/latest/code/qml_kernels.html)\n",
        "allows for a particularly simple implementation of Quantum Embedding\n",
        "Kernels. The first ingredient we need for this is an *ansatz*, which we\n",
        "will construct by repeating a layer as building block. Let\\'s start by\n",
        "defining this layer:\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 109,
      "metadata": {
        "id": "HoxftykS-1Lw"
      },
      "outputs": [],
      "source": [
        "import pennylane as qml\n",
        "\n",
        "\n",
        "def layer(x, params, wires, i0=0, inc=1):\n",
        "    \"\"\"Building block of the embedding ansatz\"\"\"\n",
        "    i = i0\n",
        "    for j, wire in enumerate(wires):\n",
        "        qml.Hadamard(wires=[wire])\n",
        "        qml.RZ(x[i % len(x)], wires=[wire])\n",
        "        i += inc\n",
        "        qml.RY(params[0, j], wires=[wire])\n",
        "\n",
        "    qml.broadcast(unitary=qml.CNOT, pattern=\"ring\", wires=wires) #, parameters=params[1])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Vz-Xc85B-1Lx"
      },
      "source": [
        "To construct the ansatz, this layer is repeated multiple times, reusing\n",
        "the datapoint `x` but feeding different variational parameters `params`\n",
        "into each of them. Together, the datapoint and the variational\n",
        "parameters fully determine the embedding ansatz $U(\\boldsymbol{x})$. In\n",
        "order to construct the full kernel circuit, we also require its adjoint\n",
        "$U(\\boldsymbol{x})^\\dagger$, which we can obtain via `qml.adjoint`.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 110,
      "metadata": {
        "id": "0u3Ie4pX-1Lx"
      },
      "outputs": [],
      "source": [
        "def ansatz(x, params, wires):\n",
        "    \"\"\"The embedding ansatz\"\"\"\n",
        "    for j, layer_params in enumerate(params):\n",
        "        layer(x, layer_params, wires, i0=j * len(wires))\n",
        "\n",
        "\n",
        "adjoint_ansatz = qml.adjoint(ansatz)\n",
        "\n",
        "\n",
        "def random_params(num_wires, num_layers):\n",
        "    \"\"\"Generate random variational parameters in the shape for the ansatz.\"\"\"\n",
        "    return np.random.uniform(0, 2 * np.pi, (num_layers, 2, num_wires), requires_grad=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-3A0BtiZ-1Lx"
      },
      "source": [
        "Together with the ansatz we only need a device to run the quantum\n",
        "circuit on. For the purpose of this tutorial we will use PennyLane\\'s\n",
        "`default.qubit` device with 5 wires in analytic mode.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 111,
      "metadata": {
        "id": "VajRrjnM-1Lx"
      },
      "outputs": [],
      "source": [
        "dev = qml.device(\"default.qubit\", wires=5, shots=None)\n",
        "wires = dev.wires.tolist()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LBQiS_k5-1Lx"
      },
      "source": [
        "Let us now define the quantum circuit that realizes the kernel. We will\n",
        "compute the overlap of the quantum states by first applying the\n",
        "embedding of the first datapoint and then the adjoint of the embedding\n",
        "of the second datapoint. We finally extract the probabilities of\n",
        "observing each basis state.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 112,
      "metadata": {
        "id": "YUoDYzxl-1Lx"
      },
      "outputs": [],
      "source": [
        "@qml.qnode(dev, interface=\"autograd\")\n",
        "def kernel_circuit(x1, x2, params):\n",
        "    ansatz(x1, params, wires=wires)\n",
        "    adjoint_ansatz(x2, params, wires=wires)\n",
        "    return qml.probs(wires=wires)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vp3JMl6k-1Lx"
      },
      "source": [
        "The kernel function itself is now obtained by looking at the probability\n",
        "of observing the all-zero state at the end of the kernel circuit --\n",
        "because of the ordering in `qml.probs`, this is the first entry:\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 113,
      "metadata": {
        "id": "FzkdR5rH-1Lx"
      },
      "outputs": [],
      "source": [
        "def kernel(x1, x2, params):\n",
        "    return kernel_circuit(x1, x2, params)[0]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "v8d7TzEV-1Lx"
      },
      "source": [
        "::: {.note}\n",
        "::: {.title}\n",
        "Note\n",
        ":::\n",
        "\n",
        "An alternative way to set up the kernel circuit in PennyLane would be to\n",
        "use the observable type\n",
        "[Projector](https://pennylane.readthedocs.io/en/latest/code/api/pennylane.Projector.html).\n",
        "This is shown in the [demo on kernel-based training of quantum\n",
        "models](https://pennylane.ai/qml/demos/tutorial_kernel_based_training.html),\n",
        "where you will also find more background information on the kernel\n",
        "circuit structure itself.\n",
        ":::\n",
        "\n",
        "Before focusing on the kernel values we have to provide values for the\n",
        "variational parameters. At this point we fix the number of layers in the\n",
        "ansatz circuit to $6$.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 114,
      "metadata": {
        "id": "XND4t8Re-1Lx"
      },
      "outputs": [],
      "source": [
        "init_params = random_params(num_wires=5, num_layers=6)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "D13009dA-1Lx"
      },
      "source": [
        "Now we can have a look at the kernel value between the first and the\n",
        "second datapoint:\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 115,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 0
        },
        "id": "EPIxD8tx-1Lx",
        "outputId": "4b351f5a-f601-4ea0-ed23-8279cae3e773"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "The kernel value between the first and second datapoint is 0.006\n"
          ]
        }
      ],
      "source": [
        "kernel_value = kernel(X[0], X[1], init_params)\n",
        "print(f\"The kernel value between the first and second datapoint is {kernel_value:.3f}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zftXzrUX-1Lx"
      },
      "source": [
        "The mutual kernel values between all elements of the dataset form the\n",
        "*kernel matrix*. We can inspect it via the\n",
        "`qml.kernels.square_kernel_matrix` method, which makes use of symmetry\n",
        "of the kernel,\n",
        "$k(\\boldsymbol{x}_i,\\boldsymbol{x}_j) = k(\\boldsymbol{x}_j, \\boldsymbol{x}_i)$.\n",
        "In addition, the option `assume_normalized_kernel=True` ensures that we\n",
        "do not calculate the entries between the same datapoints, as we know\n",
        "them to be 1 for our noiseless simulation. Overall this means that we\n",
        "compute $\\frac{1}{2}(N^2-N)$ kernel values for $N$ datapoints. To\n",
        "include the variational parameters, we construct a `lambda` function\n",
        "that fixes them to the values we sampled above.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 116,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 0
        },
        "id": "EJ6j5XF1-1Lx",
        "outputId": "f1d08b57-896d-461e-d9bd-a5a1d371352e"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[[1.    0.006 0.035 0.567 0.04  0.029]\n",
            " [0.006 1.    0.063 0.067 0.574 0.101]\n",
            " [0.035 0.063 1.    0.048 0.07  0.603]\n",
            " [0.567 0.067 0.048 1.    0.21  0.204]\n",
            " [0.04  0.574 0.07  0.21  1.    0.285]\n",
            " [0.029 0.101 0.603 0.204 0.285 1.   ]]\n"
          ]
        }
      ],
      "source": [
        "init_kernel = lambda x1, x2: kernel(x1, x2, init_params)\n",
        "K_init = qml.kernels.square_kernel_matrix(X, init_kernel, assume_normalized_kernel=True)\n",
        "\n",
        "with np.printoptions(precision=3, suppress=True):\n",
        "    print(K_init)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nRfsymWl-1Lx"
      },
      "source": [
        "Using the Quantum Embedding Kernel for predictions\n",
        "==================================================\n",
        "\n",
        "The quantum kernel alone can not be used to make predictions on a\n",
        "dataset, becaues it is essentially just a tool to measure the similarity\n",
        "between two datapoints. To perform an actual prediction we will make use\n",
        "of scikit-learn\\'s Support Vector Classifier (SVC).\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 117,
      "metadata": {
        "id": "k6YZKy7J-1Lx"
      },
      "outputs": [],
      "source": [
        "from sklearn.svm import SVC"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "s3GDnc3R-1Ly"
      },
      "source": [
        "To construct the SVM, we need to supply `sklearn.svm.SVC` with a\n",
        "function that takes two sets of datapoints and returns the associated\n",
        "kernel matrix. We can make use of the function\n",
        "`qml.kernels.kernel_matrix` that provides this functionality. It expects\n",
        "the kernel to not have additional parameters besides the datapoints,\n",
        "which is why we again supply the variational parameters via the `lambda`\n",
        "function from above. Once we have this, we can let scikit-learn adjust\n",
        "the SVM from our Quantum Embedding Kernel.\n",
        "\n",
        "::: {.note}\n",
        "::: {.title}\n",
        "Note\n",
        ":::\n",
        "\n",
        "This step does *not* modify the variational parameters in our circuit\n",
        "ansatz. What it does is solving a different optimization task for the\n",
        "$\\alpha$ and $b$ vectors we introduced in the beginning.\n",
        ":::\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 118,
      "metadata": {
        "id": "irVOPzix-1Ly"
      },
      "outputs": [],
      "source": [
        "svm = SVC(kernel=lambda X1, X2: qml.kernels.kernel_matrix(X1, X2, init_kernel)).fit(X, Y)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "I4xR2Txu-1Ly"
      },
      "source": [
        "To see how well our classifier performs we will measure which percentage\n",
        "of the dataset it classifies correctly.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 119,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 0
        },
        "id": "ZDX6Qvwk-1Ly",
        "outputId": "560bf8f1-7cb8-41e8-c20b-51797641b4bc"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "The accuracy of the kernel with random parameters is 1.000\n"
          ]
        }
      ],
      "source": [
        "def accuracy(classifier, X, Y_target):\n",
        "    return 1 - np.count_nonzero(classifier.predict(X) - Y_target) / len(Y_target)\n",
        "\n",
        "\n",
        "accuracy_init = accuracy(svm, X, Y)\n",
        "print(f\"The accuracy of the kernel with random parameters is {accuracy_init:.3f}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PPB_3tDn-1Ly"
      },
      "source": [
        "We are also interested in seeing what the decision boundaries in this\n",
        "classification look like. This could help us spotting overfitting issues\n",
        "visually in more complex data sets. To this end we will introduce a\n",
        "second helper method.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 120,
      "metadata": {
        "id": "EQZNfZjQ-1Ly"
      },
      "outputs": [],
      "source": [
        "def plot_decision_boundaries(classifier, ax, N_gridpoints=14):\n",
        "    _xx, _yy = np.meshgrid(np.linspace(-1, 1, N_gridpoints), np.linspace(-1, 1, N_gridpoints))\n",
        "\n",
        "    _zz = np.zeros_like(_xx)\n",
        "    for idx in np.ndindex(*_xx.shape):\n",
        "        _zz[idx] = classifier.predict(np.array([_xx[idx], _yy[idx]])[np.newaxis, :])\n",
        "\n",
        "    plot_data = {\"_xx\": _xx, \"_yy\": _yy, \"_zz\": _zz}\n",
        "    ax.contourf(\n",
        "        _xx,\n",
        "        _yy,\n",
        "        _zz,\n",
        "        cmap=mpl.colors.ListedColormap([\"#FF0000\", \"#0000FF\"]),\n",
        "        alpha=0.2,\n",
        "        levels=[-1, 0, 1],\n",
        "    )\n",
        "    plot_double_cake_data(X, Y, ax)\n",
        "\n",
        "    return plot_data"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "26F1Bqg9-1Ly"
      },
      "source": [
        "With that done, let\\'s have a look at the decision boundaries for our\n",
        "initial classifier:\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 121,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 406
        },
        "id": "JQ33viIF-1Ly",
        "outputId": "7bf10573-ddec-4d53-dc6a-3ad1457ddb91"
      },
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<Figure size 640x480 with 1 Axes>"
            ],
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGFCAYAAAASI+9IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAK4klEQVR4nO3dPVLjzBqA0Z5bRFqEc8feCQtlJ46dexGk/gLQexlGBsmW1H/npHQVjaXuB7XxzJ/b7XZLAJBS+l/uCQBQDlEAIIgCAEEUAAiiAEAQBQCCKAAQRAGA8DJ75Pm84TQA2Nzp9OsQTwoABFEAIIgCAEEUAAiiAEAQBQCCKAAQRAGAIAoABFEAIIgCAEEUAAiiAEAQBQCCKAAQRAGAIAoABFEAIIgCAEEUAAiiAEAQBQCCKAAQRAGAIAoABFEAIIgCAEEUAAiiAEAQBQCCKAAQRAGAIAoABFEAIIgCAEEUAAiiAEAQBQCCKAAQRAGAIAoABFEAIIgCAEEUAAiiAEAQBQCCKAAQRAGAIAoABFEAIIgCAEEUAAiiAEAQBQCCKAAQRAGAIAoABFEAIIgCAEEUAAiiAEAQBQCCKAAQRAGAIAoABFEAIIgCAEEUAAgvuScAZHS9pvT+Pv21YUjpcNh3PmQnCtCr6zWl19efx7y9CUNnZkfhfBm2nEcxTsc7vzVRhfE+dR1nuPeEsHQMTZn/pHA8bjiNQlwu6XwZbCiVOl+Gj/vUdYSHeaP5q8/w9fJU1JIIQkquIzxBFL6zoVTnryCMXEd4iChMsaFUYzIII9cRFhOFe2woxfsxCCPXERYRhZ/YUIo1Kwgj13HaMOP1mDOGpvy53W63OQPP562nUrDLJaXkzxxLsSgIX7mO//Lhtb6cTr8OEYW5bChFeDgII9eRns2IguOjuRxBZPd0EFJyHeEXorCEDSWbVYIwch3hLlFYyoayu1WDMHIdYZIoPMKGsptNgjByHeEfovAoG8rmNg3CyHWEv4jCM2wom9klCCPXEYIoPKuHfz22B8IAKSVRgP8TBhAF+Isw0DlRgO+EgY6JAkwRBjolCnCPMNAhUYCfCAOdEQX4jTDQEVGAOYSBTogCzCUMdEAUYAlhoHGiAEsJAw0TBXiEMNCol9wTgGodjyldLlnC4P+YZiuiAM/I8a/kfoZIGNiC4yOojaMrNiQKUCNhYCOiALUSBjYgClAzYWBlogC1EwZWJArQAmFgJaIArRAGViAK0BJh4EmiAK0RBp4gCtAiYeBBogCtEgYeIArQMmFgIVGA1gkDC4gC9EAYmEkU1nA8WmyUTxiYQRTWIgzUQBj4hSisSRiogTDwA1FYmzBQA2HgDlHYgjBQA2FggihsRRiogTDwjShsSRiogTDwhShsTRiogTDwSRT2IAzUQBhIorAfYaAGwtC9l9wT6MrxmM6XSzod33PPhEJcrym937kdhiGlw2Hf+aSUPsJwuaTzZXCvdkgU9iYMfLpeU3p9/XnM25swsC/HRzk4SiLdf0JYOmYzjpK65EkhF08M1ODLE8PerI08RCEnYaAGn08Me7M28nB8lJujJJhmbWQhCiVw88M0a2N3olAKNz9MszZ2JQolcfN3ZZhxqeeM6YK1sZs/t9vtNmfg+bz1VAjeYPvYADK9wbmnIj+8VjJr4zmn069D/PVRifxVUjds+gtZG5tzfFQqj8swzdrYlCiUzM0P06yNzYhC6dz8MM3a2IQo1MDND9OsjdWJQi3c/DDN2liVKNTEzQ/TrI3ViEJt3PwwzdpYhSgA7RCGp4kC0BZheIooAO0RhoeJAtAmYXiIKADtEobFRAFomzAsIgpA+4RhNlEA+iAMs4gC0A9h+JUoAH0Rhh+JAtAfYbhLFIA+CcMkUQD6JQz/eMk9AYCsjsd0vlxyz2IXp9PvY0QB4HjMPYNiOD4CIIgCAEEUAAiiAEAQBQCCKAAQRAGAIAoABFEAIIgCAEEUAAiiAEAQBQCCKAAQRAGAIAoABFEAIIgCAEEUAAiiAEAQBQCCKAAQRAGA8JJ7Arlcrym9v09/bRhSOhz2nU8WXgTgmy6jcL2m9Pr685i3t8b3RC8CMKHL46N7vxwvHVM1LwIwocsoADBNFAAIogBAEAUAgigAELqMwjCsM6ZqXgRgQpefUzgcPv4Ev+vPbXkRgAldRiEl+11KyYsA/KPL4yMApokCAEEUAAiiAEAQBQCCKAAQRAGAIAqVOl/a/bRxyz8blE4UanQ8ppTa3DzjZ/r8GYF9iUKtGgyDIEB+olCzhsIgCFAGUahdA2EQBCiHKLSg4jAIApRFFFpRYRgEAcojCi2pKAyCAGUShdZUEAZBgHKJQosKDoMgQNlEoVUFhkEQoHyi0LKCwiAIUAdRaF0BYRAEqIco9CBjGAQB6vKSewLs5HhM6XLJ88QgCFANUeiJzRn4heMjAIIoABBEAYAgCgAEUQAgiAIAQRQACKIAQBAFAIIoABBEAYAgCgAEUQAgiAIAQRQACKIAQBAFAIIoABBEAYAgCgAEUQAgiAIA4SX3BKBn12tK7+/TXxuGlA6HfecDogCZXK8pvb7+PObtTRjYl+MjyOTeE8LSMbAmUQAgiAIAwXsKQFkul3zf+3jM970LIQpAOT6DcDru/2bK+TJ8fP/Ow+D4CChDxiD89X1zPqkUQBQgk2FYZ0wTMgdhJAwp/bndbrc5A8/nracC/fHhtVRMEL46Xz5r3NhR0un0+xhRAPIpMAijFsMwJwqOj4A8Cg5CSv0eJYkCsL/CgzDqMQyiAOyrkiCMeguDKAD7qSwIo57CIArAPioNwqiXMIgCsL3KgzDqIQyiAGyrkSCMWg+DKADbaSwIo5bDIArANhoNwqjVMIgCsL7GgzBqMQyiAKyrkyCMWguDKADr6SwIo5bCIArAOjoNwqiVMIgC8LzOgzBqIQyiADxHEP5SexhEAXicIEyqOQyiADxGEH5UaxhEAVhOEGapMQyiACwjCIvUFgZRAOYThIfUFAZRAOYRhKfUEgZRAH4nCKuoIQwvuSfAjgq+ESmfIKzjdHxP58uQZz2ejr8OEYVe+E0PilHyOnR81ANBAGYShdYJArCAKLRMEICFRKFVggA8QBRaJAjAg0ShNYIAPEEUWiIIwJNEoRWCAKxAFFogCMBKRKF2ggCsSBRqJgjAykShVoIAbEAUaiQIwEZEoTaCAGxIFGoiCMDGRKEWggDsQBRqIAjATkShdIIA7EgUSiYIwM5EoVSCAGQgCiUSBCCTl9wT4BtBYE/Xa0rvd+61YUjpcNh3PmQnCiURBPZ0vab0+vrzmLc3YeiM46NSCAJ7u/eEsHQMTRGFEggCUAhRyE0QgIKU+57C52bZA0EASlFmFPz2DJBFecdHggCQTVlREATYzzCsM4amlHN8JAiwr8Ph43MIPrzGF2VEQRAgD5s+3+Q/PhIEgGLkjYIgABQlXxQEAaA4eaIgCABF2j8KggBQrH2jIAgARdsvCoIAULx9oiAIAFXYPgqCAFCNbaMgCABV2S4KggBQnW2iIAgAVVo/CoIAUK11oyAIAFVbLwqCAFC9daIgCABNeD4KggDQjOeiIAgATXk8CoIA0JzHoiAIAE1aHgVBAGjWsigIAkDT5kdBEACat+hJQRAA2jY7CoIA0L59/49mAIomCgAEUQAgiAIAQRQACKIAQBAFAIIoABBEAYAgCgAEUQAgiAIAQRQACKIAQBAFAIIoABBEAYAgCgAEUQAgiAIAQRQACKIAQBAFAIIoABBEAYAgCgAEUQAgiAIAQRQACKIAQBAFAIIoABBEAYAgCgAEUQAgiAIAQRQACKIAQBAFAIIoABBEAYDw53a73XJPAoAyeFIAIIgCAEEUAAiiAEAQBQCCKAAQRAGAIAoABFEAIPwHavhre0oF7xUAAAAASUVORK5CYII=\n"
          },
          "metadata": {}
        }
      ],
      "source": [
        "init_plot_data = plot_decision_boundaries(svm, plt.gca())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TLeU4kMX-1Ly"
      },
      "source": [
        "We see the outer points in the dataset can be correctly classified, but\n",
        "we still struggle with the inner circle. But remember we have a circuit\n",
        "with many free parameters! It is reasonable to believe we can give\n",
        "values to those variational parameters which improve the overall\n",
        "accuracy of our SVC.\n",
        "\n",
        "Training the Quantum Embedding Kernel\n",
        "=====================================\n",
        "\n",
        "To be able to train the Quantum Embedding Kernel we need some measure of\n",
        "how well it fits the dataset in question. Performing an exhaustive\n",
        "search in parameter space is not a good solution because it is very\n",
        "resource intensive, and since the accuracy is a discrete quantity we\n",
        "would not be able to detect small improvements.\n",
        "\n",
        "We can, however, resort to a more specialized measure, the\n",
        "*kernel-target alignment*. The kernel-target alignment compares the\n",
        "similarity predicted by the quantum kernel to the actual labels of the\n",
        "training data. It is based on *kernel alignment*, a similiarity measure\n",
        "between two kernels with given kernel matrices $K_1$ and $K_2$:\n",
        "\n",
        "$$\\operatorname{KA}(K_1, K_2) = \\frac{\\operatorname{Tr}(K_1 K_2)}{\\sqrt{\\operatorname{Tr}(K_1^2)\\operatorname{Tr}(K_2^2)}}.$$\n",
        "\n",
        "::: {.note}\n",
        "::: {.title}\n",
        "Note\n",
        ":::\n",
        "\n",
        "Seen from a more theoretical side, $\\operatorname{KA}$ is nothing else\n",
        "than the cosine of the angle between the kernel matrices $K_1$ and $K_2$\n",
        "if we see them as vectors in the space of matrices with the\n",
        "Hilbert-Schmidt (or Frobenius) scalar product\n",
        "$\\langle A, B \\rangle = \\operatorname{Tr}(A^T B)$. This reinforces the\n",
        "geometric picture of how this measure relates to objects, namely two\n",
        "kernels, being aligned in a vector space.\n",
        ":::\n",
        "\n",
        "The training data enters the picture by defining an *ideal* kernel\n",
        "function that expresses the original labelling in the vector\n",
        "$\\boldsymbol{y}$ by assigning to two datapoints the product of the\n",
        "corresponding labels:\n",
        "\n",
        "$$k_{\\boldsymbol{y}}(\\boldsymbol{x}_i, \\boldsymbol{x}_j) = y_i y_j.$$\n",
        "\n",
        "The assigned kernel is thus $+1$ if both datapoints lie in the same\n",
        "class and $-1$ otherwise and its kernel matrix is simply given by the\n",
        "outer product $\\boldsymbol{y}\\boldsymbol{y}^T$. The kernel-target\n",
        "alignment is then defined as the kernel alignment of the kernel matrix\n",
        "$K$ generated by the quantum kernel and\n",
        "$\\boldsymbol{y}\\boldsymbol{y}^T$:\n",
        "\n",
        "$$\\operatorname{KTA}_{\\boldsymbol{y}}(K)\n",
        "= \\frac{\\operatorname{Tr}(K \\boldsymbol{y}\\boldsymbol{y}^T)}{\\sqrt{\\operatorname{Tr}(K^2)\\operatorname{Tr}((\\boldsymbol{y}\\boldsymbol{y}^T)^2)}}\n",
        "= \\frac{\\boldsymbol{y}^T K \\boldsymbol{y}}{\\sqrt{\\operatorname{Tr}(K^2)} N}$$\n",
        "\n",
        "where $N$ is the number of elements in $\\boldsymbol{y}$, that is the\n",
        "number of datapoints in the dataset.\n",
        "\n",
        "In summary, the kernel-target alignment effectively captures how well\n",
        "the kernel you chose reproduces the actual similarities of the data. It\n",
        "does have one drawback, however: having a high kernel-target alignment\n",
        "is only a necessary but not a sufficient condition for a good\n",
        "performance of the kernel. This means having good alignment is\n",
        "guaranteed for good performance, but optimal alignment will not always\n",
        "bring optimal training accuracy with it.\n",
        "\n",
        "Let\\'s now come back to the actual implementation. PennyLane\\'s\n",
        "`kernels` module allows you to easily evaluate the kernel target\n",
        "alignment:\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 122,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 0
        },
        "id": "N1sKPuO3-1Ly",
        "outputId": "349e0b66-48df-40c7-ad7b-cbc1ad93aaee"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "The kernel-target alignment for our dataset and random parameters is 0.130\n"
          ]
        }
      ],
      "source": [
        "kta_init = qml.kernels.target_alignment(X, Y, init_kernel, assume_normalized_kernel=True)\n",
        "\n",
        "print(f\"The kernel-target alignment for our dataset and random parameters is {kta_init:.3f}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eOZMWm-p-1Ly"
      },
      "source": [
        "Now let\\'s code up an optimization loop and improve the kernel-target\n",
        "alignment!\n",
        "\n",
        "We will make use of regular gradient descent optimization. To speed up\n",
        "the optimization we will not use the entire training set to compute\n",
        "$\\operatorname{KTA}$ but rather sample smaller subsets of the data at\n",
        "each step, we choose $4$ datapoints at random. Remember that\n",
        "PennyLane\\'s built-in optimizer works to *minimize* the cost function\n",
        "that is given to it, which is why we have to multiply the kernel target\n",
        "alignment by $-1$ to actually *maximize* it in the process.\n",
        "\n",
        "::: {.note}\n",
        "::: {.title}\n",
        "Note\n",
        ":::\n",
        "\n",
        "Currently, the function `qml.kernels.target_alignment` is not\n",
        "differentiable yet, making it unfit for gradient descent optimization.\n",
        "We therefore first define a differentiable version of this function.\n",
        ":::\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 123,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 0
        },
        "id": "o8BNL8bP-1Ly",
        "outputId": "43ec31f0-b683-4347-b66b-e7b451bd7b6f"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Step 50 - Alignment = 0.156\n",
            "Step 100 - Alignment = 0.177\n",
            "Step 150 - Alignment = 0.192\n",
            "Step 200 - Alignment = 0.227\n",
            "Step 250 - Alignment = 0.263\n",
            "Step 300 - Alignment = 0.300\n",
            "Step 350 - Alignment = 0.312\n",
            "Step 400 - Alignment = 0.316\n",
            "Step 450 - Alignment = 0.320\n",
            "Step 500 - Alignment = 0.324\n"
          ]
        }
      ],
      "source": [
        "def target_alignment(\n",
        "    X,\n",
        "    Y,\n",
        "    kernel,\n",
        "    assume_normalized_kernel=False,\n",
        "    rescale_class_labels=True,\n",
        "):\n",
        "    \"\"\"Kernel-target alignment between kernel and labels.\"\"\"\n",
        "\n",
        "    K = qml.kernels.square_kernel_matrix(\n",
        "        X,\n",
        "        kernel,\n",
        "        assume_normalized_kernel=assume_normalized_kernel,\n",
        "    )\n",
        "\n",
        "    if rescale_class_labels:\n",
        "        nplus = np.count_nonzero(np.array(Y) == 1)\n",
        "        nminus = len(Y) - nplus\n",
        "        _Y = np.array([y / nplus if y == 1 else y / nminus for y in Y])\n",
        "    else:\n",
        "        _Y = np.array(Y)\n",
        "\n",
        "    T = np.outer(_Y, _Y)\n",
        "    inner_product = np.sum(K * T)\n",
        "    norm = np.sqrt(np.sum(K * K) * np.sum(T * T))\n",
        "    inner_product = inner_product / norm\n",
        "\n",
        "    return inner_product\n",
        "\n",
        "\n",
        "params = init_params\n",
        "opt = qml.GradientDescentOptimizer(0.2)\n",
        "\n",
        "for i in range(500):\n",
        "    # Choose subset of datapoints to compute the KTA on.\n",
        "    subset = np.random.choice(list(range(len(X))), 4)\n",
        "    # Define the cost function for optimization\n",
        "    cost = lambda _params: -target_alignment(\n",
        "        X[subset],\n",
        "        Y[subset],\n",
        "        lambda x1, x2: kernel(x1, x2, _params),\n",
        "        assume_normalized_kernel=True,\n",
        "    )\n",
        "    # Optimization step\n",
        "    params = opt.step(cost, params)\n",
        "\n",
        "    # Report the alignment on the full dataset every 50 steps.\n",
        "    if (i + 1) % 50 == 0:\n",
        "        current_alignment = target_alignment(\n",
        "            X,\n",
        "            Y,\n",
        "            lambda x1, x2: kernel(x1, x2, params),\n",
        "            assume_normalized_kernel=True,\n",
        "        )\n",
        "        print(f\"Step {i+1} - Alignment = {current_alignment:.3f}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lhbr4YTr-1Lz"
      },
      "source": [
        "We want to assess the impact of training the parameters of the quantum\n",
        "kernel. Thus, let\\'s build a second support vector classifier with the\n",
        "trained kernel:\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 124,
      "metadata": {
        "id": "vUg8wNDJ-1Lz"
      },
      "outputs": [],
      "source": [
        "# First create a kernel with the trained parameter baked into it.\n",
        "trained_kernel = lambda x1, x2: kernel(x1, x2, params)\n",
        "\n",
        "# Second create a kernel matrix function using the trained kernel.\n",
        "trained_kernel_matrix = lambda X1, X2: qml.kernels.kernel_matrix(X1, X2, trained_kernel)\n",
        "\n",
        "# Note that SVC expects the kernel argument to be a kernel matrix function.\n",
        "svm_trained = SVC(kernel=trained_kernel_matrix).fit(X, Y)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LgE9Tcmn-1Lz"
      },
      "source": [
        "We expect to see an accuracy improvement vs. the SVM with random\n",
        "parameters:\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 125,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 0
        },
        "id": "SGUIEF1X-1Lz",
        "outputId": "e189e6b0-5e12-4fb5-9473-3c8667b6bcb2"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "The accuracy of a kernel with trained parameters is 1.000\n"
          ]
        }
      ],
      "source": [
        "accuracy_trained = accuracy(svm_trained, X, Y)\n",
        "print(f\"The accuracy of a kernel with trained parameters is {accuracy_trained:.3f}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "N5tAagq1-1Lz"
      },
      "source": [
        "We have now achieved perfect classification! 🎆\n",
        "\n",
        "Following on the results that SVM\\'s have proven good generalisation\n",
        "behavior, it will be interesting to inspect the decision boundaries of\n",
        "our classifier:\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 126,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 406
        },
        "id": "CHXRBR7p-1Lz",
        "outputId": "fe39d39a-1cfd-4702-c4b0-e1eb696be350"
      },
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<Figure size 640x480 with 1 Axes>"
            ],
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGFCAYAAAASI+9IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAL0ElEQVR4nO3dPVYjuxaA0XPfIvIgnDtmJgyUmRA79yBIeQHNuX27bfBPVUlH2jtFa5WW26qPUlvmn4+Pj48AgIj4X+sJANAPUQAgiQIASRQASKIAQBIFAJIoAJBEAYD0dO3At7c1pwF/OB7j+fDeehYwlufnH4d4UgAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIF+nM8tp4BTEsU6MuvIDwf3htPBOYkCvRDEKA5UaAPggBdEAXaEwTohijQliBAV0SBdgQBuiMKtCEI0CVRYHuCAN0SBbYlCNA1UWA7ggDdEwW2IQhQgiiwPkGAMkSBdQkClPLUegIMbOMgvB13m1znHNGjZ19r4/n557GiwDpaBeFw2OR6f1//KAx06da1YfuI5U0WhK9rt3xSgXPuWRuiwLJmDMIXYaAj964NUWA5MwfhizDQgUfWhiiwDEH4lzDQ0KNrQxR4nCD8TRhoYIm1IQo8RhAuEwY2tNTaEAXuJwg/EwY2sOTaEAXuIwjXEwZWtPTaEAVuJwi3EwZWsMbaEAVuIwj3EwYWtNbaEAWuJwiPEwYWsObaEAWuIwjLEQYesPbaEAV+1urrr0cMwhdh4A5b/LIkCnzP30OALmz19CwKXCYI0IUtt1NFgfMEAbqw9f+viQJ/EwToQosPXIgC/yUI0IVWn8ATBf4lCNCFlh/JFgU+CQJ0ofUZHVFAEKATrYMQIQoIAnShhyBEiMLcBAG60EsQIkRhXoIAXegpCBERT60nQAOCwJfTKeL9wvtgt4vY77edz2R6C0KEKMxHEPhyOkW8vHw/5vVVGFbSYxAibB/NRRD43aUnhFvHcLNegxAhCvMQBOhCz0GIEIU5CAJ0ofcgRIjC+AQBulAhCBGiMDZBgC5UCUKEKIxLEKALlYIQIQpjEgSusbvib0RfM4aLqgUhwjmF8QgC19rvP88hOLy2iopBiBCFsQgCt3LTX0XVIETYPhqHIEAXKgchQhTGIAjQhepBiBCF+gQBujBCECJEoTZBgC6MEoQIUahLEKALIwUhQhRqEgTowmhBiBCFsgRhDHlToa6BghAhCtDOr5uJMNATUYCWhIHOiAK0Jgx0RBSgB8JAJ0QBeiEMdEAUoCfCQGOiAL0RBhoSBeiRMNCIKECvhIEGRAF6JgxsTBSgd8LAhkQBKhAGNiIKUIUwsAFRgEqEgZWJAlQjDKxIFKAiYWAlogBVCQMrEAWoTBhYmChAdcLAgkQBRiAMLEQUYBTCwAJEAUYiDDxIFGA0wsADRAFGJAzcSRRgVMLAHUQBRiYM3EgUYHTCwA1EAWYgDFxJFGAWwsAVRAFmIgz8QBRgNsLAN0QBZiQMXCAKMCth4AxRgJkJA38QBZidMPAbUQCEgSQKwCdhIEQB+J0wTO+p9QRgZqdTxPv7+Z/tdhH7/bbziYjPMByP8XbcxfPhwuQYlihAI6dTxMvL92NeX4WBbdk+gkYuPSHcOmY1tpKmJArAZcIwHVEAvicMUxEF4GfCMA1RAK4jDFMQBeB6wjA8UYBGdlfcV68ZszlhGJpzCtDIfv95DqG7w2vXcI5hWKIADXV707+GMAzJ9hFwP1tJwxEF4DHCMBRRAB4nDMMQBWAZwjAEUQCWIwzliQKwLGEoTRSA5QlDWaIArONXGKhFFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgPTUegKtnE4R7+/nf7bbRez3286nCS8C8Icpo3A6Rby8fD/m9XXwe6IXAThjyu2jS78c3zqmNC8CcMaUUQDgPFEAIIkCAEkUAEiiAECaMgq73TJjSvMiAGdMeU5hv//8CP7U57a8CMAZU0Yhwv0uIrwIwF+m3D4C4DxRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogCs43iM58N761lwI1EAlicIZYkCsCxBKE0UgOUIQnmiACxDEIYgCsDjBGEYogA8RhCGIgrA/QRhOKIA3EcQhvTUegIws9Mp4v3CfXW3i9jvt53P1QRhWKIAjZxOES8v3495fe0wDIIwNNtH0MilJ4Rbx2xKEIYnCsB1BGEKogD8TBCmIQrA9wRhKqIAXCYI0xEF4DxBmJIoQCO73TJjViEI03JOARrZ7z/PIXR3eE0QpiYK0JCDafTG9hHwSRAIUQAiBIEkCjA7QeA3ogAzEwT+IAowK0HgDFGAGQkCF4gCzEYQ+IYowEwEgR+IQlFvx1bff0BZgsAVRKGiwyEihIEbCAJXEoWqhIFrCQI3EIXKhIGfCAI3EoXqhIFLBIE7iMIIhIE/CQJ3EoVRCANfBIEHiMJIhAFB4EGiMBphmJcgsABRGJEwzEcQWIgojEoY5iEILEgURiYM4xMEFiYKoxOGcQkCKxCFGQjDeASBlYjCLIRhHILAikRhJsJQnyCwMlGYjTDUJQhsQBRmJAz1CAIbEYVZCUMdgsCGRGFmwtA/QWBjojA7YeiXINCAKCAMPRIEGhEFPglDPwSBhkSBfwlDe4JAY6LAfwlDO4JAB0SBvwnD9gSBTogC5wnDdgSBjogClwnD+gSBzogC3xOG9QjCGI7H1jNYlCjws43D8Hbc5TWhZxn1gcIgClxnozAIAtWMFgZR4Horh0EQqGqkMIgCt1kpDIJAdaOEQRS43cJhEARGMUIYRIH7LBQGQWA01cMgCtzvwTAIAqOqHAZR4DF3hkEQGF3VMDy1ngADOBwijsd4O+6uOowlCB05nSLeL/yb7XYR+/228xnM8+H98/1+PJZ5z4sCy7gyDILQkdMp4uXl+zGvr8LwoGphsH3Ecn7YShKEzlx6Qrh1DD+qtJUkCizrQhgEgdlVCYMosLw/wiAI8KlCGESBdfweBkGA1HsYRIH1HA6CAGf0HAZRAGig1zCIAsxqd8WBw2vGcLcew+CcAsxqv/88h+DwWlO9nWMQBZiZm34XegqD7SOADvSylSQKAJ3oIQyiANCR1mEQBYDOtAyDKAB0qFUYRAGgUy3CIAoAHds6DKIA0LktwyAKAAVsFQZRgBY6+q4b6rjmb6A/ShRga7+CsMUCh1uJAmxJEOicKMBWBIECRAG2IAgUIQqwNkGgEFGANQkCxYgCrEUQKEgUYA2CQFGiAEsTBAoTBViSIFCcKMBSBIEBiAIsQRAYhCjAowSBgYgCPEIQGIwowL0EgQGJAtxDEBiUKMCtBIGBiQLcQhAYnCjAtQSBCYgCXEMQmIQowE8EgYmIAnxHEJiMKMAlgsCERAHOEQQmJQrwJ0FgYqIAvxMEJicK8EUQQBQgIgQBfhEFEASKeDvuIg6HVa8hCsxNEChiiyBEiAIzEwSK2CoIEaLArASBIrYMQoQoMCNBoIitgxAhCsxGECiiRRAiRIGZCAJFtApChCgwC0GgiJZBiBAFZiAIFNE6CBGiwOgEgSJ6CEKEKDAyQaCIXoIQIQqMShAooqcgRIgCIxIEiugtCBGiwGgEgSJ6DEKEKDASQaCIXoMQIQqMQhAooucgRIgCIxAEiug9CBGiQHWCQBEVghAhClQmCBRRJQgRokBVgkARlYIQIQpUJAgUUS0IEaJANYJAERWDECEKVCIIFFE1CBGiQBWCQBGVgxAhClQgCBRRPQgRokDvBIEiRghChCjQM0GgiFGCECEK9EoQKGKkIERE/PPx8fHRehIA9MGTAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgDp/7iJjaKJGzOmAAAAAElFTkSuQmCC\n"
          },
          "metadata": {}
        }
      ],
      "source": [
        "trained_plot_data = plot_decision_boundaries(svm_trained, plt.gca())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "71P66YVn-1Lz"
      },
      "source": [
        "Indeed, we see that now not only every data instance falls within the\n",
        "correct class, but also that there are no strong artifacts that would\n",
        "make us distrust the model. In this sense, our approach benefits from\n",
        "both: on one hand it can adjust itself to the dataset, and on the other\n",
        "hand is not expected to suffer from bad generalisation.\n",
        "\n",
        "References\n",
        "==========\n",
        "\n",
        "About the authors\n",
        "=================\n"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.17"
    },
    "colab": {
      "provenance": []
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}