[404218]: / Code / All Qiskit, PennyLane QML Nov 23 / 15a Geometric QML HRYEmbed kkawchak.ipynb

Download this file

1307 lines (1307 with data), 372.7 kB

{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "_VCz6stSPtCC"
      },
      "outputs": [],
      "source": [
        "# This cell is added by sphinx-gallery\n",
        "# It can be customized to whatever you like\n",
        "%matplotlib inline\n",
        "# !pip install pennylane"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "exXC6t6_PtCD"
      },
      "source": [
        "Introduction to Geometric Quantum Machine Learning\n",
        "==================================================\n",
        "\n",
        "::: {.meta}\n",
        ":property=\\\"og:description\\\": Using the natural symmetries in a quantum\n",
        "learning problem can improve learning :property=\\\"og:image\\\":\n",
        "<https://pennylane.ai/qml/_images/equivariant_thumbnail.jpeg>\n",
        ":::\n",
        "\n",
        "::: {.related}\n",
        "tutorial\\_equivariant\\_graph\\_embedding A permutation equivariant graph\n",
        "embedding\n",
        ":::\n",
        "\n",
        "*Author: Richard East --- Posted: 18 October 2022.*\n",
        "\n",
        "Introduction\n",
        "------------\n",
        "\n",
        "Symmetries are at the heart of physics. Indeed in condensed matter and\n",
        "particle physics we often define a thing simply by the symmetries it\n",
        "adheres to. What does symmetry mean for those in machine learning? In\n",
        "this context the ambition is straightforward --- it is a means to reduce\n",
        "the parameter space and improve the trained model\\'s ability to\n",
        "sucessfully label unseen data, i.e., its ability to generalise.\n",
        "\n",
        "Suppose we have a learning task and the data we are learning from has an\n",
        "underlying symmetry. For example, consider a game of Noughts and Crosses\n",
        "(aka Tic-tac-toe): if we win a game, we would have won it if the board\n",
        "was rotated or flipped along any of the lines of symmetry. Now if we\n",
        "want to train an algorithm to spot the outcome of these games, we can\n",
        "either ignore the existence of this symmetry or we can somehow include\n",
        "it. The advantage of paying attention to the symmetry is it identifies\n",
        "multiple configurations of the board as \\'the same thing\\' as far as the\n",
        "symmetry is concerned. This means we can reduce our parameter space, and\n",
        "so the amount of data our algorithm must sift through is immediately\n",
        "reduced. Along the way, the fact that our learning model must encode a\n",
        "symmetry that actually exists in the system we are trying to represent\n",
        "naturally encourages our results to be more generalisable. The encoding\n",
        "of symmetries into our learning models is where the term *equivariance*\n",
        "will appear. We will see that demanding that certain symmetries are\n",
        "included in our models means that the mappings that make up our\n",
        "algorithms must be such that we could transform our input data with\n",
        "respect to a certain symmetry, then apply our mappings, and this would\n",
        "be the same as applying the mappings and then transforming the output\n",
        "data with the same symmetry. This is the technical property that gives\n",
        "us the name \\\"equavariant learning\\\".\n",
        "\n",
        "In classical machine learning, this area is often referred to as\n",
        "geometric deep learning (GDL) due to the traditional association of\n",
        "symmetry to the world of geometry, and the fact that these\n",
        "considerations usually focus on deep neural networks (see or for a broad\n",
        "introduction). We will refer to the quantum computing version of this as\n",
        "*quantum geometric machine learning* (QGML).\n",
        "\n",
        "Representation theory in circuits\n",
        "---------------------------------\n",
        "\n",
        "The first thing to discuss is how do we work with symmetries in the\n",
        "first place? The answer lies in the world of group representation\n",
        "theory.\n",
        "\n",
        "First, let\\'s define what we mean by a group:\n",
        "\n",
        "**Definition**: A group is a set $G$ together with a binary operation on\n",
        "$G$, here denoted $\\circ$, that combines any two elements $a$ and $b$ to\n",
        "form an element of $G$, denoted $a \\circ b$, such that the following\n",
        "three requirements, known as group axioms, are satisfied as follows:\n",
        "\n",
        "1.  **Associativity**: For all $a, b, c$ in $G$, one has\n",
        "    $(a \\circ b) \\circ c=a \\circ (b \\circ c)$.\n",
        "\n",
        "2.  \n",
        "\n",
        "    **Identity element**: There exists an element $e$ in $G$ such that, for every $a$ in $G$, one\n",
        "\n",
        "    :   has $e \\circ a=a$ and $a \\circ e=a$. Such an element is unique.\n",
        "        It is called the identity element of the group.\n",
        "\n",
        "3.  \n",
        "\n",
        "    **Inverse element**: For each $a$ in $G$, there exists an element $b$ in $G$\n",
        "\n",
        "    :   such that $a \\circ b=e$ and $b \\circ a=e$, where $e$ is the\n",
        "        identity element. For each $a$, the element $b$ is unique: it is\n",
        "        called the inverse of $a$ and is commonly denoted $a^{-1}$.\n",
        "\n",
        "With groups defined, we are in a position to articulate what a\n",
        "representation is: Let $\\varphi$ be a map sending $g$ in group $G$ to a\n",
        "linear map $\\varphi(g): V \\rightarrow V$, for some vector space $V$,\n",
        "which satisfies\n",
        "\n",
        "$$\\varphi\\left(g_{1} g_{2}\\right)=\\varphi\\left(g_{1}\\right) \\circ \\varphi\\left(g_{2}\\right) \\quad \\text { for all } g_{1}, g_{2} \\in G.$$\n",
        "\n",
        "The idea here is that just as elements in a group act on each other to\n",
        "reach further elements, i.e., $g\\circ h = k$, a representation sends us\n",
        "to a mapping acting on a vector space such that\n",
        "$\\varphi(g)\\circ \\varphi(h) = \\varphi(k)$. In this way we are\n",
        "representing the structure of the group as a linear map. For a\n",
        "representation, our mapping must send us to the general linear group\n",
        "$GL(n)$ (the space of invertible $n \\times n$ matrices with matrix\n",
        "multiplication as the group multiplication). Note how this is both a\n",
        "group, and by virtue of being a collection of invertible matrices, also\n",
        "a set of linear maps (they\\'re all invertble matrices that can act on\n",
        "row vectors). Fundamentally, representation theory is based on the\n",
        "prosaic observation that linear algebra is easy and group theory is\n",
        "abstract. So what if we can study groups via linear maps?\n",
        "\n",
        "Now due to the importance of unitarity in quantum mechnics, we are\n",
        "particularly interested in the unitary representations: representations\n",
        "where the linear maps are unitary matrices. If we can identify these\n",
        "then we will have a way to naturally encode groups in quantum circuits\n",
        "(which are mostly made up of unitary gates).\n",
        "\n",
        "![](../demonstrations/geometric_qml/sphere_equivariant.png){.align-center\n",
        "width=\"45.0%\"}\n",
        "\n",
        "How does all this relate to symmetries? Well, a large class of\n",
        "symmetries can be characterised as a group, where all the elements of\n",
        "the group leave some space we are considering unchanged. Let\\'s consider\n",
        "an example: the symmetries of a sphere. Now when we think of this\n",
        "symmetry we probably think something along the lines of \\\"it\\'s the same\n",
        "no matter how we rotate it, or flip it left to right, etc\\\". There is\n",
        "this idea of being invariant under some operation. We also have the idea\n",
        "of being able to undo these actions: if we rotate one way, we can rotate\n",
        "it back. If we flip the sphere right-to-left we can flip it\n",
        "left-to-right to get back to where we started (notice too all these\n",
        "inverses are unique). Trivially we can also do nothing. What exactly are\n",
        "we describing here? We have elements that correspond to an action on a\n",
        "sphere that can be inverted and for which there exists an identity. It\n",
        "is also trivially the case here that if we consider three operations a,\n",
        "b, c from the set of rotations and reflections of the sphere, that if we\n",
        "combine two of them together then\n",
        "$a\\circ (b \\circ c) = (a\\circ b) \\circ c$. The operations are\n",
        "associative. These features turn out to literally define a group!\n",
        "\n",
        "As we\\'ve seen the group in itself is a very abstract creature; this is\n",
        "why we look to its representations. The group labels what symmetries we\n",
        "care about, they tell us the mappings that our system is invariant\n",
        "under, and the unitary representations show us how those symmetries look\n",
        "on a particular space of unitary matrices. If we want to encode the\n",
        "structure of the symmeteries in a quantum circuit we must restrict our\n",
        "gates to being unitary representations of the group.\n",
        "\n",
        "There remains one question: *what is equivariance?* With our newfound\n",
        "knowledge of group representation theory we are ready to tackle this.\n",
        "Let $G$ be our group, and $V$ and $W$, with elements $v$ and $w$\n",
        "respectively, be vector spaces over some field $F$ with a map $f$\n",
        "between them. Suppose we have representations\n",
        "$\\varphi: G \\rightarrow GL(V)$ and $\\psi: G \\rightarrow GL(W)$.\n",
        "Furthermore, let\\'s write $\\varphi_g$ for the representation of $g$ as a\n",
        "linear map on $V$ and $\\psi_g$ as the same group element represented as\n",
        "a linear map on $W$ respectively. We call $f$ *equivariant* if\n",
        "\n",
        "$$f(\\varphi_g(v))=\\psi_g(f(v)) \\quad \\text { for all } g\\in G.$$\n",
        "\n",
        "The importance of such a map in machine learning is that if, for\n",
        "example, our neural network layers are equivariant maps then two inputs\n",
        "that are related by some intrinsic symmetry (maybe they are reflections)\n",
        "preserve this information in the outputs.\n",
        "\n",
        "Consider the following figure for example. What we see is a board with a\n",
        "cross in a certain square on the left and some numerical encoding of\n",
        "this on the right, where the 1 is where the X is in the number grid. We\n",
        "present an equivariant mapping between these two spaces with respect to\n",
        "a group action that is a rotation or a swap (here a $\\pi$ rotation). We\n",
        "can either apply a group action to the original grid and then map to the\n",
        "number grid, or we could map to the number grid and then apply the group\n",
        "action. Equivariance demands that the result of either of these\n",
        "procedures should be the same.\n",
        "\n",
        "![](../demonstrations/geometric_qml/equivariant-example.jpg){.align-center\n",
        "width=\"80.0%\"}\n",
        "\n",
        "Given the vast amount of input data required to train a neural network\n",
        "the principle that one can pre-encode known symmetry structures into the\n",
        "network allows us to learn better and faster. Indeed it is the reason\n",
        "for the success of convolutional neural networks (CNNs) for image\n",
        "analysis, where it is known they are equivariant with respect to\n",
        "translations. They naturally encode the idea that a picture of a dog is\n",
        "symmetrically related to the same picture slid to the left by n pixels,\n",
        "and they do this by having neural network layers that are equivariant\n",
        "maps. With our focus on unitary representations (and so quantum\n",
        "circuits) we are looking to extend this idea to quantum machine\n",
        "learning.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ugPilCpKPtCF"
      },
      "source": [
        "Noughts and Crosses\n",
        "===================\n",
        "\n",
        "Let\\'s look at the game of noughts and crosses, as inspired by. Two\n",
        "players take turns to place a O or an X, depending on which player they\n",
        "are, in a 3x3 grid. The aim is to get three of your symbols in a row,\n",
        "column, or diagonal. As this is not always possible depending on the\n",
        "choices of the players, there could be a draw. Our learning task is to\n",
        "take a set of completed games labelled with their outcomes and teach the\n",
        "algorithm to identify these correctly.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "daTwFY-GPtCG"
      },
      "source": [
        "This board of nine elements has the symmetry of the square, also known\n",
        "as the *dihedral group*. This means it is symmetric under\n",
        "$\\frac{\\pi}{2}$ rotations and flips about the lines of symmetry of a\n",
        "square (vertical, horizontal, and both diagonals).\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pXOW9XfePtCG"
      },
      "source": [
        "![](../demonstrations/geometric_qml/NandC_sym.png){.align-center\n",
        "width=\"70.0%\"}\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "81LgOuSiPtCG"
      },
      "source": [
        "**The question is, how do we encode this in our QML problem?**\n",
        "\n",
        "First, let us encode this problem classically. We will consider a\n",
        "nine-element vector $v$, each element of which identifies a square of\n",
        "the board. The entries themselves can be $+1$,$0$,$-1,$ representing a\n",
        "nought, no symbol, or a cross. The label is one-hot encoded in a vector\n",
        "$y=(y_O,y_- , y_X)$ with $+1$ in the correct label and $-1$ in the\n",
        "others. For instance (-1,-1,1) would represent an X in the relevant\n",
        "position.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EbaXa7F7PtCG"
      },
      "source": [
        "To create the quantum model let us take nine qubits and let them\n",
        "represent squares of our board. We\\'ll initialise them all as\n",
        "$|0\\rangle$, which we note leaves the board invariant under the\n",
        "symmetries of the problem (flip and rotate all you want, it\\'s still\n",
        "going to be zeroes whatever your mapping). We will then look to apply\n",
        "single qubit $R_x(\\theta)$ rotations on individual qubits, encoding each\n",
        "of the possibilities in the board squares at an angle of\n",
        "$\\frac{2\\pi}{3}$ from each other. For our parameterised gates we will\n",
        "have a single-qubit $R_x(\\theta_1)$ and $R_y(\\theta_2)$ rotation at each\n",
        "point. We will then use $CR_y(\\theta_3)$ for two-qubit entangling gates.\n",
        "This implies that, for each encoding, crudely, we\\'ll need 18\n",
        "single-qubit rotation parameters and $\\binom{9}{2}=36$ two-qubit gate\n",
        "rotations. Let\\'s see how, by using symmetries, we can reduce this.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CS1oXl5EPtCG"
      },
      "source": [
        "![..](../demonstrations/geometric_qml/grid.jpg){.align-center\n",
        "width=\"35.0%\"}\n",
        "\n",
        "The indexing of our game board.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3lyfUhwJPtCH"
      },
      "source": [
        "The secret will be to encode the symmetries into the gate set so the\n",
        "observables we are interested in inherently respect the symmetries. How\n",
        "do we do this? We need to select the collections of gates that commute\n",
        "with the symmetries. In general, we can use the twirling formula for\n",
        "this:\n",
        "\n",
        "::: {.tip}\n",
        "::: {.title}\n",
        "Tip\n",
        ":::\n",
        "\n",
        "Let $\\mathcal{S}$ be the group that encodes our symmetries and $U$ be a\n",
        "unitary representation of $\\mathcal{S}$. Then,\n",
        "\n",
        "$$\\mathcal{T}_{U}[X]=\\frac{1}{|\\mathcal{S}|} \\sum_{s \\in \\mathcal{S}} U(s) X U(s)^{\\dagger}$$\n",
        "\n",
        "defines a projector onto the set of operators commuting with all\n",
        "elements of the representation, i.e.,\n",
        "$\\left[\\mathcal{T}_{U}[X], U(s)\\right]=$ 0 for all $X$ and\n",
        "$s \\in \\mathcal{S}$.\n",
        ":::\n",
        "\n",
        "The twirling process applied to an arbitrary unitary will give us a new\n",
        "unitary that commutes with the group as we require. We remember that\n",
        "unitary gates typically have the form $W = \\exp(-i\\theta H)$, where $H$\n",
        "is a Hermitian matrix called a *generator*, and $\\theta$ may be fixed or\n",
        "left as a free parameter. A recipe for creating a unitary that commutes\n",
        "with our symmetries is to *twirl the generator of the gate*, i.e., we\n",
        "move from the gate $W = \\exp(-i\\theta H)$ to the gate\n",
        "$W' = \\exp(-i\\theta\\mathcal{T}_U[H])$. When each term in the twirling\n",
        "formula acts on different qubits, then this unitary would further\n",
        "simplify to\n",
        "\n",
        "$$W' = \\bigotimes_{s\\in\\mathcal{S}}U(s)\\exp(-i\\tfrac{\\theta}{\\vert\\mathcal{S}\\vert})U(s)^\\dagger.$$\n",
        "\n",
        "For simplicity, we can absorb the normalization factor\n",
        "$\\vert\\mathcal{S}\\vert$ into the free parameter $\\theta$.\n",
        "\n",
        "So let\\'s look again at our choice of gates: single-qubit $R_x(\\theta)$\n",
        "and $R_y(\\theta)$ rotations, and entangling two-qubit $CR_y(\\phi)$\n",
        "gates. What will we get by twirling these?\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bbVT5h2HPtCH"
      },
      "source": [
        "In this particular instance we can see the action of the twirling\n",
        "operation geometrically as the symmetries involved are all permutations.\n",
        "Let\\'s consider the $R_x$ rotation acting on one qubit. Now if this\n",
        "qubit is in the centre location on the grid, then we can flip around any\n",
        "symmetry axis we like, and this operation leaves the qubit invariant, so\n",
        "we\\'ve identified one equivariant gate immediately. If the qubit is on\n",
        "the corners, then the flipping will send this qubit rotation to each of\n",
        "the other corners. Similarly, if a qubit is on the central edge then the\n",
        "rotation gate will be sent round the other edges. So we can see that the\n",
        "twirling operation is a sum over all the possible outcomes of performing\n",
        "the symmetry action (the sum over the symmetry group actions). Having\n",
        "done this we can see that for a single-qubit rotation the invariant maps\n",
        "are rotations on the central qubit, at all the corners, and at all the\n",
        "central edges (when their rotation angles are fixed to be the same).\n",
        "\n",
        "As an example consider the following figure, where we take a $R_x$ gate\n",
        "in the corner and then apply all the symmetries of a square. The result\n",
        "of this twirling leads us to have the same gate at all the corners.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Qj-0MV6ePtCH"
      },
      "source": [
        "![](../demonstrations/geometric_qml/twirl.jpeg){.align-center\n",
        "width=\"70.0%\"}\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YkTducrePtCH"
      },
      "source": [
        "For entangling gates the situation is similar. There are three invariant\n",
        "classes, the centre entangled with all corners, with all edges, and the\n",
        "edges paired in a ring.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Wzz4R-IxPtCH"
      },
      "source": [
        "The prediction of a label is obtained via a one-hot-encoding by\n",
        "measuring the expectation values of three invariant observables:\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UuFwqycTPtCH"
      },
      "source": [
        "$$O_{-}=Z_{\\text {middle }}=Z_{4}$$\n",
        "\n",
        "$$O_{\\circ}=\\frac{1}{4} \\sum_{i \\in \\text { corners }} Z_{i}=\\frac{1}{4}\\left[Z_{0}+Z_{2}+Z_{6}+Z_{8}\\right]$$\n",
        "\n",
        "$$O_{\\times}=\\frac{1}{4} \\sum_{i \\in \\text { edges }} Z_{i}=\\frac{1}{4}\\left[Z_{1}+Z_{3}+Z_{5}+Z_{7}\\right]$$\n",
        "\n",
        "$$\\hat{\\boldsymbol{y}}=\\left(\\left\\langle O_{\\circ}\\right\\rangle,\\left\\langle O_{-}\\right\\rangle,\\left\\langle O_{\\times}\\right\\rangle\\right)$$\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5STgZUP6PtCI"
      },
      "source": [
        "This is the quantum encoding of the symmetries into a learning problem.\n",
        "A prediction for a given data point will be obtained by selecting the\n",
        "class for which the observed expectation value is the largest.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WSH0J4_XPtCI"
      },
      "source": [
        "Now that we have a specific encoding and have decided on our observables\n",
        "we need to choose a suitable cost function to optimise. We will use an\n",
        "$l_2$ loss function acting on pairs of games and labels $D={(g,y)}$,\n",
        "where $D$ is our dataset.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BNyvZ0HRPtCI"
      },
      "source": [
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ePThXcVlPtCI"
      },
      "source": [
        "Let\\'s now implement this!\n",
        "\n",
        "First let\\'s generate some games. Here we are creating a small program\n",
        "that will play Noughts and Crosses against itself in a random fashion.\n",
        "On completion, it spits out the winner and the winning board, with\n",
        "noughts as +1, draw as 0, and crosses as -1. There are 26,830 different\n",
        "possible games but we will only sample a few hundred.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "fBqEnXDYPtCI"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "import random\n",
        "\n",
        "# Fix seeds for reproducability\n",
        "torch.backends.cudnn.deterministic = True\n",
        "torch.manual_seed(16)\n",
        "random.seed(16)\n",
        "\n",
        "#  create an empty board\n",
        "def create_board():\n",
        "    return torch.tensor([[0, 0, 0], [0, 0, 0], [0, 0, 0]])\n",
        "\n",
        "\n",
        "# Check for empty places on board\n",
        "def possibilities(board):\n",
        "    l = []\n",
        "    for i in range(len(board)):\n",
        "        for j in range(3):\n",
        "            if board[i, j] == 0:\n",
        "                l.append((i, j))\n",
        "    return l\n",
        "\n",
        "\n",
        "# Select a random place for the player\n",
        "def random_place(board, player):\n",
        "    selection = possibilities(board)\n",
        "    current_loc = random.choice(selection)\n",
        "    board[current_loc] = player\n",
        "    return board\n",
        "\n",
        "\n",
        "# Check if there is a winner by having 3 in a row\n",
        "def row_win(board, player):\n",
        "    for x in range(3):\n",
        "        lista = []\n",
        "        win = True\n",
        "\n",
        "        for y in range(3):\n",
        "            lista.append(board[x, y])\n",
        "\n",
        "            if board[x, y] != player:\n",
        "                win = False\n",
        "\n",
        "        if win:\n",
        "            break\n",
        "\n",
        "    return win\n",
        "\n",
        "\n",
        "# Check if there is a winner by having 3 in a column\n",
        "def col_win(board, player):\n",
        "    for x in range(3):\n",
        "        win = True\n",
        "\n",
        "        for y in range(3):\n",
        "            if board[y, x] != player:\n",
        "                win = False\n",
        "\n",
        "        if win:\n",
        "            break\n",
        "\n",
        "    return win\n",
        "\n",
        "\n",
        "# Check if there is a winner by having 3 along a diagonal\n",
        "def diag_win(board, player):\n",
        "    win1 = True\n",
        "    win2 = True\n",
        "    for x, y in [(0, 0), (1, 1), (2, 2)]:\n",
        "        if board[x, y] != player:\n",
        "            win1 = False\n",
        "\n",
        "    for x, y in [(0, 2), (1, 1), (2, 0)]:\n",
        "        if board[x, y] != player:\n",
        "            win2 = False\n",
        "\n",
        "    return win1 or win2\n",
        "\n",
        "\n",
        "# Check if the win conditions have been met or if a draw has occurred\n",
        "def evaluate_game(board):\n",
        "    winner = None\n",
        "    for player in [1, -1]:\n",
        "        if row_win(board, player) or col_win(board, player) or diag_win(board, player):\n",
        "            winner = player\n",
        "\n",
        "    if torch.all(board != 0) and winner == None:\n",
        "        winner = 0\n",
        "\n",
        "    return winner\n",
        "\n",
        "\n",
        "# Main function to start the game\n",
        "def play_game():\n",
        "    board, winner, counter = create_board(), None, 1\n",
        "    while winner == None:\n",
        "        for player in [1, -1]:\n",
        "            board = random_place(board, player)\n",
        "            counter += 1\n",
        "            winner = evaluate_game(board)\n",
        "            if winner != None:\n",
        "                break\n",
        "\n",
        "    return [board.flatten(), winner]\n",
        "\n",
        "\n",
        "def create_dataset(size_for_each_winner):\n",
        "    game_d = {-1: [], 0: [], 1: []}\n",
        "\n",
        "    while min([len(v) for k, v in game_d.items()]) < size_for_each_winner:\n",
        "        board, winner = play_game()\n",
        "        if len(game_d[winner]) < size_for_each_winner:\n",
        "            game_d[winner].append(board)\n",
        "\n",
        "    res = []\n",
        "    for winner, boards in game_d.items():\n",
        "        res += [(board, winner) for board in boards]\n",
        "\n",
        "    return res\n",
        "\n",
        "\n",
        "NUM_TRAINING = 450\n",
        "NUM_VALIDATION = 600\n",
        "\n",
        "# Create datasets but with even numbers of each outcome\n",
        "with torch.no_grad():\n",
        "    dataset = create_dataset(NUM_TRAINING // 3)\n",
        "    dataset_val = create_dataset(NUM_VALIDATION // 3)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xB-FbtShPtCJ"
      },
      "source": [
        "Now let\\'s create the relevant circuit expectation values that respect\n",
        "the symmetry classes we defined over the single-site and two-site\n",
        "measurements.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 408
        },
        "id": "B3_k2h_IPtCJ",
        "outputId": "49afcd58-5e30-48fa-b6ec-de2d62a41d03"
      },
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<Figure size 3700x1000 with 1 Axes>"
            ],
            "image/png": "\n"
          },
          "metadata": {}
        }
      ],
      "source": [
        "import pennylane as qml\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "# Set up a nine-qubit system\n",
        "dev = qml.device(\"default.qubit.torch\", wires=9)\n",
        "\n",
        "ob_center = qml.PauliZ(4)\n",
        "ob_corner = (qml.PauliZ(0) + qml.PauliZ(2) + qml.PauliZ(6) + qml.PauliZ(8)) * (1 / 4)\n",
        "ob_edge = (qml.PauliZ(1) + qml.PauliZ(3) + qml.PauliZ(5) + qml.PauliZ(7)) * (1 / 4)\n",
        "\n",
        "# Now let's encode the data in the following qubit models, first with symmetry\n",
        "@qml.qnode(dev, interface=\"torch\")\n",
        "def circuit(x, p):\n",
        "    qml.Hadamard(wires=0)\n",
        "    qml.Hadamard(wires=1)\n",
        "    qml.Hadamard(wires=2)\n",
        "    qml.Hadamard(wires=3)\n",
        "    qml.Hadamard(wires=4)\n",
        "    qml.Hadamard(wires=5)\n",
        "    qml.Hadamard(wires=6)\n",
        "    qml.Hadamard(wires=7)\n",
        "    qml.Hadamard(wires=8)\n",
        "    qml.RY(x[0], wires=0)\n",
        "    qml.RY(x[1], wires=1)\n",
        "    qml.RY(x[2], wires=2)\n",
        "    qml.RY(x[3], wires=3)\n",
        "    qml.RY(x[4], wires=4)\n",
        "    qml.RY(x[5], wires=5)\n",
        "    qml.RY(x[6], wires=6)\n",
        "    qml.RY(x[7], wires=7)\n",
        "    qml.RY(x[8], wires=8)\n",
        "\n",
        "    # Centre single-qubit rotation\n",
        "    qml.RX(p[0], wires=4)\n",
        "    qml.RY(p[1], wires=4)\n",
        "\n",
        "    # Corner single-qubit rotation\n",
        "    qml.RX(p[2], wires=0)\n",
        "    qml.RX(p[2], wires=2)\n",
        "    qml.RX(p[2], wires=6)\n",
        "    qml.RX(p[2], wires=8)\n",
        "\n",
        "    qml.RY(p[3], wires=0)\n",
        "    qml.RY(p[3], wires=2)\n",
        "    qml.RY(p[3], wires=6)\n",
        "    qml.RY(p[3], wires=8)\n",
        "\n",
        "    # Edge single-qubit rotation\n",
        "    qml.RX(p[4], wires=1)\n",
        "    qml.RX(p[4], wires=3)\n",
        "    qml.RX(p[4], wires=5)\n",
        "    qml.RX(p[4], wires=7)\n",
        "\n",
        "    qml.RY(p[5], wires=1)\n",
        "    qml.RY(p[5], wires=3)\n",
        "    qml.RY(p[5], wires=5)\n",
        "    qml.RY(p[5], wires=7)\n",
        "\n",
        "    # Double Entagling two-qubit gates\n",
        "    # circling the edge of the board\n",
        "    qml.CRY(p[6], wires=[0, 1])\n",
        "    qml.CRY(p[6], wires=[2, 1])\n",
        "    qml.CRY(p[6], wires=[2, 5])\n",
        "    qml.CRY(p[6], wires=[8, 5])\n",
        "    qml.CRY(p[6], wires=[8, 7])\n",
        "    qml.CRY(p[6], wires=[6, 7])\n",
        "    qml.CRY(p[6], wires=[6, 3])\n",
        "    qml.CRY(p[6], wires=[0, 3])\n",
        "\n",
        "    # To the corners from the centre\n",
        "    qml.CRY(p[7], wires=[4, 0])\n",
        "    qml.CRY(p[7], wires=[4, 2])\n",
        "    qml.CRY(p[7], wires=[4, 6])\n",
        "    qml.CRY(p[7], wires=[4, 8])\n",
        "\n",
        "    # To the centre from the edges\n",
        "    qml.CRY(p[8], wires=[1, 4])\n",
        "    qml.CRY(p[8], wires=[3, 4])\n",
        "    qml.CRY(p[8], wires=[5, 4])\n",
        "    qml.CRY(p[8], wires=[7, 4])\n",
        "\n",
        "    # circling the edge of the board\n",
        "    qml.CRY(p[6], wires=[0, 1])\n",
        "    qml.CRY(p[6], wires=[2, 1])\n",
        "    qml.CRY(p[6], wires=[2, 5])\n",
        "    qml.CRY(p[6], wires=[8, 5])\n",
        "    qml.CRY(p[6], wires=[8, 7])\n",
        "    qml.CRY(p[6], wires=[6, 7])\n",
        "    qml.CRY(p[6], wires=[6, 3])\n",
        "    qml.CRY(p[6], wires=[0, 3])\n",
        "\n",
        "    # To the corners from the centre\n",
        "    qml.CRY(p[7], wires=[4, 0])\n",
        "    qml.CRY(p[7], wires=[4, 2])\n",
        "    qml.CRY(p[7], wires=[4, 6])\n",
        "    qml.CRY(p[7], wires=[4, 8])\n",
        "\n",
        "    # To the centre from the edges\n",
        "    qml.CRY(p[8], wires=[1, 4])\n",
        "    qml.CRY(p[8], wires=[3, 4])\n",
        "    qml.CRY(p[8], wires=[5, 4])\n",
        "    qml.CRY(p[8], wires=[7, 4])\n",
        "\n",
        "    return [qml.expval(ob_center), qml.expval(ob_corner), qml.expval(ob_edge)]\n",
        "\n",
        "\n",
        "fig, ax = qml.draw_mpl(circuit)([0] * 9, 18 * [0])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EZKg8hC1PtCJ"
      },
      "source": [
        "Let\\'s also look at the same series of gates but this time they are\n",
        "applied independently from one another, so we won\\'t be preserving the\n",
        "symmetries with our gate operations. Practically this also means more\n",
        "parameters, as previously groups of gates were updated together.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 408
        },
        "id": "Bhp9RWhXPtCJ",
        "outputId": "24413c39-e6ee-414b-e7c0-517743512713"
      },
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<Figure size 3700x1000 with 1 Axes>"
            ],
            "image/png": "\n"
          },
          "metadata": {}
        }
      ],
      "source": [
        "@qml.qnode(dev, interface=\"torch\")\n",
        "def circuit_no_sym(x, p):\n",
        "\n",
        "    qml.Hadamard(wires=0)\n",
        "    qml.Hadamard(wires=1)\n",
        "    qml.Hadamard(wires=2)\n",
        "    qml.Hadamard(wires=3)\n",
        "    qml.Hadamard(wires=4)\n",
        "    qml.Hadamard(wires=5)\n",
        "    qml.Hadamard(wires=6)\n",
        "    qml.Hadamard(wires=7)\n",
        "    qml.Hadamard(wires=8)\n",
        "    qml.RY(x[0], wires=0)\n",
        "    qml.RY(x[1], wires=1)\n",
        "    qml.RY(x[2], wires=2)\n",
        "    qml.RY(x[3], wires=3)\n",
        "    qml.RY(x[4], wires=4)\n",
        "    qml.RY(x[5], wires=5)\n",
        "    qml.RY(x[6], wires=6)\n",
        "    qml.RY(x[7], wires=7)\n",
        "    qml.RY(x[8], wires=8)\n",
        "\n",
        "    # Centre single-qubit rotation\n",
        "    qml.RX(p[0], wires=4)\n",
        "    qml.RY(p[1], wires=4)\n",
        "\n",
        "    # Note in this circuit the parameters aren't all the same.\n",
        "    # Previously they were identical to ensure they were applied\n",
        "    # as one combined gate. The fact they can all vary independently\n",
        "    # here means we aren't respecting the symmetry.\n",
        "\n",
        "    # Corner single-qubit rotation\n",
        "    qml.RX(p[2], wires=0)\n",
        "    qml.RX(p[3], wires=2)\n",
        "    qml.RX(p[4], wires=6)\n",
        "    qml.RX(p[5], wires=8)\n",
        "\n",
        "    qml.RY(p[6], wires=0)\n",
        "    qml.RY(p[7], wires=2)\n",
        "    qml.RY(p[8], wires=6)\n",
        "    qml.RY(p[9], wires=8)\n",
        "\n",
        "    # Edge single-qubit rotation\n",
        "    qml.RX(p[10], wires=1)\n",
        "    qml.RX(p[11], wires=3)\n",
        "    qml.RX(p[12], wires=5)\n",
        "    qml.RX(p[13], wires=7)\n",
        "\n",
        "    qml.RY(p[14], wires=1)\n",
        "    qml.RY(p[15], wires=3)\n",
        "    qml.RY(p[16], wires=5)\n",
        "    qml.RY(p[17], wires=7)\n",
        "\n",
        "    # Double Entagling two-qubit gates\n",
        "    # circling the edge of the board\n",
        "    qml.CRY(p[18], wires=[0, 1])\n",
        "    qml.CRY(p[19], wires=[2, 1])\n",
        "    qml.CRY(p[20], wires=[2, 5])\n",
        "    qml.CRY(p[21], wires=[8, 5])\n",
        "    qml.CRY(p[22], wires=[8, 7])\n",
        "    qml.CRY(p[23], wires=[6, 7])\n",
        "    qml.CRY(p[24], wires=[6, 3])\n",
        "    qml.CRY(p[25], wires=[0, 3])\n",
        "\n",
        "    # To the corners from the centre\n",
        "    qml.CRY(p[26], wires=[4, 0])\n",
        "    qml.CRY(p[27], wires=[4, 2])\n",
        "    qml.CRY(p[28], wires=[4, 6])\n",
        "    qml.CRY(p[29], wires=[4, 8])\n",
        "\n",
        "    # To the centre from the edges\n",
        "    qml.CRY(p[30], wires=[1, 4])\n",
        "    qml.CRY(p[31], wires=[3, 4])\n",
        "    qml.CRY(p[32], wires=[5, 4])\n",
        "    qml.CRY(p[33], wires=[7, 4])\n",
        "\n",
        "    # circling the edge of the board\n",
        "    qml.CRY(p[18], wires=[0, 1])\n",
        "    qml.CRY(p[19], wires=[2, 1])\n",
        "    qml.CRY(p[20], wires=[2, 5])\n",
        "    qml.CRY(p[21], wires=[8, 5])\n",
        "    qml.CRY(p[22], wires=[8, 7])\n",
        "    qml.CRY(p[23], wires=[6, 7])\n",
        "    qml.CRY(p[24], wires=[6, 3])\n",
        "    qml.CRY(p[25], wires=[0, 3])\n",
        "\n",
        "    # To the corners from the centre\n",
        "    qml.CRY(p[26], wires=[4, 0])\n",
        "    qml.CRY(p[27], wires=[4, 2])\n",
        "    qml.CRY(p[28], wires=[4, 6])\n",
        "    qml.CRY(p[29], wires=[4, 8])\n",
        "\n",
        "    # To the centre from the edges\n",
        "    qml.CRY(p[30], wires=[1, 4])\n",
        "    qml.CRY(p[31], wires=[3, 4])\n",
        "    qml.CRY(p[32], wires=[5, 4])\n",
        "    qml.CRY(p[33], wires=[7, 4])\n",
        "\n",
        "    return [qml.expval(ob_center), qml.expval(ob_corner), qml.expval(ob_edge)]\n",
        "\n",
        "\n",
        "fig, ax = qml.draw_mpl(circuit_no_sym)([0] * 9, [0] * 34)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1REWjSNvPtCJ"
      },
      "source": [
        "Note again how, though these circuits have a similar form to before,\n",
        "they are parameterised differently. We need to feed the vector\n",
        "$\\boldsymbol{y}$ made up of the expectation value of these three\n",
        "operators into the loss function and use this to update our parameters.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "id": "7wTu3Uw_PtCJ"
      },
      "outputs": [],
      "source": [
        "import math\n",
        "\n",
        "def encode_game(game):\n",
        "    board, res = game\n",
        "    x = board * (2 * math.pi) / 3\n",
        "    if res == 1:\n",
        "        y = [-1, -1, 1]\n",
        "    elif res == -1:\n",
        "        y = [1, -1, -1]\n",
        "    else:\n",
        "        y = [-1, 1, -1]\n",
        "    return x, y"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qMYUsXoZPtCJ"
      },
      "source": [
        "Recall that the loss function we\\'re interested in is\n",
        "$\\mathcal{L}(\\mathcal{D})=\\frac{1}{|\\mathcal{D}|} \\sum_{(\\boldsymbol{g}, \\boldsymbol{y}) \\in \\mathcal{D}}\\|\\hat{\\boldsymbol{y}}(\\boldsymbol{g})-\\boldsymbol{y}\\|_{2}^{2}$.\n",
        "We need to define this and then we can begin our optimisation.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "id": "_RVbTFtXPtCJ"
      },
      "outputs": [],
      "source": [
        "# calculate the mean square error for this classification problem\n",
        "def cost_function(params, input, target):\n",
        "    output = torch.stack([torch.hstack(circuit(x, params)) for x in input])\n",
        "    vec = output - target\n",
        "    sum_sqr = torch.sum(vec * vec, dim=1)\n",
        "    return torch.mean(sum_sqr)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wiKzIjHSPtCJ"
      },
      "source": [
        "Let\\'s now train our symmetry-preserving circuit on the data.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 0
        },
        "id": "ISDNjQj0PtCK",
        "outputId": "5e938cbf-2218-46a6-8edb-8ef841c22141"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "accuracy without training = 0.5249999761581421\n",
            "Epoch:  1 | Loss: 3.176805 | Validation accuracy: 0.485000\n",
            "Epoch:  2 | Loss: 2.856527 | Validation accuracy: 0.436667\n",
            "Epoch:  3 | Loss: 2.822027 | Validation accuracy: 0.496667\n",
            "Epoch:  4 | Loss: 2.764042 | Validation accuracy: 0.456667\n",
            "Epoch:  5 | Loss: 2.671184 | Validation accuracy: 0.560000\n",
            "Epoch:  6 | Loss: 2.681827 | Validation accuracy: 0.545000\n",
            "Epoch:  7 | Loss: 2.649758 | Validation accuracy: 0.541667\n",
            "Epoch:  8 | Loss: 2.564894 | Validation accuracy: 0.573333\n",
            "Epoch:  9 | Loss: 2.649880 | Validation accuracy: 0.591667\n",
            "Epoch: 10 | Loss: 2.615056 | Validation accuracy: 0.580000\n",
            "Epoch: 11 | Loss: 2.590838 | Validation accuracy: 0.596667\n",
            "Epoch: 12 | Loss: 2.577285 | Validation accuracy: 0.583333\n",
            "Epoch: 13 | Loss: 2.638542 | Validation accuracy: 0.581667\n",
            "Epoch: 14 | Loss: 2.569011 | Validation accuracy: 0.595000\n",
            "Epoch: 15 | Loss: 2.578797 | Validation accuracy: 0.585000\n"
          ]
        }
      ],
      "source": [
        "from torch import optim\n",
        "import numpy as np\n",
        "\n",
        "params = 0.01 * torch.randn(9)\n",
        "params.requires_grad = True\n",
        "opt = optim.Adam([params], lr=1e-2)\n",
        "\n",
        "\n",
        "max_epoch = 15\n",
        "max_step = 30\n",
        "batch_size = 10\n",
        "\n",
        "encoded_dataset = list(zip(*[encode_game(game) for game in dataset]))\n",
        "encoded_dataset_val = list(zip(*[encode_game(game) for game in dataset_val]))\n",
        "\n",
        "\n",
        "def accuracy(p, x_val, y_val):\n",
        "    with torch.no_grad():\n",
        "        y_val = torch.tensor(y_val)\n",
        "        y_out = torch.stack([torch.hstack(circuit(x, p)) for x in x_val])\n",
        "        acc = torch.sum(torch.argmax(y_out, axis=1) == torch.argmax(y_val, axis=1))\n",
        "        return acc / len(x_val)\n",
        "\n",
        "\n",
        "print(f\"accuracy without training = {accuracy(params, *encoded_dataset_val)}\")\n",
        "\n",
        "x_dataset = torch.stack(encoded_dataset[0])\n",
        "y_dataset = torch.tensor(encoded_dataset[1], requires_grad=False)\n",
        "\n",
        "saved_costs_sym = []\n",
        "saved_accs_sym = []\n",
        "for epoch in range(max_epoch):\n",
        "    rand_idx = torch.randperm(len(x_dataset))\n",
        "    # Shuffled dataset\n",
        "    x_dataset = x_dataset[rand_idx]\n",
        "    y_dataset = y_dataset[rand_idx]\n",
        "\n",
        "    costs = []\n",
        "\n",
        "    for step in range(max_step):\n",
        "        x_batch = x_dataset[step * batch_size : (step + 1) * batch_size]\n",
        "        y_batch = y_dataset[step * batch_size : (step + 1) * batch_size]\n",
        "\n",
        "        def opt_func():\n",
        "            opt.zero_grad()\n",
        "            loss = cost_function(params, x_batch, y_batch)\n",
        "            costs.append(loss.item())\n",
        "            loss.backward()\n",
        "            return loss\n",
        "\n",
        "        opt.step(opt_func)\n",
        "\n",
        "    cost = np.mean(costs)\n",
        "    saved_costs_sym.append(cost)\n",
        "\n",
        "    if (epoch + 1) % 1 == 0:\n",
        "        # Compute validation accuracy\n",
        "        acc_val = accuracy(params, *encoded_dataset_val)\n",
        "        saved_accs_sym.append(acc_val)\n",
        "\n",
        "        res = [epoch + 1, cost, acc_val]\n",
        "        print(\"Epoch: {:2d} | Loss: {:3f} | Validation accuracy: {:3f}\".format(*res))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "z9SOd6lQPtCK"
      },
      "source": [
        "Now we train the non-symmetry preserving circuit.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 0
        },
        "id": "THRXOBEUPtCK",
        "outputId": "3e4761c0-bdf8-470e-b9d3-eba258c5aac0"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "accuracy without training = 0.44999998807907104\n",
            "Epoch:  1 | Loss: 3.169856 | Validation accuracy: 0.438333\n",
            "Epoch:  2 | Loss: 2.915814 | Validation accuracy: 0.443333\n",
            "Epoch:  3 | Loss: 2.848999 | Validation accuracy: 0.445000\n",
            "Epoch:  4 | Loss: 2.798156 | Validation accuracy: 0.428333\n",
            "Epoch:  5 | Loss: 2.745969 | Validation accuracy: 0.453333\n",
            "Epoch:  6 | Loss: 2.714810 | Validation accuracy: 0.461667\n",
            "Epoch:  7 | Loss: 2.697196 | Validation accuracy: 0.485000\n",
            "Epoch:  8 | Loss: 2.678972 | Validation accuracy: 0.480000\n",
            "Epoch:  9 | Loss: 2.667330 | Validation accuracy: 0.510000\n",
            "Epoch: 10 | Loss: 2.651770 | Validation accuracy: 0.513333\n",
            "Epoch: 11 | Loss: 2.636838 | Validation accuracy: 0.525000\n",
            "Epoch: 12 | Loss: 2.623924 | Validation accuracy: 0.560000\n",
            "Epoch: 13 | Loss: 2.608304 | Validation accuracy: 0.531667\n",
            "Epoch: 14 | Loss: 2.597640 | Validation accuracy: 0.555000\n",
            "Epoch: 15 | Loss: 2.588152 | Validation accuracy: 0.565000\n"
          ]
        }
      ],
      "source": [
        "params = 0.01 * torch.randn(34)\n",
        "params.requires_grad = True\n",
        "opt = optim.Adam([params], lr=1e-2)\n",
        "\n",
        "# calculate mean square error for this classification problem\n",
        "\n",
        "\n",
        "def cost_function_no_sym(params, input, target):\n",
        "    output = torch.stack([torch.hstack(circuit_no_sym(x, params)) for x in input])\n",
        "    vec = output - target\n",
        "    sum_sqr = torch.sum(vec * vec, dim=1)\n",
        "    return torch.mean(sum_sqr)\n",
        "\n",
        "\n",
        "max_epoch = 15\n",
        "max_step = 30\n",
        "batch_size = 15\n",
        "\n",
        "encoded_dataset = list(zip(*[encode_game(game) for game in dataset]))\n",
        "encoded_dataset_val = list(zip(*[encode_game(game) for game in dataset_val]))\n",
        "\n",
        "\n",
        "def accuracy_no_sym(p, x_val, y_val):\n",
        "    with torch.no_grad():\n",
        "        y_val = torch.tensor(y_val)\n",
        "        y_out = torch.stack([torch.hstack(circuit_no_sym(x, p)) for x in x_val])\n",
        "        acc = torch.sum(torch.argmax(y_out, axis=1) == torch.argmax(y_val, axis=1))\n",
        "        return acc / len(x_val)\n",
        "\n",
        "\n",
        "print(f\"accuracy without training = {accuracy_no_sym(params, *encoded_dataset_val)}\")\n",
        "\n",
        "\n",
        "x_dataset = torch.stack(encoded_dataset[0])\n",
        "y_dataset = torch.tensor(encoded_dataset[1], requires_grad=False)\n",
        "\n",
        "saved_costs = []\n",
        "saved_accs = []\n",
        "for epoch in range(max_epoch):\n",
        "    rand_idx = torch.randperm(len(x_dataset))\n",
        "    # Shuffled dataset\n",
        "    x_dataset = x_dataset[rand_idx]\n",
        "    y_dataset = y_dataset[rand_idx]\n",
        "\n",
        "    costs = []\n",
        "\n",
        "    for step in range(max_step):\n",
        "        x_batch = x_dataset[step * batch_size : (step + 1) * batch_size]\n",
        "        y_batch = y_dataset[step * batch_size : (step + 1) * batch_size]\n",
        "\n",
        "        def opt_func():\n",
        "            opt.zero_grad()\n",
        "            loss = cost_function_no_sym(params, x_batch, y_batch)\n",
        "            costs.append(loss.item())\n",
        "            loss.backward()\n",
        "            return loss\n",
        "\n",
        "        opt.step(opt_func)\n",
        "\n",
        "    cost = np.mean(costs)\n",
        "    saved_costs.append(costs)\n",
        "\n",
        "    if (epoch + 1) % 1 == 0:\n",
        "        # Compute validation accuracy\n",
        "        acc_val = accuracy_no_sym(params, *encoded_dataset_val)\n",
        "        saved_accs.append(acc_val)\n",
        "\n",
        "        res = [epoch + 1, cost, acc_val]\n",
        "        print(\"Epoch: {:2d} | Loss: {:3f} | Validation accuracy: {:3f}\".format(*res))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8ujgVimaPtCK"
      },
      "source": [
        "Finally let\\'s plot the results and see how the two training regimes\n",
        "differ.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 472
        },
        "id": "GeF6cgOLPtCK",
        "outputId": "9838d950-89b7-44d9-e2ec-f1a6b4ffcd72"
      },
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<Figure size 640x480 with 1 Axes>"
            ],
            "image/png": "\n"
          },
          "metadata": {}
        }
      ],
      "source": [
        "from matplotlib import pyplot as plt\n",
        "\n",
        "plt.title(\"Validation accuracies\")\n",
        "plt.plot(saved_accs_sym, \"b\", label=\"Symmetric\")\n",
        "plt.plot(saved_accs, \"g\", label=\"Standard\")\n",
        "\n",
        "plt.ylabel(\"Validation accuracy (%)\")\n",
        "plt.xlabel(\"Optimization steps\")\n",
        "plt.legend()\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iilNBvUJPtCK"
      },
      "source": [
        "What we can see then is that by paying attention to the symmetries\n",
        "intrinsic to the learning problem and reflecting this in an equivariant\n",
        "gate set we have managed to improve our learning accuracies, while also\n",
        "using fewer parameters. While the symmetry-aware circuit clearly\n",
        "outperforms the naive one, it is notable however that the learning\n",
        "accuracies in both cases are hardly ideal given this is a solved game.\n",
        "So paying attention to symmetries definitely helps, but it also isn\\'t a\n",
        "magic bullet!\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "59M09yH_PtCK"
      },
      "source": [
        "The use of symmetries in both quantum and classical machine learning is\n",
        "a developing field, so we can expect new results to emerge over the\n",
        "coming years. If you want to get involved, the references given below\n",
        "are a great place to start.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AfMCcW_QPtCK"
      },
      "source": [
        "References\n",
        "==========\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_Xne5NFbPtCK"
      },
      "source": [
        "Acknowledgments\n",
        "===============\n",
        "\n",
        "The author would also like to acknowledge the helpful input of C.-Y.\n",
        "Park.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_kP9cOTuPtCK"
      },
      "source": [
        "About the author\n",
        "================\n"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.17"
    },
    "colab": {
      "provenance": []
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}