[404218]: / Code / Tensor Network vs FC Explainability / Dataset 1 / DS1 2FC 2TN TPU kkawchak.ipynb

Download this file

1266 lines (1266 with data), 217.9 kB

{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "machine_shape": "hm",
      "gpuType": "V28"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "TPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8XnVMPBXmtRa"
      },
      "source": [
        "# TensorNetworks in Neural Networks.\n",
        "\n",
        "Here, we have a small toy example of how to use a TN inside of a fully connected neural network.\n",
        "\n",
        "First off, let's install tensornetwork"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "7HGRsYNAFxME"
      },
      "source": [
        "# !pip install tensornetwork\n",
        "\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "import tensorflow as tf\n",
        "# Import tensornetwork\n",
        "import tensornetwork as tn\n",
        "import random\n",
        "import time\n",
        "import pandas as pd\n",
        "# Set the backend to tesorflow\n",
        "# (default is numpy)\n",
        "tn.set_default_backend(\"tensorflow\")\n",
        "np.random.seed(42)\n",
        "random.seed(42)\n",
        "tf.random.set_seed(42)\n",
        "# Explainability code assistance aided by ChatGPT3.5\n",
        "# 2021 Kelly, D. TensorFlow Explainable AI tutorial https://www.youtube.com/watch?v=6xePkn3-LME"
      ],
      "execution_count": 90,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "g1OMCo5XmrYu"
      },
      "source": [
        "# TensorNetwork layer definition\n",
        "\n",
        "Here, we define the TensorNetwork layer we wish to use to replace the fully connected layer. Here, we simply use a 2 node Matrix Product Operator network to replace the normal dense weight matrix.\n",
        "\n",
        "We TensorNetwork's NCon API to keep the code short."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "wvSMKtPufnLp"
      },
      "source": [
        "class TNLayer(tf.keras.layers.Layer):\n",
        "\n",
        "  def __init__(self):\n",
        "    super(TNLayer, self).__init__()\n",
        "    # Create the variables for the layer.\n",
        "    self.a_var = tf.Variable(tf.random.normal(shape=(32, 32, 2),\n",
        "                                              stddev=1.0/32.0),\n",
        "                             name=\"a\", trainable=True)\n",
        "    self.b_var = tf.Variable(tf.random.normal(shape=(32, 32, 2),\n",
        "                                              stddev=1.0/32.0),\n",
        "                             name=\"b\", trainable=True)\n",
        "    self.bias = tf.Variable(tf.zeros(shape=(32, 32)),\n",
        "                            name=\"bias\", trainable=True)\n",
        "\n",
        "  def call(self, inputs):\n",
        "    # Define the contraction.\n",
        "    # We break it out so we can parallelize a batch using\n",
        "    # tf.vectorized_map (see below).\n",
        "    def f(input_vec, a_var, b_var, bias_var):\n",
        "      # Reshape to a matrix instead of a vector.\n",
        "      input_vec = tf.reshape(input_vec, (32, 32))\n",
        "\n",
        "      # Now we create the network.\n",
        "      a = tn.Node(a_var)\n",
        "      b = tn.Node(b_var)\n",
        "      x_node = tn.Node(input_vec)\n",
        "      a[1] ^ x_node[0]\n",
        "      b[1] ^ x_node[1]\n",
        "      a[2] ^ b[2]\n",
        "\n",
        "      # The TN should now look like this\n",
        "      #   |     |\n",
        "      #   a --- b\n",
        "      #    \\   /\n",
        "      #      x\n",
        "\n",
        "      # Now we begin the contraction.\n",
        "      c = a @ x_node\n",
        "      result = (c @ b).tensor\n",
        "\n",
        "      # To make the code shorter, we also could've used Ncon.\n",
        "      # The above few lines of code is the same as this:\n",
        "      # result = tn.ncon([x, a_var, b_var], [[1, 2], [-1, 1, 3], [-2, 2, 3]])\n",
        "\n",
        "      # Finally, add bias.\n",
        "      return result + bias_var\n",
        "\n",
        "    # To deal with a batch of items, we can use the tf.vectorized_map\n",
        "    # function.\n",
        "    # https://www.tensorflow.org/api_docs/python/tf/vectorized_map\n",
        "    result = tf.vectorized_map(\n",
        "        lambda vec: f(vec, self.a_var, self.b_var, self.bias), inputs)\n",
        "    return tf.nn.relu(tf.reshape(result, (-1, 1024)))"
      ],
      "execution_count": 91,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "V-CVqIhPnhY_"
      },
      "source": [
        "# Smaller model\n",
        "These two models are effectively the same, but notice how the TN layer has nearly 10x fewer parameters."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "bbKsmK8wIFTp",
        "outputId": "4146e343-d5ff-4f17-faf5-2a8bc7b74c8c",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 0
        }
      },
      "source": [
        "Dense = tf.keras.layers.Dense\n",
        "tn_model = tf.keras.Sequential(\n",
        "    [\n",
        "     tf.keras.Input(shape=(2,)),\n",
        "     Dense(1024, activation=tf.nn.relu),\n",
        "     # Start Modified Layers\n",
        "     Dense(1024, activation=tf.nn.relu),\n",
        "     Dense(1024, activation=tf.nn.relu),\n",
        "     TNLayer(),\n",
        "     TNLayer(),\n",
        "     # Finish Modified Layers\n",
        "     Dense(1, activation=None)])\n",
        "tn_model.summary()"
      ],
      "execution_count": 92,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Model: \"sequential_8\"\n",
            "_________________________________________________________________\n",
            " Layer (type)                Output Shape              Param #   \n",
            "=================================================================\n",
            " dense_30 (Dense)            (None, 1024)              3072      \n",
            "                                                                 \n",
            " dense_31 (Dense)            (None, 1024)              1049600   \n",
            "                                                                 \n",
            " dense_32 (Dense)            (None, 1024)              1049600   \n",
            "                                                                 \n",
            " tn_layer_9 (TNLayer)        (None, 1024)              5120      \n",
            "                                                                 \n",
            " tn_layer_10 (TNLayer)       (None, 1024)              5120      \n",
            "                                                                 \n",
            " dense_33 (Dense)            (None, 1)                 1025      \n",
            "                                                                 \n",
            "=================================================================\n",
            "Total params: 2113537 (8.06 MB)\n",
            "Trainable params: 2113537 (8.06 MB)\n",
            "Non-trainable params: 0 (0.00 Byte)\n",
            "_________________________________________________________________\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GWwoYp0WnsLA"
      },
      "source": [
        "# Training a model\n",
        "\n",
        "You can train the TN model just as you would a normal neural network model! Here, we give an example of how to do it in Keras."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qDFzOC7sDBJ-"
      },
      "source": [
        "X = np.concatenate([np.random.randn(120, 2) + np.array([3, 3]),\n",
        "                    np.random.randn(120, 2) + np.array([-3, -3]),\n",
        "                    np.random.randn(120, 2) + np.array([-3, 3]),\n",
        "                    np.random.randn(120, 2) + np.array([3, -3])])\n",
        "\n",
        "Y = np.concatenate([np.ones((240)), -np.ones((240))])"
      ],
      "execution_count": 93,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "seconds = time.time()\n",
        "print(\"Time in seconds since beginning of run:\", seconds)\n",
        "local_time = time.ctime(seconds)\n",
        "print(local_time)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 0
        },
        "id": "19TWP-1eKURB",
        "outputId": "676af897-2ef8-45bf-8e5f-d1338ef2fe9b"
      },
      "execution_count": 94,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Time in seconds since beginning of run: 1712560083.9302807\n",
            "Mon Apr  8 07:08:03 2024\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "crc0q1vbIyTj",
        "outputId": "7a7a4363-88c0-44c0-e572-fd6af9f22775",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 0
        }
      },
      "source": [
        "tn_model.compile(optimizer=\"adam\", loss=\"mean_squared_error\")\n",
        "tn_model.fit(X, Y, epochs=300, verbose=2)"
      ],
      "execution_count": 95,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Epoch 1/300\n",
            "15/15 - 2s - loss: 0.9375 - 2s/epoch - 103ms/step\n",
            "Epoch 2/300\n",
            "15/15 - 0s - loss: 0.2116 - 174ms/epoch - 12ms/step\n",
            "Epoch 3/300\n",
            "15/15 - 0s - loss: 0.0701 - 171ms/epoch - 11ms/step\n",
            "Epoch 4/300\n",
            "15/15 - 0s - loss: 0.0407 - 166ms/epoch - 11ms/step\n",
            "Epoch 5/300\n",
            "15/15 - 0s - loss: 0.0273 - 162ms/epoch - 11ms/step\n",
            "Epoch 6/300\n",
            "15/15 - 0s - loss: 0.0219 - 163ms/epoch - 11ms/step\n",
            "Epoch 7/300\n",
            "15/15 - 0s - loss: 0.0176 - 170ms/epoch - 11ms/step\n",
            "Epoch 8/300\n",
            "15/15 - 0s - loss: 0.0134 - 164ms/epoch - 11ms/step\n",
            "Epoch 9/300\n",
            "15/15 - 0s - loss: 0.0097 - 161ms/epoch - 11ms/step\n",
            "Epoch 10/300\n",
            "15/15 - 0s - loss: 0.0070 - 157ms/epoch - 10ms/step\n",
            "Epoch 11/300\n",
            "15/15 - 0s - loss: 0.0078 - 157ms/epoch - 10ms/step\n",
            "Epoch 12/300\n",
            "15/15 - 0s - loss: 0.0031 - 158ms/epoch - 11ms/step\n",
            "Epoch 13/300\n",
            "15/15 - 0s - loss: 0.0022 - 158ms/epoch - 11ms/step\n",
            "Epoch 14/300\n",
            "15/15 - 0s - loss: 0.0010 - 161ms/epoch - 11ms/step\n",
            "Epoch 15/300\n",
            "15/15 - 0s - loss: 3.7597e-04 - 160ms/epoch - 11ms/step\n",
            "Epoch 16/300\n",
            "15/15 - 0s - loss: 2.3599e-04 - 162ms/epoch - 11ms/step\n",
            "Epoch 17/300\n",
            "15/15 - 0s - loss: 1.7035e-04 - 164ms/epoch - 11ms/step\n",
            "Epoch 18/300\n",
            "15/15 - 0s - loss: 1.1514e-04 - 172ms/epoch - 11ms/step\n",
            "Epoch 19/300\n",
            "15/15 - 0s - loss: 1.2201e-04 - 166ms/epoch - 11ms/step\n",
            "Epoch 20/300\n",
            "15/15 - 0s - loss: 9.0416e-05 - 165ms/epoch - 11ms/step\n",
            "Epoch 21/300\n",
            "15/15 - 0s - loss: 3.3621e-05 - 162ms/epoch - 11ms/step\n",
            "Epoch 22/300\n",
            "15/15 - 0s - loss: 3.5003e-05 - 159ms/epoch - 11ms/step\n",
            "Epoch 23/300\n",
            "15/15 - 0s - loss: 2.5957e-05 - 169ms/epoch - 11ms/step\n",
            "Epoch 24/300\n",
            "15/15 - 0s - loss: 1.8330e-05 - 165ms/epoch - 11ms/step\n",
            "Epoch 25/300\n",
            "15/15 - 0s - loss: 2.0533e-05 - 165ms/epoch - 11ms/step\n",
            "Epoch 26/300\n",
            "15/15 - 0s - loss: 2.0419e-05 - 164ms/epoch - 11ms/step\n",
            "Epoch 27/300\n",
            "15/15 - 0s - loss: 1.7338e-05 - 167ms/epoch - 11ms/step\n",
            "Epoch 28/300\n",
            "15/15 - 0s - loss: 1.1321e-05 - 162ms/epoch - 11ms/step\n",
            "Epoch 29/300\n",
            "15/15 - 0s - loss: 1.0546e-05 - 162ms/epoch - 11ms/step\n",
            "Epoch 30/300\n",
            "15/15 - 0s - loss: 1.0944e-05 - 169ms/epoch - 11ms/step\n",
            "Epoch 31/300\n",
            "15/15 - 0s - loss: 1.1172e-05 - 167ms/epoch - 11ms/step\n",
            "Epoch 32/300\n",
            "15/15 - 0s - loss: 6.0298e-06 - 165ms/epoch - 11ms/step\n",
            "Epoch 33/300\n",
            "15/15 - 0s - loss: 6.9388e-06 - 165ms/epoch - 11ms/step\n",
            "Epoch 34/300\n",
            "15/15 - 0s - loss: 6.2688e-06 - 165ms/epoch - 11ms/step\n",
            "Epoch 35/300\n",
            "15/15 - 0s - loss: 4.5486e-06 - 165ms/epoch - 11ms/step\n",
            "Epoch 36/300\n",
            "15/15 - 0s - loss: 4.6380e-06 - 164ms/epoch - 11ms/step\n",
            "Epoch 37/300\n",
            "15/15 - 0s - loss: 4.7684e-06 - 169ms/epoch - 11ms/step\n",
            "Epoch 38/300\n",
            "15/15 - 0s - loss: 1.5904e-05 - 161ms/epoch - 11ms/step\n",
            "Epoch 39/300\n",
            "15/15 - 0s - loss: 1.5738e-05 - 159ms/epoch - 11ms/step\n",
            "Epoch 40/300\n",
            "15/15 - 0s - loss: 2.6713e-06 - 161ms/epoch - 11ms/step\n",
            "Epoch 41/300\n",
            "15/15 - 0s - loss: 2.3722e-06 - 166ms/epoch - 11ms/step\n",
            "Epoch 42/300\n",
            "15/15 - 0s - loss: 3.4382e-06 - 170ms/epoch - 11ms/step\n",
            "Epoch 43/300\n",
            "15/15 - 0s - loss: 2.8751e-06 - 161ms/epoch - 11ms/step\n",
            "Epoch 44/300\n",
            "15/15 - 0s - loss: 3.6055e-06 - 157ms/epoch - 10ms/step\n",
            "Epoch 45/300\n",
            "15/15 - 0s - loss: 2.0925e-06 - 152ms/epoch - 10ms/step\n",
            "Epoch 46/300\n",
            "15/15 - 0s - loss: 3.7457e-06 - 151ms/epoch - 10ms/step\n",
            "Epoch 47/300\n",
            "15/15 - 0s - loss: 8.4038e-06 - 158ms/epoch - 11ms/step\n",
            "Epoch 48/300\n",
            "15/15 - 0s - loss: 1.5027e-06 - 158ms/epoch - 11ms/step\n",
            "Epoch 49/300\n",
            "15/15 - 0s - loss: 1.9858e-06 - 153ms/epoch - 10ms/step\n",
            "Epoch 50/300\n",
            "15/15 - 0s - loss: 7.2850e-06 - 163ms/epoch - 11ms/step\n",
            "Epoch 51/300\n",
            "15/15 - 0s - loss: 1.0350e-05 - 159ms/epoch - 11ms/step\n",
            "Epoch 52/300\n",
            "15/15 - 0s - loss: 9.2933e-05 - 152ms/epoch - 10ms/step\n",
            "Epoch 53/300\n",
            "15/15 - 0s - loss: 5.7264e-04 - 153ms/epoch - 10ms/step\n",
            "Epoch 54/300\n",
            "15/15 - 0s - loss: 7.0577e-04 - 156ms/epoch - 10ms/step\n",
            "Epoch 55/300\n",
            "15/15 - 0s - loss: 2.0676e-04 - 154ms/epoch - 10ms/step\n",
            "Epoch 56/300\n",
            "15/15 - 0s - loss: 6.8527e-05 - 150ms/epoch - 10ms/step\n",
            "Epoch 57/300\n",
            "15/15 - 0s - loss: 4.0665e-05 - 157ms/epoch - 10ms/step\n",
            "Epoch 58/300\n",
            "15/15 - 0s - loss: 7.2068e-06 - 151ms/epoch - 10ms/step\n",
            "Epoch 59/300\n",
            "15/15 - 0s - loss: 3.5091e-06 - 153ms/epoch - 10ms/step\n",
            "Epoch 60/300\n",
            "15/15 - 0s - loss: 3.4957e-06 - 157ms/epoch - 10ms/step\n",
            "Epoch 61/300\n",
            "15/15 - 0s - loss: 8.3725e-06 - 153ms/epoch - 10ms/step\n",
            "Epoch 62/300\n",
            "15/15 - 0s - loss: 1.6718e-05 - 158ms/epoch - 11ms/step\n",
            "Epoch 63/300\n",
            "15/15 - 0s - loss: 1.9949e-05 - 154ms/epoch - 10ms/step\n",
            "Epoch 64/300\n",
            "15/15 - 0s - loss: 9.0646e-06 - 155ms/epoch - 10ms/step\n",
            "Epoch 65/300\n",
            "15/15 - 0s - loss: 1.2163e-05 - 156ms/epoch - 10ms/step\n",
            "Epoch 66/300\n",
            "15/15 - 0s - loss: 1.5129e-05 - 152ms/epoch - 10ms/step\n",
            "Epoch 67/300\n",
            "15/15 - 0s - loss: 9.2843e-06 - 155ms/epoch - 10ms/step\n",
            "Epoch 68/300\n",
            "15/15 - 0s - loss: 1.3217e-05 - 158ms/epoch - 11ms/step\n",
            "Epoch 69/300\n",
            "15/15 - 0s - loss: 1.7184e-05 - 163ms/epoch - 11ms/step\n",
            "Epoch 70/300\n",
            "15/15 - 0s - loss: 3.9230e-06 - 150ms/epoch - 10ms/step\n",
            "Epoch 71/300\n",
            "15/15 - 0s - loss: 9.9230e-06 - 165ms/epoch - 11ms/step\n",
            "Epoch 72/300\n",
            "15/15 - 0s - loss: 1.0119e-05 - 159ms/epoch - 11ms/step\n",
            "Epoch 73/300\n",
            "15/15 - 0s - loss: 8.1462e-06 - 151ms/epoch - 10ms/step\n",
            "Epoch 74/300\n",
            "15/15 - 0s - loss: 5.1228e-06 - 155ms/epoch - 10ms/step\n",
            "Epoch 75/300\n",
            "15/15 - 0s - loss: 5.1433e-06 - 150ms/epoch - 10ms/step\n",
            "Epoch 76/300\n",
            "15/15 - 0s - loss: 2.0605e-06 - 162ms/epoch - 11ms/step\n",
            "Epoch 77/300\n",
            "15/15 - 0s - loss: 7.6334e-07 - 158ms/epoch - 11ms/step\n",
            "Epoch 78/300\n",
            "15/15 - 0s - loss: 9.4166e-07 - 160ms/epoch - 11ms/step\n",
            "Epoch 79/300\n",
            "15/15 - 0s - loss: 1.7487e-06 - 152ms/epoch - 10ms/step\n",
            "Epoch 80/300\n",
            "15/15 - 0s - loss: 1.0295e-06 - 153ms/epoch - 10ms/step\n",
            "Epoch 81/300\n",
            "15/15 - 0s - loss: 2.8929e-06 - 158ms/epoch - 11ms/step\n",
            "Epoch 82/300\n",
            "15/15 - 0s - loss: 7.4247e-06 - 153ms/epoch - 10ms/step\n",
            "Epoch 83/300\n",
            "15/15 - 0s - loss: 4.7313e-06 - 159ms/epoch - 11ms/step\n",
            "Epoch 84/300\n",
            "15/15 - 0s - loss: 1.3355e-06 - 156ms/epoch - 10ms/step\n",
            "Epoch 85/300\n",
            "15/15 - 0s - loss: 1.2078e-05 - 153ms/epoch - 10ms/step\n",
            "Epoch 86/300\n",
            "15/15 - 0s - loss: 4.9421e-05 - 152ms/epoch - 10ms/step\n",
            "Epoch 87/300\n",
            "15/15 - 0s - loss: 1.7619e-04 - 152ms/epoch - 10ms/step\n",
            "Epoch 88/300\n",
            "15/15 - 0s - loss: 2.8511e-05 - 162ms/epoch - 11ms/step\n",
            "Epoch 89/300\n",
            "15/15 - 0s - loss: 9.1337e-06 - 153ms/epoch - 10ms/step\n",
            "Epoch 90/300\n",
            "15/15 - 0s - loss: 1.2788e-05 - 150ms/epoch - 10ms/step\n",
            "Epoch 91/300\n",
            "15/15 - 0s - loss: 1.7800e-05 - 154ms/epoch - 10ms/step\n",
            "Epoch 92/300\n",
            "15/15 - 0s - loss: 4.3808e-06 - 153ms/epoch - 10ms/step\n",
            "Epoch 93/300\n",
            "15/15 - 0s - loss: 2.5492e-06 - 159ms/epoch - 11ms/step\n",
            "Epoch 94/300\n",
            "15/15 - 0s - loss: 2.7812e-06 - 153ms/epoch - 10ms/step\n",
            "Epoch 95/300\n",
            "15/15 - 0s - loss: 2.8194e-06 - 154ms/epoch - 10ms/step\n",
            "Epoch 96/300\n",
            "15/15 - 0s - loss: 1.4666e-06 - 158ms/epoch - 11ms/step\n",
            "Epoch 97/300\n",
            "15/15 - 0s - loss: 1.1745e-06 - 153ms/epoch - 10ms/step\n",
            "Epoch 98/300\n",
            "15/15 - 0s - loss: 4.2295e-07 - 162ms/epoch - 11ms/step\n",
            "Epoch 99/300\n",
            "15/15 - 0s - loss: 3.2026e-07 - 153ms/epoch - 10ms/step\n",
            "Epoch 100/300\n",
            "15/15 - 0s - loss: 2.3621e-07 - 148ms/epoch - 10ms/step\n",
            "Epoch 101/300\n",
            "15/15 - 0s - loss: 4.3640e-07 - 157ms/epoch - 10ms/step\n",
            "Epoch 102/300\n",
            "15/15 - 0s - loss: 2.1619e-07 - 156ms/epoch - 10ms/step\n",
            "Epoch 103/300\n",
            "15/15 - 0s - loss: 3.2122e-07 - 154ms/epoch - 10ms/step\n",
            "Epoch 104/300\n",
            "15/15 - 0s - loss: 3.5772e-07 - 155ms/epoch - 10ms/step\n",
            "Epoch 105/300\n",
            "15/15 - 0s - loss: 1.9384e-07 - 151ms/epoch - 10ms/step\n",
            "Epoch 106/300\n",
            "15/15 - 0s - loss: 1.1047e-07 - 158ms/epoch - 11ms/step\n",
            "Epoch 107/300\n",
            "15/15 - 0s - loss: 1.2855e-07 - 162ms/epoch - 11ms/step\n",
            "Epoch 108/300\n",
            "15/15 - 0s - loss: 1.0713e-07 - 156ms/epoch - 10ms/step\n",
            "Epoch 109/300\n",
            "15/15 - 0s - loss: 1.3824e-07 - 166ms/epoch - 11ms/step\n",
            "Epoch 110/300\n",
            "15/15 - 0s - loss: 1.7921e-07 - 159ms/epoch - 11ms/step\n",
            "Epoch 111/300\n",
            "15/15 - 0s - loss: 5.4620e-07 - 157ms/epoch - 10ms/step\n",
            "Epoch 112/300\n",
            "15/15 - 0s - loss: 3.5711e-06 - 159ms/epoch - 11ms/step\n",
            "Epoch 113/300\n",
            "15/15 - 0s - loss: 2.8771e-06 - 152ms/epoch - 10ms/step\n",
            "Epoch 114/300\n",
            "15/15 - 0s - loss: 1.8036e-06 - 155ms/epoch - 10ms/step\n",
            "Epoch 115/300\n",
            "15/15 - 0s - loss: 7.9089e-07 - 160ms/epoch - 11ms/step\n",
            "Epoch 116/300\n",
            "15/15 - 0s - loss: 7.1739e-07 - 159ms/epoch - 11ms/step\n",
            "Epoch 117/300\n",
            "15/15 - 0s - loss: 4.0318e-07 - 165ms/epoch - 11ms/step\n",
            "Epoch 118/300\n",
            "15/15 - 0s - loss: 2.0311e-06 - 161ms/epoch - 11ms/step\n",
            "Epoch 119/300\n",
            "15/15 - 0s - loss: 3.7540e-06 - 157ms/epoch - 10ms/step\n",
            "Epoch 120/300\n",
            "15/15 - 0s - loss: 4.6012e-06 - 156ms/epoch - 10ms/step\n",
            "Epoch 121/300\n",
            "15/15 - 0s - loss: 1.4149e-06 - 158ms/epoch - 11ms/step\n",
            "Epoch 122/300\n",
            "15/15 - 0s - loss: 5.4807e-07 - 156ms/epoch - 10ms/step\n",
            "Epoch 123/300\n",
            "15/15 - 0s - loss: 2.5128e-06 - 156ms/epoch - 10ms/step\n",
            "Epoch 124/300\n",
            "15/15 - 0s - loss: 1.0745e-06 - 156ms/epoch - 10ms/step\n",
            "Epoch 125/300\n",
            "15/15 - 0s - loss: 1.3167e-06 - 155ms/epoch - 10ms/step\n",
            "Epoch 126/300\n",
            "15/15 - 0s - loss: 1.8597e-06 - 157ms/epoch - 10ms/step\n",
            "Epoch 127/300\n",
            "15/15 - 0s - loss: 7.4599e-06 - 161ms/epoch - 11ms/step\n",
            "Epoch 128/300\n",
            "15/15 - 0s - loss: 1.0053e-06 - 156ms/epoch - 10ms/step\n",
            "Epoch 129/300\n",
            "15/15 - 0s - loss: 2.0620e-05 - 155ms/epoch - 10ms/step\n",
            "Epoch 130/300\n",
            "15/15 - 0s - loss: 1.8044e-04 - 157ms/epoch - 10ms/step\n",
            "Epoch 131/300\n",
            "15/15 - 0s - loss: 1.2260e-04 - 164ms/epoch - 11ms/step\n",
            "Epoch 132/300\n",
            "15/15 - 0s - loss: 4.5144e-05 - 161ms/epoch - 11ms/step\n",
            "Epoch 133/300\n",
            "15/15 - 0s - loss: 6.3156e-05 - 157ms/epoch - 10ms/step\n",
            "Epoch 134/300\n",
            "15/15 - 0s - loss: 7.4351e-05 - 156ms/epoch - 10ms/step\n",
            "Epoch 135/300\n",
            "15/15 - 0s - loss: 5.5949e-05 - 158ms/epoch - 11ms/step\n",
            "Epoch 136/300\n",
            "15/15 - 0s - loss: 1.2320e-04 - 160ms/epoch - 11ms/step\n",
            "Epoch 137/300\n",
            "15/15 - 0s - loss: 4.5785e-05 - 158ms/epoch - 11ms/step\n",
            "Epoch 138/300\n",
            "15/15 - 0s - loss: 1.1484e-05 - 155ms/epoch - 10ms/step\n",
            "Epoch 139/300\n",
            "15/15 - 0s - loss: 2.0034e-05 - 159ms/epoch - 11ms/step\n",
            "Epoch 140/300\n",
            "15/15 - 0s - loss: 5.0207e-05 - 155ms/epoch - 10ms/step\n",
            "Epoch 141/300\n",
            "15/15 - 0s - loss: 2.4747e-05 - 161ms/epoch - 11ms/step\n",
            "Epoch 142/300\n",
            "15/15 - 0s - loss: 2.6868e-06 - 162ms/epoch - 11ms/step\n",
            "Epoch 143/300\n",
            "15/15 - 0s - loss: 1.8803e-06 - 160ms/epoch - 11ms/step\n",
            "Epoch 144/300\n",
            "15/15 - 0s - loss: 1.2996e-05 - 159ms/epoch - 11ms/step\n",
            "Epoch 145/300\n",
            "15/15 - 0s - loss: 1.3690e-05 - 154ms/epoch - 10ms/step\n",
            "Epoch 146/300\n",
            "15/15 - 0s - loss: 2.4086e-06 - 159ms/epoch - 11ms/step\n",
            "Epoch 147/300\n",
            "15/15 - 0s - loss: 9.8821e-07 - 161ms/epoch - 11ms/step\n",
            "Epoch 148/300\n",
            "15/15 - 0s - loss: 7.1822e-07 - 163ms/epoch - 11ms/step\n",
            "Epoch 149/300\n",
            "15/15 - 0s - loss: 6.6779e-07 - 160ms/epoch - 11ms/step\n",
            "Epoch 150/300\n",
            "15/15 - 0s - loss: 2.8166e-06 - 154ms/epoch - 10ms/step\n",
            "Epoch 151/300\n",
            "15/15 - 0s - loss: 3.2752e-06 - 161ms/epoch - 11ms/step\n",
            "Epoch 152/300\n",
            "15/15 - 0s - loss: 4.4531e-06 - 159ms/epoch - 11ms/step\n",
            "Epoch 153/300\n",
            "15/15 - 0s - loss: 1.1895e-06 - 161ms/epoch - 11ms/step\n",
            "Epoch 154/300\n",
            "15/15 - 0s - loss: 5.8231e-06 - 160ms/epoch - 11ms/step\n",
            "Epoch 155/300\n",
            "15/15 - 0s - loss: 5.0539e-05 - 159ms/epoch - 11ms/step\n",
            "Epoch 156/300\n",
            "15/15 - 0s - loss: 5.2146e-05 - 158ms/epoch - 11ms/step\n",
            "Epoch 157/300\n",
            "15/15 - 0s - loss: 6.6994e-05 - 156ms/epoch - 10ms/step\n",
            "Epoch 158/300\n",
            "15/15 - 0s - loss: 7.5773e-05 - 156ms/epoch - 10ms/step\n",
            "Epoch 159/300\n",
            "15/15 - 0s - loss: 2.2007e-04 - 156ms/epoch - 10ms/step\n",
            "Epoch 160/300\n",
            "15/15 - 0s - loss: 2.6289e-05 - 152ms/epoch - 10ms/step\n",
            "Epoch 161/300\n",
            "15/15 - 0s - loss: 1.4910e-05 - 153ms/epoch - 10ms/step\n",
            "Epoch 162/300\n",
            "15/15 - 0s - loss: 7.8941e-05 - 160ms/epoch - 11ms/step\n",
            "Epoch 163/300\n",
            "15/15 - 0s - loss: 6.5873e-05 - 163ms/epoch - 11ms/step\n",
            "Epoch 164/300\n",
            "15/15 - 0s - loss: 5.7652e-05 - 165ms/epoch - 11ms/step\n",
            "Epoch 165/300\n",
            "15/15 - 0s - loss: 6.8477e-05 - 169ms/epoch - 11ms/step\n",
            "Epoch 166/300\n",
            "15/15 - 0s - loss: 1.4567e-05 - 159ms/epoch - 11ms/step\n",
            "Epoch 167/300\n",
            "15/15 - 0s - loss: 1.3806e-05 - 167ms/epoch - 11ms/step\n",
            "Epoch 168/300\n",
            "15/15 - 0s - loss: 5.3796e-06 - 163ms/epoch - 11ms/step\n",
            "Epoch 169/300\n",
            "15/15 - 0s - loss: 5.1897e-06 - 163ms/epoch - 11ms/step\n",
            "Epoch 170/300\n",
            "15/15 - 0s - loss: 5.9178e-06 - 163ms/epoch - 11ms/step\n",
            "Epoch 171/300\n",
            "15/15 - 0s - loss: 7.9625e-06 - 158ms/epoch - 11ms/step\n",
            "Epoch 172/300\n",
            "15/15 - 0s - loss: 9.0737e-06 - 158ms/epoch - 11ms/step\n",
            "Epoch 173/300\n",
            "15/15 - 0s - loss: 5.9115e-06 - 158ms/epoch - 11ms/step\n",
            "Epoch 174/300\n",
            "15/15 - 0s - loss: 2.0013e-06 - 159ms/epoch - 11ms/step\n",
            "Epoch 175/300\n",
            "15/15 - 0s - loss: 1.1812e-06 - 163ms/epoch - 11ms/step\n",
            "Epoch 176/300\n",
            "15/15 - 0s - loss: 6.4147e-07 - 162ms/epoch - 11ms/step\n",
            "Epoch 177/300\n",
            "15/15 - 0s - loss: 1.6583e-06 - 159ms/epoch - 11ms/step\n",
            "Epoch 178/300\n",
            "15/15 - 0s - loss: 3.5972e-06 - 171ms/epoch - 11ms/step\n",
            "Epoch 179/300\n",
            "15/15 - 0s - loss: 2.4839e-06 - 158ms/epoch - 11ms/step\n",
            "Epoch 180/300\n",
            "15/15 - 0s - loss: 1.0166e-06 - 160ms/epoch - 11ms/step\n",
            "Epoch 181/300\n",
            "15/15 - 0s - loss: 1.2895e-07 - 159ms/epoch - 11ms/step\n",
            "Epoch 182/300\n",
            "15/15 - 0s - loss: 1.5877e-07 - 162ms/epoch - 11ms/step\n",
            "Epoch 183/300\n",
            "15/15 - 0s - loss: 9.0683e-07 - 157ms/epoch - 10ms/step\n",
            "Epoch 184/300\n",
            "15/15 - 0s - loss: 3.4370e-07 - 156ms/epoch - 10ms/step\n",
            "Epoch 185/300\n",
            "15/15 - 0s - loss: 2.6290e-07 - 157ms/epoch - 10ms/step\n",
            "Epoch 186/300\n",
            "15/15 - 0s - loss: 1.9069e-07 - 157ms/epoch - 10ms/step\n",
            "Epoch 187/300\n",
            "15/15 - 0s - loss: 6.3974e-07 - 163ms/epoch - 11ms/step\n",
            "Epoch 188/300\n",
            "15/15 - 0s - loss: 1.4942e-06 - 163ms/epoch - 11ms/step\n",
            "Epoch 189/300\n",
            "15/15 - 0s - loss: 6.2188e-06 - 157ms/epoch - 10ms/step\n",
            "Epoch 190/300\n",
            "15/15 - 0s - loss: 2.3015e-05 - 161ms/epoch - 11ms/step\n",
            "Epoch 191/300\n",
            "15/15 - 0s - loss: 9.8418e-05 - 155ms/epoch - 10ms/step\n",
            "Epoch 192/300\n",
            "15/15 - 0s - loss: 1.6463e-04 - 160ms/epoch - 11ms/step\n",
            "Epoch 193/300\n",
            "15/15 - 0s - loss: 2.0316e-04 - 166ms/epoch - 11ms/step\n",
            "Epoch 194/300\n",
            "15/15 - 0s - loss: 2.0298e-04 - 164ms/epoch - 11ms/step\n",
            "Epoch 195/300\n",
            "15/15 - 0s - loss: 4.8672e-04 - 162ms/epoch - 11ms/step\n",
            "Epoch 196/300\n",
            "15/15 - 0s - loss: 0.0535 - 168ms/epoch - 11ms/step\n",
            "Epoch 197/300\n",
            "15/15 - 0s - loss: 0.0445 - 160ms/epoch - 11ms/step\n",
            "Epoch 198/300\n",
            "15/15 - 0s - loss: 0.0185 - 166ms/epoch - 11ms/step\n",
            "Epoch 199/300\n",
            "15/15 - 0s - loss: 0.0094 - 169ms/epoch - 11ms/step\n",
            "Epoch 200/300\n",
            "15/15 - 0s - loss: 0.0048 - 158ms/epoch - 11ms/step\n",
            "Epoch 201/300\n",
            "15/15 - 0s - loss: 9.8256e-04 - 165ms/epoch - 11ms/step\n",
            "Epoch 202/300\n",
            "15/15 - 0s - loss: 5.2631e-04 - 166ms/epoch - 11ms/step\n",
            "Epoch 203/300\n",
            "15/15 - 0s - loss: 8.6200e-05 - 165ms/epoch - 11ms/step\n",
            "Epoch 204/300\n",
            "15/15 - 0s - loss: 3.2787e-05 - 159ms/epoch - 11ms/step\n",
            "Epoch 205/300\n",
            "15/15 - 0s - loss: 1.8895e-05 - 165ms/epoch - 11ms/step\n",
            "Epoch 206/300\n",
            "15/15 - 0s - loss: 1.3025e-05 - 159ms/epoch - 11ms/step\n",
            "Epoch 207/300\n",
            "15/15 - 0s - loss: 1.0598e-05 - 157ms/epoch - 10ms/step\n",
            "Epoch 208/300\n",
            "15/15 - 0s - loss: 7.2356e-06 - 154ms/epoch - 10ms/step\n",
            "Epoch 209/300\n",
            "15/15 - 0s - loss: 5.1274e-06 - 153ms/epoch - 10ms/step\n",
            "Epoch 210/300\n",
            "15/15 - 0s - loss: 4.4314e-06 - 168ms/epoch - 11ms/step\n",
            "Epoch 211/300\n",
            "15/15 - 0s - loss: 3.8059e-06 - 165ms/epoch - 11ms/step\n",
            "Epoch 212/300\n",
            "15/15 - 0s - loss: 3.0771e-06 - 154ms/epoch - 10ms/step\n",
            "Epoch 213/300\n",
            "15/15 - 0s - loss: 2.6971e-06 - 156ms/epoch - 10ms/step\n",
            "Epoch 214/300\n",
            "15/15 - 0s - loss: 2.6234e-06 - 152ms/epoch - 10ms/step\n",
            "Epoch 215/300\n",
            "15/15 - 0s - loss: 3.0180e-06 - 158ms/epoch - 11ms/step\n",
            "Epoch 216/300\n",
            "15/15 - 0s - loss: 3.4726e-06 - 157ms/epoch - 10ms/step\n",
            "Epoch 217/300\n",
            "15/15 - 0s - loss: 2.3030e-06 - 164ms/epoch - 11ms/step\n",
            "Epoch 218/300\n",
            "15/15 - 0s - loss: 1.9915e-06 - 172ms/epoch - 11ms/step\n",
            "Epoch 219/300\n",
            "15/15 - 0s - loss: 1.6561e-06 - 165ms/epoch - 11ms/step\n",
            "Epoch 220/300\n",
            "15/15 - 0s - loss: 1.4225e-06 - 157ms/epoch - 10ms/step\n",
            "Epoch 221/300\n",
            "15/15 - 0s - loss: 1.2966e-06 - 160ms/epoch - 11ms/step\n",
            "Epoch 222/300\n",
            "15/15 - 0s - loss: 1.2228e-06 - 166ms/epoch - 11ms/step\n",
            "Epoch 223/300\n",
            "15/15 - 0s - loss: 1.1352e-06 - 156ms/epoch - 10ms/step\n",
            "Epoch 224/300\n",
            "15/15 - 0s - loss: 1.1424e-06 - 164ms/epoch - 11ms/step\n",
            "Epoch 225/300\n",
            "15/15 - 0s - loss: 1.1585e-06 - 166ms/epoch - 11ms/step\n",
            "Epoch 226/300\n",
            "15/15 - 0s - loss: 1.2208e-06 - 157ms/epoch - 10ms/step\n",
            "Epoch 227/300\n",
            "15/15 - 0s - loss: 1.0862e-06 - 163ms/epoch - 11ms/step\n",
            "Epoch 228/300\n",
            "15/15 - 0s - loss: 1.1594e-06 - 158ms/epoch - 11ms/step\n",
            "Epoch 229/300\n",
            "15/15 - 0s - loss: 1.5344e-06 - 159ms/epoch - 11ms/step\n",
            "Epoch 230/300\n",
            "15/15 - 0s - loss: 1.7111e-06 - 158ms/epoch - 11ms/step\n",
            "Epoch 231/300\n",
            "15/15 - 0s - loss: 1.3852e-06 - 164ms/epoch - 11ms/step\n",
            "Epoch 232/300\n",
            "15/15 - 0s - loss: 8.7603e-07 - 159ms/epoch - 11ms/step\n",
            "Epoch 233/300\n",
            "15/15 - 0s - loss: 6.9070e-07 - 173ms/epoch - 12ms/step\n",
            "Epoch 234/300\n",
            "15/15 - 0s - loss: 6.5760e-07 - 162ms/epoch - 11ms/step\n",
            "Epoch 235/300\n",
            "15/15 - 0s - loss: 6.0191e-07 - 164ms/epoch - 11ms/step\n",
            "Epoch 236/300\n",
            "15/15 - 0s - loss: 6.3585e-07 - 158ms/epoch - 11ms/step\n",
            "Epoch 237/300\n",
            "15/15 - 0s - loss: 5.3993e-07 - 166ms/epoch - 11ms/step\n",
            "Epoch 238/300\n",
            "15/15 - 0s - loss: 5.6463e-07 - 158ms/epoch - 11ms/step\n",
            "Epoch 239/300\n",
            "15/15 - 0s - loss: 6.5308e-07 - 167ms/epoch - 11ms/step\n",
            "Epoch 240/300\n",
            "15/15 - 0s - loss: 6.8948e-07 - 165ms/epoch - 11ms/step\n",
            "Epoch 241/300\n",
            "15/15 - 0s - loss: 4.7476e-07 - 159ms/epoch - 11ms/step\n",
            "Epoch 242/300\n",
            "15/15 - 0s - loss: 4.8182e-07 - 163ms/epoch - 11ms/step\n",
            "Epoch 243/300\n",
            "15/15 - 0s - loss: 4.8436e-07 - 165ms/epoch - 11ms/step\n",
            "Epoch 244/300\n",
            "15/15 - 0s - loss: 4.2307e-07 - 165ms/epoch - 11ms/step\n",
            "Epoch 245/300\n",
            "15/15 - 0s - loss: 4.2961e-07 - 170ms/epoch - 11ms/step\n",
            "Epoch 246/300\n",
            "15/15 - 0s - loss: 4.5584e-07 - 164ms/epoch - 11ms/step\n",
            "Epoch 247/300\n",
            "15/15 - 0s - loss: 3.9954e-07 - 154ms/epoch - 10ms/step\n",
            "Epoch 248/300\n",
            "15/15 - 0s - loss: 3.5126e-07 - 158ms/epoch - 11ms/step\n",
            "Epoch 249/300\n",
            "15/15 - 0s - loss: 3.8651e-07 - 159ms/epoch - 11ms/step\n",
            "Epoch 250/300\n",
            "15/15 - 0s - loss: 3.5794e-07 - 166ms/epoch - 11ms/step\n",
            "Epoch 251/300\n",
            "15/15 - 0s - loss: 3.8071e-07 - 158ms/epoch - 11ms/step\n",
            "Epoch 252/300\n",
            "15/15 - 0s - loss: 3.4994e-07 - 167ms/epoch - 11ms/step\n",
            "Epoch 253/300\n",
            "15/15 - 0s - loss: 3.3035e-07 - 157ms/epoch - 10ms/step\n",
            "Epoch 254/300\n",
            "15/15 - 0s - loss: 3.4228e-07 - 166ms/epoch - 11ms/step\n",
            "Epoch 255/300\n",
            "15/15 - 0s - loss: 2.8867e-07 - 155ms/epoch - 10ms/step\n",
            "Epoch 256/300\n",
            "15/15 - 0s - loss: 4.8354e-07 - 159ms/epoch - 11ms/step\n",
            "Epoch 257/300\n",
            "15/15 - 0s - loss: 3.3378e-07 - 159ms/epoch - 11ms/step\n",
            "Epoch 258/300\n",
            "15/15 - 0s - loss: 2.8491e-07 - 163ms/epoch - 11ms/step\n",
            "Epoch 259/300\n",
            "15/15 - 0s - loss: 2.9364e-07 - 165ms/epoch - 11ms/step\n",
            "Epoch 260/300\n",
            "15/15 - 0s - loss: 4.6714e-07 - 161ms/epoch - 11ms/step\n",
            "Epoch 261/300\n",
            "15/15 - 0s - loss: 3.3356e-07 - 165ms/epoch - 11ms/step\n",
            "Epoch 262/300\n",
            "15/15 - 0s - loss: 2.1951e-07 - 163ms/epoch - 11ms/step\n",
            "Epoch 263/300\n",
            "15/15 - 0s - loss: 2.5106e-07 - 161ms/epoch - 11ms/step\n",
            "Epoch 264/300\n",
            "15/15 - 0s - loss: 4.2915e-07 - 166ms/epoch - 11ms/step\n",
            "Epoch 265/300\n",
            "15/15 - 0s - loss: 3.9619e-07 - 157ms/epoch - 10ms/step\n",
            "Epoch 266/300\n",
            "15/15 - 0s - loss: 2.4926e-07 - 170ms/epoch - 11ms/step\n",
            "Epoch 267/300\n",
            "15/15 - 0s - loss: 2.6463e-07 - 162ms/epoch - 11ms/step\n",
            "Epoch 268/300\n",
            "15/15 - 0s - loss: 2.2773e-07 - 159ms/epoch - 11ms/step\n",
            "Epoch 269/300\n",
            "15/15 - 0s - loss: 2.3046e-07 - 165ms/epoch - 11ms/step\n",
            "Epoch 270/300\n",
            "15/15 - 0s - loss: 1.8284e-07 - 160ms/epoch - 11ms/step\n",
            "Epoch 271/300\n",
            "15/15 - 0s - loss: 1.7507e-07 - 161ms/epoch - 11ms/step\n",
            "Epoch 272/300\n",
            "15/15 - 0s - loss: 2.1668e-07 - 163ms/epoch - 11ms/step\n",
            "Epoch 273/300\n",
            "15/15 - 0s - loss: 1.6727e-07 - 154ms/epoch - 10ms/step\n",
            "Epoch 274/300\n",
            "15/15 - 0s - loss: 1.6174e-07 - 152ms/epoch - 10ms/step\n",
            "Epoch 275/300\n",
            "15/15 - 0s - loss: 1.9149e-07 - 157ms/epoch - 10ms/step\n",
            "Epoch 276/300\n",
            "15/15 - 0s - loss: 2.4064e-07 - 161ms/epoch - 11ms/step\n",
            "Epoch 277/300\n",
            "15/15 - 0s - loss: 2.9025e-07 - 166ms/epoch - 11ms/step\n",
            "Epoch 278/300\n",
            "15/15 - 0s - loss: 1.5687e-07 - 160ms/epoch - 11ms/step\n",
            "Epoch 279/300\n",
            "15/15 - 0s - loss: 1.3006e-07 - 169ms/epoch - 11ms/step\n",
            "Epoch 280/300\n",
            "15/15 - 0s - loss: 1.2471e-07 - 156ms/epoch - 10ms/step\n",
            "Epoch 281/300\n",
            "15/15 - 0s - loss: 1.1829e-07 - 157ms/epoch - 10ms/step\n",
            "Epoch 282/300\n",
            "15/15 - 0s - loss: 1.3956e-07 - 159ms/epoch - 11ms/step\n",
            "Epoch 283/300\n",
            "15/15 - 0s - loss: 1.5758e-07 - 159ms/epoch - 11ms/step\n",
            "Epoch 284/300\n",
            "15/15 - 0s - loss: 1.5391e-07 - 160ms/epoch - 11ms/step\n",
            "Epoch 285/300\n",
            "15/15 - 0s - loss: 1.5173e-07 - 152ms/epoch - 10ms/step\n",
            "Epoch 286/300\n",
            "15/15 - 0s - loss: 2.4573e-07 - 158ms/epoch - 11ms/step\n",
            "Epoch 287/300\n",
            "15/15 - 0s - loss: 2.2439e-07 - 158ms/epoch - 11ms/step\n",
            "Epoch 288/300\n",
            "15/15 - 0s - loss: 2.5811e-07 - 159ms/epoch - 11ms/step\n",
            "Epoch 289/300\n",
            "15/15 - 0s - loss: 3.3408e-07 - 162ms/epoch - 11ms/step\n",
            "Epoch 290/300\n",
            "15/15 - 0s - loss: 2.6625e-07 - 158ms/epoch - 11ms/step\n",
            "Epoch 291/300\n",
            "15/15 - 0s - loss: 1.8685e-07 - 154ms/epoch - 10ms/step\n",
            "Epoch 292/300\n",
            "15/15 - 0s - loss: 7.9143e-07 - 159ms/epoch - 11ms/step\n",
            "Epoch 293/300\n",
            "15/15 - 0s - loss: 3.9901e-06 - 157ms/epoch - 10ms/step\n",
            "Epoch 294/300\n",
            "15/15 - 0s - loss: 6.7444e-07 - 158ms/epoch - 11ms/step\n",
            "Epoch 295/300\n",
            "15/15 - 0s - loss: 9.3395e-07 - 175ms/epoch - 12ms/step\n",
            "Epoch 296/300\n",
            "15/15 - 0s - loss: 1.7670e-06 - 174ms/epoch - 12ms/step\n",
            "Epoch 297/300\n",
            "15/15 - 0s - loss: 2.1740e-06 - 155ms/epoch - 10ms/step\n",
            "Epoch 298/300\n",
            "15/15 - 0s - loss: 1.3254e-06 - 165ms/epoch - 11ms/step\n",
            "Epoch 299/300\n",
            "15/15 - 0s - loss: 1.2474e-06 - 169ms/epoch - 11ms/step\n",
            "Epoch 300/300\n",
            "15/15 - 0s - loss: 1.0169e-06 - 160ms/epoch - 11ms/step\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<keras.src.callbacks.History at 0x79197c2d89a0>"
            ]
          },
          "metadata": {},
          "execution_count": 95
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "n-aNP4n3sqG_",
        "outputId": "87ad0ea7-106a-4a82-a0e7-ba5e99783ea3",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 443
        }
      },
      "source": [
        "# Plotting code, feel free to ignore.\n",
        "h = 1.0\n",
        "x_min, x_max = X[:, 0].min() - 5, X[:, 0].max() + 5\n",
        "y_min, y_max = X[:, 1].min() - 5, X[:, 1].max() + 5\n",
        "xx, yy = np.meshgrid(np.arange(x_min, x_max, h),\n",
        "                     np.arange(y_min, y_max, h))\n",
        "\n",
        "# here \"model\" is your model's prediction (classification) function\n",
        "Z = tn_model.predict(np.c_[xx.ravel(), yy.ravel()])\n",
        "\n",
        "# Put the result into a color plot\n",
        "Z = Z.reshape(xx.shape)\n",
        "plt.contourf(xx, yy, Z)\n",
        "plt.axis('off')\n",
        "\n",
        "# Plot also the training points\n",
        "plt.scatter(X[:, 0], X[:, 1], c=Y, cmap=plt.cm.Paired)"
      ],
      "execution_count": 96,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "16/16 [==============================] - 0s 5ms/step\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<matplotlib.collections.PathCollection at 0x791974456ef0>"
            ]
          },
          "metadata": {},
          "execution_count": 96
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<Figure size 640x480 with 1 Axes>"
            ],
            "image/png": "\n"
          },
          "metadata": {}
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "seconds = time.time()\n",
        "print(\"Time in seconds since end of run:\", seconds)\n",
        "local_time = time.ctime(seconds)\n",
        "print(local_time)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 0
        },
        "id": "s_ukr55OORqE",
        "outputId": "6b92866c-a3cb-48e6-e690-1dba3579edaa"
      },
      "execution_count": 97,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Time in seconds since end of run: 1712560134.6738951\n",
            "Mon Apr  8 07:08:54 2024\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "seconds = time.time()\n",
        "print(\"Time in seconds since beginning of run:\", seconds)\n",
        "local_time = time.ctime(seconds)\n",
        "print(local_time)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 0
        },
        "id": "o8HTyvcHchzQ",
        "outputId": "03b6969f-a4f1-4901-9e9e-23f8f950fe51"
      },
      "execution_count": 98,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Time in seconds since beginning of run: 1712560134.6804094\n",
            "Mon Apr  8 07:08:54 2024\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Function to compute saliency map\n",
        "@tf.function\n",
        "def compute_saliency(input_image):\n",
        "    with tf.GradientTape() as tape:\n",
        "        tape.watch(input_image)\n",
        "        predictions = tn_model(input_image)\n",
        "    grads = tape.gradient(predictions, input_image)\n",
        "    saliency_map = tf.reduce_max(tf.abs(grads), axis=-1)\n",
        "    return saliency_map\n",
        "\n",
        "# Function to compute saliency map using Gradient\n",
        "@tf.function\n",
        "def compute_gradient_saliency(input_image):\n",
        "    with tf.GradientTape() as tape:\n",
        "        tape.watch(input_image)\n",
        "        predictions = tn_model(input_image)\n",
        "    grads = tape.gradient(predictions, input_image)\n",
        "    saliency_map = tf.reduce_max(tf.abs(grads), axis=-1)\n",
        "    return saliency_map\n",
        "\n",
        "# Compute saliency map for the entire grid\n",
        "def compute_saliency_map_grid():\n",
        "    xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))\n",
        "    input_image = np.c_[xx.ravel(), yy.ravel()]\n",
        "    saliency_map = compute_saliency(tf.constant(input_image, dtype=tf.float32)).numpy()\n",
        "    saliency_map = saliency_map.reshape(xx.shape)\n",
        "    return xx, yy, saliency_map\n",
        "\n",
        "# Compute and plot saliency map for the entire grid\n",
        "xx, yy, saliency_map = compute_saliency_map_grid()\n",
        "\n",
        "# Compute saliency maps for all data points\n",
        "def compute_saliency_maps():\n",
        "    saliency_maps = []\n",
        "    for data_point in X:\n",
        "        saliency_map = compute_gradient_saliency(tf.constant(data_point[None, :], dtype=tf.float32)).numpy()\n",
        "        saliency_maps.append(saliency_map)\n",
        "    return saliency_maps\n",
        "\n",
        "# Find the indices of the data points with the highest saliency values\n",
        "def find_top_indices(saliency_maps, top_k):\n",
        "    top_indices = np.argsort(np.max(saliency_maps, axis=1))[-top_k:]\n",
        "    return top_indices\n",
        "\n",
        "def plot_most_diagnostic(top_indices, top_k, normalized_saliency_values):\n",
        "    plt.figure(figsize=(8, 6))\n",
        "    plt.scatter(X[:, 0], X[:, 1], c=Y, cmap=plt.cm.Paired)\n",
        "    plt.scatter(X[top_indices, 0], X[top_indices, 1], marker='o', s=200, facecolors='none', edgecolors='r', linewidths=2)\n",
        "    for i, index in enumerate(top_indices):\n",
        "        plt.annotate(f'{normalized_saliency_values.iloc[index][\"Saliency\"]:.4f}', (X[index, 0], X[index, 1]), xytext=(X[index, 0]+0.35, X[index, 1]+0.25), arrowprops=dict(facecolor='black', arrowstyle='->'))\n",
        "    plt.title(f'Saliency Most Diagnostic Data Points (Top {top_k})')\n",
        "    plt.xlabel('Feature 1')\n",
        "    plt.ylabel('Feature 2')\n",
        "    plt.grid(True)\n",
        "    plt.axis('equal')\n",
        "    plt.show()\n",
        "\n",
        "# Compute saliency maps for all data points\n",
        "saliency_maps = compute_saliency_maps()\n",
        "\n",
        "# Find the indices of the data points with the highest saliency values\n",
        "top_k = 5  # Number of top diagnostic data points to select\n",
        "top_indices = find_top_indices(saliency_maps, top_k)\n",
        "\n",
        "# Create a DataFrame to store the saliency values\n",
        "saliency_df = pd.DataFrame(data=saliency_maps, columns=[\"Saliency\"])\n",
        "\n",
        "# Save the saliency values to a CSV file\n",
        "saliency_df.to_csv(\"saliency_values.csv\", index=False)\n",
        "\n",
        "print(\"Saliency values saved to saliency_values.csv\")\n",
        "\n",
        "# Normalizing the saliency values\n",
        "normalized_saliency = (saliency_df - saliency_df.min()) / (saliency_df.max() - saliency_df.min())\n",
        "\n",
        "# Saving the normalized saliency values to a new CSV file\n",
        "normalized_saliency.to_csv(\"normalized_saliency_values.csv\", index=False)\n",
        "\n",
        "# Plot the most diagnostic data points\n",
        "plot_most_diagnostic(top_indices, top_k, normalized_saliency)\n",
        "\n",
        "print(\"Normalized saliency values saved to normalized_saliency_values.csv\")\n",
        "print(\"Normalized Saliency Top-k:\")\n",
        "print(normalized_saliency.nlargest(top_k, 'Saliency'))\n",
        "print(\"Normalized Saliency Max:\", normalized_saliency.max())\n",
        "print(\"Normalized Saliency Min:\", normalized_saliency.min())\n",
        "print(\"Normalized Saliency Mean:\", normalized_saliency.mean())\n",
        "print(\"Normalized Saliency Median:\", normalized_saliency.median())\n",
        "print(\"Normalized Saliency Mode:\", normalized_saliency.mode())\n",
        "sum_normalized_values = normalized_saliency.sum()\n",
        "print(\"Normalized Saliency Sum:\", sum_normalized_values)\n",
        "print(\"#\")\n",
        "print(\"#\")\n",
        "print(\"#\")\n",
        "print(\"Normalized Saliency Standard Deviation:\", normalized_saliency.std())\n",
        "print(\"Normalized Saliency Skewness:\", normalized_saliency.skew())\n",
        "print(\"Normalized Saliency Kurtosis:\", normalized_saliency.kurtosis())\n",
        "print(\"Normalized Saliency Variance:\", normalized_saliency.var())\n",
        "coefficient_variation = (normalized_saliency.std() / normalized_saliency.mean()) * 100\n",
        "print(\"Normalized Saliency Coefficient of Variation:\", coefficient_variation)\n",
        "print(\"#\")\n",
        "print(\"#\")\n",
        "print(\"#\")\n",
        "cumulative_sum = normalized_saliency.cumsum()\n",
        "print(\"Cumulative Sum of Normalized Saliency Values:\", cumulative_sum)\n",
        "mean_cumulative_sum = cumulative_sum / len(normalized_saliency)\n",
        "print(\"Mean of Cumulative Sum of Normalized Saliency Values:\", mean_cumulative_sum)\n",
        "rms = np.sqrt(np.mean(normalized_saliency**2))\n",
        "print(\"Normalized Saliency Root Mean Square:\", rms)\n",
        "q1 = normalized_saliency.quantile(0.25)\n",
        "q2 = normalized_saliency.quantile(0.75)\n",
        "iqr = q2 - q1\n",
        "print(\"Normalized Saliency 25th Percentile:\", q1)\n",
        "print(\"Normalized Saliency 75th Percentile:\", q2)\n",
        "print(\"Normalized Saliency Interquartile Range:\", iqr)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1987
        },
        "id": "95xed6YyDClf",
        "outputId": "695ed818-fa52-4b76-9f2f-7401b4041e17"
      },
      "execution_count": 99,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Saliency values saved to saliency_values.csv\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<Figure size 800x600 with 1 Axes>"
            ],
            "image/png": "\n"
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Normalized saliency values saved to normalized_saliency_values.csv\n",
            "Normalized Saliency Top-k:\n",
            "     Saliency\n",
            "370  1.000000\n",
            "327  0.595631\n",
            "239  0.450081\n",
            "377  0.309829\n",
            "287  0.173656\n",
            "Normalized Saliency Max: Saliency    1.0\n",
            "dtype: float32\n",
            "Normalized Saliency Min: Saliency    0.0\n",
            "dtype: float32\n",
            "Normalized Saliency Mean: Saliency    0.009756\n",
            "dtype: float32\n",
            "Normalized Saliency Median: Saliency    0.002819\n",
            "dtype: float32\n",
            "Normalized Saliency Mode:    Saliency\n",
            "0  0.001755\n",
            "1  0.002679\n",
            "2  0.002819\n",
            "3  0.003298\n",
            "4  0.003505\n",
            "5  0.003582\n",
            "6  0.007161\n",
            "7  0.007519\n",
            "Normalized Saliency Sum: Saliency    4.683109\n",
            "dtype: float32\n",
            "#\n",
            "#\n",
            "#\n",
            "Normalized Saliency Standard Deviation: Saliency    0.059118\n",
            "dtype: float32\n",
            "Normalized Saliency Skewness: Saliency    13.085361\n",
            "dtype: float32\n",
            "Normalized Saliency Kurtosis: Saliency    191.830261\n",
            "dtype: float32\n",
            "Normalized Saliency Variance: Saliency    0.003495\n",
            "dtype: float32\n",
            "Normalized Saliency Coefficient of Variation: Saliency    605.93689\n",
            "dtype: float32\n",
            "#\n",
            "#\n",
            "#\n",
            "Cumulative Sum of Normalized Saliency Values:      Saliency\n",
            "0    0.003288\n",
            "1    0.004165\n",
            "2    0.004461\n",
            "3    0.006762\n",
            "4    0.010534\n",
            "..        ...\n",
            "475  4.669758\n",
            "476  4.676190\n",
            "477  4.677881\n",
            "478  4.680497\n",
            "479  4.683109\n",
            "\n",
            "[480 rows x 1 columns]\n",
            "Mean of Cumulative Sum of Normalized Saliency Values:      Saliency\n",
            "0    0.000007\n",
            "1    0.000009\n",
            "2    0.000009\n",
            "3    0.000014\n",
            "4    0.000022\n",
            "..        ...\n",
            "475  0.009729\n",
            "476  0.009742\n",
            "477  0.009746\n",
            "478  0.009751\n",
            "479  0.009756\n",
            "\n",
            "[480 rows x 1 columns]\n",
            "Normalized Saliency Root Mean Square: 0.059856962\n",
            "Normalized Saliency 25th Percentile: Saliency    0.001322\n",
            "Name: 0.25, dtype: float64\n",
            "Normalized Saliency 75th Percentile: Saliency    0.004718\n",
            "Name: 0.75, dtype: float64\n",
            "Normalized Saliency Interquartile Range: Saliency    0.003396\n",
            "dtype: float64\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "seconds = time.time()\n",
        "print(\"Time in seconds since end of run:\", seconds)\n",
        "local_time = time.ctime(seconds)\n",
        "print(local_time)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 0
        },
        "id": "wfZCzuq9KY9b",
        "outputId": "e3ccc530-a074-4ff4-dc67-4fd8259305d1"
      },
      "execution_count": 100,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Time in seconds since end of run: 1712560136.4642692\n",
            "Mon Apr  8 07:08:56 2024\n"
          ]
        }
      ]
    }
  ]
}