[00d17c]: / 2020 / scripts / AIforGenomics.ipynb

Download this file

902 lines (902 with data), 29.5 kB

{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "AIforGenomics.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bBEeSQ6u2-H-",
        "colab_type": "text"
      },
      "source": [
        "# Interpretability: PyTorch pipeline, Saliency Maps and Deep Dream\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XkBPghB1wkgy",
        "colab_type": "text"
      },
      "source": [
        "## Part 1: PyTorch Deep Learning Pipeline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vA48M5TADk9P",
        "colab_type": "text"
      },
      "source": [
        "\n",
        "In this part, we’ll walk through an end-to-end pipeline for a classification task on the MNIST dataset in Pytorch.\n",
        "\n",
        "We will implement the following steps - \n",
        "\n",
        "1.   Download the dataset \n",
        "2. Load the dataset\n",
        "2.   Define the model\n",
        "1.   Define the loss function and optimizer\n",
        "2.   Define the evaluation metric\n",
        "1.   Train the network on the training data\n",
        "2.   Report results on the train and test data (using the evaluation metric)\n",
        "\n",
        "\n",
        "Modified from - https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HfEuIgy8ZPCe",
        "colab_type": "text"
      },
      "source": [
        "### Download Data"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "l3aQ4JEzX8C7",
        "colab_type": "code",
        "outputId": "827f72fa-1528-4bdd-e7d1-9d3d99bd2bf6",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "!pip install pypng"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Requirement already satisfied: pypng in /usr/local/lib/python3.6/dist-packages (0.0.20)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "orUb96aWYsaN",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "from tqdm import *"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8ngVVUx6YFXw",
        "colab_type": "code",
        "outputId": "c361986d-f0f8-4843-e553-e951a4e0da78",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 809
        }
      },
      "source": [
        "!wget http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n",
        "!wget http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n",
        "!wget http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n",
        "!wget http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n",
        "!gunzip t*.gz"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "--2020-03-27 19:02:28--  http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n",
            "Resolving yann.lecun.com (yann.lecun.com)... 104.28.6.204, 104.28.7.204, 2606:4700:3033::681c:7cc, ...\n",
            "Connecting to yann.lecun.com (yann.lecun.com)|104.28.6.204|:80... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 9912422 (9.5M) [application/x-gzip]\n",
            "Saving to: ‘train-images-idx3-ubyte.gz’\n",
            "\n",
            "train-images-idx3-u 100%[===================>]   9.45M  16.2MB/s    in 0.6s    \n",
            "\n",
            "2020-03-27 19:02:29 (16.2 MB/s) - ‘train-images-idx3-ubyte.gz’ saved [9912422/9912422]\n",
            "\n",
            "--2020-03-27 19:02:29--  http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n",
            "Resolving yann.lecun.com (yann.lecun.com)... 104.28.6.204, 104.28.7.204, 2606:4700:3033::681c:7cc, ...\n",
            "Connecting to yann.lecun.com (yann.lecun.com)|104.28.6.204|:80... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 28881 (28K) [application/x-gzip]\n",
            "Saving to: ‘train-labels-idx1-ubyte.gz’\n",
            "\n",
            "train-labels-idx1-u 100%[===================>]  28.20K  --.-KB/s    in 0.05s   \n",
            "\n",
            "2020-03-27 19:02:30 (535 KB/s) - ‘train-labels-idx1-ubyte.gz’ saved [28881/28881]\n",
            "\n",
            "--2020-03-27 19:02:30--  http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n",
            "Resolving yann.lecun.com (yann.lecun.com)... 104.28.6.204, 104.28.7.204, 2606:4700:3033::681c:7cc, ...\n",
            "Connecting to yann.lecun.com (yann.lecun.com)|104.28.6.204|:80... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 1648877 (1.6M) [application/x-gzip]\n",
            "Saving to: ‘t10k-images-idx3-ubyte.gz’\n",
            "\n",
            "t10k-images-idx3-ub 100%[===================>]   1.57M  5.96MB/s    in 0.3s    \n",
            "\n",
            "2020-03-27 19:02:31 (5.96 MB/s) - ‘t10k-images-idx3-ubyte.gz’ saved [1648877/1648877]\n",
            "\n",
            "--2020-03-27 19:02:32--  http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n",
            "Resolving yann.lecun.com (yann.lecun.com)... 104.28.6.204, 104.28.7.204, 2606:4700:3033::681c:7cc, ...\n",
            "Connecting to yann.lecun.com (yann.lecun.com)|104.28.6.204|:80... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 4542 (4.4K) [application/x-gzip]\n",
            "Saving to: ‘t10k-labels-idx1-ubyte.gz’\n",
            "\n",
            "t10k-labels-idx1-ub 100%[===================>]   4.44K  --.-KB/s    in 0s      \n",
            "\n",
            "2020-03-27 19:02:32 (543 MB/s) - ‘t10k-labels-idx1-ubyte.gz’ saved [4542/4542]\n",
            "\n",
            "gzip: t10k-images-idx3-ubyte already exists; do you wish to overwrite (y or n)? "
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "v0OL4yC9YVwp",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# source: https://github.com/myleott/mnist_png/blob/master/convert_mnist_to_png.py\n",
        "import os\n",
        "import struct\n",
        "import sys\n",
        "\n",
        "from array import array\n",
        "from os import path\n",
        "\n",
        "import png\n",
        "\n",
        "# source: http://abel.ee.ucla.edu/cvxopt/_downloads/mnist.py\n",
        "def read(dataset = \"training\", path = \".\"):\n",
        "    if dataset is \"training\":\n",
        "        fname_img = os.path.join(path, 'train-images-idx3-ubyte')\n",
        "        fname_lbl = os.path.join(path, 'train-labels-idx1-ubyte')\n",
        "    elif dataset is \"testing\":\n",
        "        fname_img = os.path.join(path, 't10k-images-idx3-ubyte')\n",
        "        fname_lbl = os.path.join(path, 't10k-labels-idx1-ubyte')\n",
        "    else:\n",
        "        raise ValueError(\"dataset must be 'testing' or 'training'\")\n",
        "\n",
        "    flbl = open(fname_lbl, 'rb')\n",
        "    magic_nr, size = struct.unpack(\">II\", flbl.read(8))\n",
        "    lbl = array(\"b\", flbl.read())\n",
        "    flbl.close()\n",
        "\n",
        "    fimg = open(fname_img, 'rb')\n",
        "    magic_nr, size, rows, cols = struct.unpack(\">IIII\", fimg.read(16))\n",
        "    img = array(\"B\", fimg.read())\n",
        "    fimg.close()\n",
        "\n",
        "    return lbl, img, size, rows, cols\n",
        "\n",
        "def write_dataset(labels, data, size, rows, cols, output_dir):\n",
        "    # create output directories\n",
        "    output_dirs = [\n",
        "        path.join(output_dir, str(i))\n",
        "        for i in range(10)\n",
        "    ]\n",
        "    for dir in output_dirs:\n",
        "        if not path.exists(dir):\n",
        "            os.makedirs(dir)\n",
        "\n",
        "    # write data\n",
        "    for (i, label) in tqdm(enumerate(labels)):\n",
        "        output_filename = path.join(output_dirs[label], str(i) + \".png\")\n",
        "        with open(output_filename, \"wb\") as h:\n",
        "            w = png.Writer(cols, rows, greyscale=True)\n",
        "            data_i = [\n",
        "                data[ (i*rows*cols + j*cols) : (i*rows*cols + (j+1)*cols) ]\n",
        "                for j in range(rows)\n",
        "            ]\n",
        "            w.write(h, data_i)\n",
        "\n",
        "input_path = '/content'\n",
        "output_path = '/content/mnist'\n",
        "for dataset in [\"training\", \"testing\"]:\n",
        "    labels, data, size, rows, cols = read(dataset, input_path)\n",
        "    write_dataset(labels, data, size, rows, cols,\n",
        "                  path.join(output_path, dataset))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Zz84BWQTZv1z",
        "colab_type": "text"
      },
      "source": [
        "### Load and preprocess the dataset "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "sm9kkKrFgER1",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "#Create csv files\n",
        "import glob \n",
        "import pandas as pd\n",
        "\n",
        "for split in ['training', 'testing']:\n",
        "    rows = []\n",
        "    for folder in glob.glob('/content/mnist/'+ split + '/*'):\n",
        "        label = folder.split(\"/\")[-1]\n",
        "        for image_path in glob.glob(folder+ \"/*\"):\n",
        "            rows.append([image_path,label])\n",
        "    df=pd.DataFrame(rows,columns=['Path','Label'])\n",
        "    df.to_csv(split + \".csv\", index = False)\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Py8ckf0VqgOK",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Visualize the dataset \n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import cv2\n",
        "\n",
        "def view_dataset(paths, labels, method='cv2'):\n",
        "    fig, axs = plt.subplots(5, 5, figsize=(10, 10))\n",
        "    flatted_axs = [item for one_ax in axs for item in one_ax]\n",
        "    for ax, path, label in zip(flatted_axs, paths[:25], labels[:25]):\n",
        "        if method == 'cv2':\n",
        "            img = cv2.imread(path, 3)\n",
        "        elif method == 'tf':\n",
        "            img = try_tf_image(path)\n",
        "        ax.imshow(img)\n",
        "        ax.set_title(label)\n",
        "        ax.axis('off')\n",
        "    plt.show() \n",
        "\n",
        "\n",
        "df = pd.read_csv('/content/training.csv')\n",
        "paths = df['Path'][:25]\n",
        "labels = df['Label'][:25]\n",
        "\n",
        "view_dataset(paths, labels)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "npVz8F15ipDi",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Install Pytorch2\n",
        "# http://pytorch.org/\n",
        "#from os.path import exists\n",
        "#from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag\n",
        "#platform = '{}{}-{}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag())\n",
        "#cuda_output = !ldconfig -p|grep cudart.so|sed -e 's/.*\\.\\([0-9]*\\)\\.\\([0-9]*\\)$/cu\\1\\2/'\n",
        "#accelerator = cuda_output[0] if exists('/dev/nvidia0') else 'cpu'\n",
        "\n",
        "# !pip install -q http://download.pytorch.org/whl/{accelerator}/torch-0.4.1-{platform}-linux_x86_64.whl torchvision\n",
        "\n",
        "  \n",
        "!pip install torch \n",
        "import torch\n",
        "\n",
        "print(torch.__version__)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "rQYdkPxSZ46b",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import pandas as pd\n",
        "from torch.utils.data import Dataset, DataLoader\n",
        "from PIL import Image\n",
        "\n",
        "class Dataset(Dataset):\n",
        "    def __init__(self, data_split, toy=False):\n",
        "     \n",
        "        df = pd.read_csv(data_split)\n",
        "        \n",
        "        # Remove any paths for which the image files do not exist\n",
        "        #df = df[df[\"Path\"].apply(os.path.exists)]\n",
        "      \n",
        "        #print (\"%s size %d\" % (data_split, df.shape[0]))\n",
        "\n",
        "        #Could remove\n",
        "        #if toy:\n",
        "            #df = df.sample(frac=0.01)\n",
        "\n",
        "        self.img_paths = df[\"Path\"].tolist()\n",
        "\n",
        "        self.labels = df[\"Label\"].tolist()\n",
        "\n",
        "        self.n_classes = len(self.labels)\n",
        "\n",
        "    def __getitem__(self, index):\n",
        "           img = np.array(Image.open(self.img_paths[index])).astype(np.float32) / 255.\n",
        "           label = self.labels[index]\n",
        "           label_vec = torch.LongTensor([label])\n",
        "           return img, label_vec\n",
        "          \n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.img_paths)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "65LGWGMTvftr",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "train_dataset = Dataset('training.csv')\n",
        "test_dataset = Dataset('testing.csv')\n",
        "\n",
        "\n",
        "train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=100,\n",
        "                                         shuffle=True)\n",
        "test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=100,\n",
        "                                         shuffle=True)\n",
        "\n",
        "print(len(train_loader))\n",
        "print(len(test_loader))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "3BqLZzkD5qkp",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "\"\"\" Step 2: Define the model\n",
        "\n",
        "We will implement LeNet\n",
        "\"\"\"\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "\n",
        "class Flatten(nn.Module):\n",
        "    def forward(self, input):\n",
        "        return input.view(input.size(0), -1)\n",
        "\n",
        "class LeNet(nn.Module):\n",
        "    def __init__(self):\n",
        "        super(LeNet, self).__init__()\n",
        "        self.conv1 = nn.Conv2d(1, 6, 5)\n",
        "        self.conv2 = nn.Conv2d(6, 16, 5)\n",
        "        self.fc1   = nn.Linear(256, 120)\n",
        "        self.fc2   = nn.Linear(120, 84)\n",
        "        self.fc3   = nn.Linear(84, 10)\n",
        "\n",
        "    def forward(self, x):\n",
        "        out = F.relu(self.conv1(x[:, None, :, :]))\n",
        "        out = F.max_pool2d(out, 2)\n",
        "        out = F.relu(self.conv2(out))\n",
        "        out = F.max_pool2d(out, 2)\n",
        "        out = out.view(out.size(0), -1)\n",
        "        out = F.relu(self.fc1(out))\n",
        "        out = F.relu(self.fc2(out))\n",
        "        out = self.fc3(out)\n",
        "        return out\n",
        "\n",
        "model = LeNet()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "dNR2Huir8iDO",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "\"\"\"Step 3: Define loss function and optimizer\n",
        "\n",
        "We will use the cross entropy loss and Adam optimizer\n",
        "\"\"\"\n",
        "import torch.optim as optim\n",
        "\n",
        "# Define the cost function\n",
        "criterion = nn.CrossEntropyLoss()\n",
        "\n",
        "# Define the optimizer, learning rate \n",
        "optimizer = optim.Adam(model.parameters(), lr=0.01)\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "FjYyhIyz8kzP",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "\"\"\"\"\"Step 5: Train the network on the training data\n",
        "\"\"\"\n",
        "\n",
        "for epoch in range(2): \n",
        "      for i, (inputs, labels) in enumerate(train_loader, 0):\n",
        "        # zero the parameter gradients\n",
        "        optimizer.zero_grad()\n",
        "\n",
        "        # forward propogation\n",
        "        outputs = model(inputs)\n",
        "        \n",
        "        # calculate the loss\n",
        "        loss = criterion(outputs, labels.squeeze(1))\n",
        "        \n",
        "        # backpropogation + update parameters\n",
        "        loss.backward()\n",
        "        optimizer.step()\n",
        "\n",
        "        # print statistics\n",
        "        cost = loss.item()\n",
        "        if i % 100 == 0:    # print every 1000 iterations\n",
        "            print('Epoch:' + str(epoch) + \", Iteration: \" + str(i) \n",
        "                  + \", training cost = \" + str(cost))\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "flMO9Gx68jYJ",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "\"\"\"Step 4: Define evaluation metric\n",
        "\n",
        "We will use accuracy as an evaluation metric\n",
        "\"\"\"\n",
        "\n",
        "def calculate_accuracy(loader):\n",
        "    total = 0\n",
        "    correct = 0\n",
        "  \n",
        "    all_images = []\n",
        "    all_preds = []\n",
        "    all_labels = []\n",
        "    with torch.no_grad():\n",
        "        for data in loader:\n",
        "            images, labels = data\n",
        "            outputs = model(images)\n",
        "            _, predicted = torch.max(outputs.data, 1)\n",
        "            total += labels.size(0)\n",
        "            correct += (predicted == labels.squeeze()).sum().item()\n",
        "            \n",
        "            all_images.append(images)\n",
        "            all_preds.append(predicted.numpy())\n",
        "            all_labels.append(labels)\n",
        "\n",
        "    return 100 * correct / total, all_images, all_preds, all_labels"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "eRbRt1-18rks",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "\"\"\"Step 6: Report results on the train and test data (using the evaluation metric)\n",
        "\"\"\"\n",
        "        \n",
        "train_accuracy, _ , _, _ = calculate_accuracy(train_loader)\n",
        "test_accuracy, images, preds, labels = calculate_accuracy(test_loader)\n",
        "\n",
        "print('Train accuracy: %f' % train_accuracy)\n",
        "print('Test accuracy: %f' % test_accuracy)\n",
        "\n",
        "images = np.concatenate(images, axis=0)\n",
        "preds = np.concatenate(preds, axis=0)\n",
        "labels = np.squeeze(np.concatenate(labels, axis=0))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "GUJzZQ7zey6O",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "##### VIEW PREDICTIONS #####\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "\n",
        "\n",
        "def unison_shuffled_copies(a, b, c):\n",
        "    assert len(a) == len(b) == len(c)\n",
        "    p = np.random.permutation(len(a))\n",
        "    return a[p], b[p], c[p]\n",
        "\n",
        "images, labels, preds = unison_shuffled_copies(images, labels, preds)\n",
        "\n",
        "fig, axs = plt.subplots(5, 5, figsize=(10, 10))\n",
        "flatted_axs = [item for one_ax in axs for item in one_ax]\n",
        "for ax, img, label, pred in zip(flatted_axs, images[:25], labels[:25], preds[:25]):\n",
        "    ax.imshow(np.reshape(img, (28,28)))\n",
        "    ax.set_title('l:{},p:{}'.format(label, pred))\n",
        "    ax.axis('off')\n",
        "plt.show()  "
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "W5eAVxqYbc2N",
        "colab_type": "text"
      },
      "source": [
        "### Saving/Loading a model\n",
        "\n",
        "We will look at how you would save a trained model and then load it again for evaluation.\n",
        "\n",
        "Reference - https://pytorch.org/tutorials/beginner/saving_loading_models.html"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "1PwE0NcNbuTf",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "\"\"\"Save a model\n",
        "\"\"\"\n",
        "\n",
        "# Create a very simple model and save the weights\n",
        "v1 = torch.randn(3)\n",
        "v2 = torch.randn(5)\n",
        "\n",
        "# Save variables\n",
        "torch.save(v1, 'v1.pth')\n",
        "torch.save(v2, 'v2.pth')\n",
        "\n",
        "# Print values at v1 and v2 to verify later\n",
        "print(v1)\n",
        "print(v2)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "wRo8lso2nagN",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "\"\"\" Load a model\n",
        "\"\"\"\n",
        "\n",
        "# Load variables v1 and v2 \n",
        "v1 = torch.load('v1.pth')\n",
        "v2 = torch.load('v2.pth')\n",
        "\n",
        "# Check the values of the variables\n",
        "print(v1)\n",
        "print(v2)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cXK_MJbtcMel",
        "colab_type": "text"
      },
      "source": [
        "For more help getting started with Pytorch check out: http://cs230.stanford.edu/blog/pytorch/"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uO8S-By3yBMo",
        "colab_type": "text"
      },
      "source": [
        "## Part 2: Feature importance techniques\n",
        "\n",
        "In this part, we will implement two feature visualization techniques:\n",
        "- Saliency Maps\n",
        "- Deep Dream"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "r9gbMifYb1jE",
        "colab_type": "text"
      },
      "source": [
        "### Saliency Map"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SeMDRB2SHm67",
        "colab_type": "text"
      },
      "source": [
        "You can change the class whose gradient is computed"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "NM2RoMlTHltn",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "idx = 5 # class whose gradient will be computed"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "SzF6QEQhy-Sc",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def plot_saliency(idx):\n",
        "    # Forward pass\n",
        "  data = next(iter(test_loader))\n",
        "  images, labels = data\n",
        "  images.requires_grad = True # the gradient with respect to the input will be computed\n",
        "  outputs = model(images)\n",
        "\n",
        "  # Backward pass\n",
        "  loss = outputs[:, idx: idx+1].sum()\n",
        "  loss.backward(retain_graph=True)\n",
        "  heatmaps = images.grad.detach().numpy() # The heatmap is the gradient with respect to the input\n",
        "\n",
        "  # Plot\n",
        "  fig, axs = plt.subplots(4, 4, figsize=(10, 10))\n",
        "  fig.suptitle(\"Saliency maps with respect to class \" + str(idx) + \", green corresponds to positive gradient and red to negative\")\n",
        "  for i in range(4):\n",
        "    for j in range(2):\n",
        "      axs[i, 2*j].imshow(images[2*i + j].detach())\n",
        "      axs[i, 2*j].axis('off')\n",
        "      axs[i, 2*j+1].imshow(heatmaps[2*i + j], cmap=\"PiYG\", vmin=-1, vmax=1)\n",
        "      axs[i, 2*j+1].axis('off')\n",
        "  plt.show()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Ujfua2GTzbvi",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "plot_saliency(idx=1)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "s3jl5qoJb7EY",
        "colab_type": "text"
      },
      "source": [
        "### Deep Dream"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AW4SHVmYG7U0",
        "colab_type": "text"
      },
      "source": [
        "Play with the following parameters and see what happens!"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "iRvtXZSWHAEF",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "idx = 3 # class whose gradient will be computed\n",
        "lr = 0.05 # Learning rate for deep dream\n",
        "n_iter = 60 # Number of iterations for Deep Dream"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Mp55OBKSsQMl",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Forward pass\n",
        "data = next(iter(test_loader))\n",
        "images, labels = data\n",
        "images.requires_grad = True # the gradient with respect to the input will be computed\n",
        "outputs = model(images)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dz3R-eaCHKPc",
        "colab_type": "text"
      },
      "source": [
        "The Deep Dream procedure:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "AqzLn8Izb-Jv",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "img = images\n",
        "\n",
        "for _ in range(n_iter):\n",
        "    model.zero_grad()\n",
        "    outputs = model(img)\n",
        "    loss = outputs[:, idx: idx+1].sum() # maximize the class idx\n",
        "    loss.backward(retain_graph=True)\n",
        "    img = img + lr * img.grad # The gradient is added to the image\n",
        "    img.data = torch.clamp(img.data, min=0)\n",
        "\n",
        "    # Trick to forget the computation graph\n",
        "    img = img.detach().numpy()\n",
        "    img = torch.Tensor(img)\n",
        "    img.requires_grad = True\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hZjL5g6cHSkZ",
        "colab_type": "text"
      },
      "source": [
        "Visualize Before/After:\n",
        "\n",
        "The dreamed image is on the right of the corresponding original image"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "emFRVlDmrV99",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "fig, axs = plt.subplots(4, 4, figsize=(10, 10))\n",
        "fig.suptitle(\"Deep Dreaming numbers \" + str(idx))\n",
        "for i in range(4):\n",
        "  for j in range(2):\n",
        "    axs[i, 2*j].imshow(images[2*i + j].detach())\n",
        "    axs[i, 2*j].axis('off')\n",
        "    axs[i, 2*j+1].imshow(img[2*i + j].detach())\n",
        "    axs[i, 2*j+1].axis('off')\n",
        "plt.show()"
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}