[64faee]: / ndv / model.ipynb

Download this file

1099 lines (1098 with data), 35.0 kB

{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "b6KgSfpahUZ1"
   },
   "source": [
    "# Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "G75A7ey4hUZ3"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from fastai.callbacks import *\n",
    "from fastai.vision import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "C77ffh0whUZ7"
   },
   "source": [
    "## GPU "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "id": "syfiMzschUZ8",
    "outputId": "a555f3d9-91cb-43e1-8302-a037fbb5efe9"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 2,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Check GPU availablity\n",
    "torch.cuda.is_available()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "id": "TOznfYKmhUaB",
    "outputId": "6b060a7f-72ae-43c1-c0b0-e521b770ffad"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": 4,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Check mounted GPU devices\n",
    "torch.cuda.device_count()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "id": "BfqqvsjvhUaF",
    "outputId": "5afac64a-4654-41df-ed6d-6bc968cc8ef7"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0"
      ]
     },
     "execution_count": 16,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Current device you're using\n",
    "# * 0-indexed *\n",
    "torch.cuda.current_device()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "code_folding": [],
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 289
    },
    "colab_type": "code",
    "id": "hIcrVEJWhUaK",
    "outputId": "52eb889d-49ef-433b-9528-0ac9b514dc76"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Wed Jun 10 22:31:28 2020       \n",
      "+-----------------------------------------------------------------------------+\n",
      "| NVIDIA-SMI 418.67       Driver Version: 418.67       CUDA Version: 10.1     |\n",
      "|-------------------------------+----------------------+----------------------+\n",
      "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
      "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
      "|===============================+======================+======================|\n",
      "|   0  Tesla V100-PCIE...  On   | 00000000:00:06.0 Off |                    0 |\n",
      "| N/A   36C    P0    41W / 250W |   4990MiB / 32480MiB |      0%      Default |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "|   1  Tesla V100-PCIE...  On   | 00000000:00:07.0 Off |                    0 |\n",
      "| N/A   37C    P0    43W / 250W |    936MiB / 32480MiB |      0%      Default |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "|   2  Tesla V100-PCIE...  On   | 00000000:00:08.0 Off |                    0 |\n",
      "| N/A   34C    P0    39W / 250W |    936MiB / 32480MiB |      0%      Default |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "                                                                               \n",
      "+-----------------------------------------------------------------------------+\n",
      "| Processes:                                                       GPU Memory |\n",
      "|  GPU       PID   Type   Process name                             Usage      |\n",
      "|=============================================================================|\n",
      "|    0      5744      C   ...u/anaconda3/envs/pytorch_p36/bin/python   941MiB |\n",
      "|    0      8881      C   ...u/anaconda3/envs/pytorch_p36/bin/python   941MiB |\n",
      "+-----------------------------------------------------------------------------+\n"
     ]
    }
   ],
   "source": [
    "# Check workloads of your GPU(s)\n",
    "!nvidia-smi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "2BmajeAXhUaO"
   },
   "outputs": [],
   "source": [
    "# Reset your current device (if necessary)\n",
    "torch.cuda.set_device(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "tOVolw1UhUaS",
    "outputId": "eab8d203-31ee-41b2-822c-3a0abb058fe5"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Check change's been made\n",
    "torch.cuda.current_device()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "id": "wXALsotzhUaY",
    "outputId": "c2118aaa-bf6c-4d23-db4b-bfc0667ed5a2"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Tesla V100-PCIE-32GB'"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Check name of your device\n",
    "torch.cuda.get_device_name()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "CSmixVUShUac"
   },
   "source": [
    "# Model Prototyping"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "heading_collapsed": true,
    "id": "btBiurUPxx7t"
   },
   "source": [
    "## Helpers "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "code_folding": [
     0
    ],
    "colab": {},
    "colab_type": "code",
    "hidden": true,
    "id": "jIljYlKohUad"
   },
   "outputs": [],
   "source": [
    "def conv_block(c_in, c_out, ks, num_groups=None, **conv_kwargs):\n",
    "    \"A sequence of modules composed of Group Norm, ReLU and Conv3d in order\"\n",
    "    if not num_groups : num_groups = int(c_in/2) if c_in%2 == 0 else None\n",
    "    return nn.Sequential(nn.GroupNorm(num_groups, c_in),\n",
    "                         nn.ReLU(),\n",
    "                         nn.Conv3d(c_in, c_out, ks, **conv_kwargs))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "code_folding": [
     0
    ],
    "colab": {},
    "colab_type": "code",
    "hidden": true,
    "id": "fLbUaZT7hUag"
   },
   "outputs": [],
   "source": [
    "def reslike_block(nf, num_groups=None, bottle_neck:bool=False, **conv_kwargs):\n",
    "    \"A ResNet-like block with the GroupNorm normalization providing optional bottle-neck functionality\"\n",
    "    nf_inner = nf / 2 if bottle_neck else nf\n",
    "    return SequentialEx(conv_block(num_groups=num_groups, c_in=nf, c_out=nf_inner, ks=3, stride=1, padding=1, **conv_kwargs),\n",
    "                        conv_block(num_groups=num_groups, c_in=nf_inner, c_out=nf, ks=3, stride=1, padding=1, **conv_kwargs),\n",
    "                        MergeLayer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "code_folding": [
     0
    ],
    "colab": {},
    "colab_type": "code",
    "hidden": true,
    "id": "BKNJoGdsgbk2"
   },
   "outputs": [],
   "source": [
    "def upsize(c_in, c_out, ks=1, scale=2):\n",
    "    \"Reduce the number of features by 2 using Conv with kernel size 1x1x1 and double the spatial dimension using 3D trilinear upsampling\"\n",
    "    return nn.Sequential(nn.Conv3d(c_in, c_out, ks),\n",
    "                       nn.Upsample(scale_factor=scale, mode='trilinear'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "code_folding": [
     0
    ],
    "colab": {},
    "colab_type": "code",
    "hidden": true,
    "id": "6CjiBTnT8LFF"
   },
   "outputs": [],
   "source": [
    "def hook_debug(module, input, output):\n",
    "    \"\"\"\n",
    "    Print out what's been hooked usually for debugging purpose\n",
    "    ----------------------------------------------------------\n",
    "       Example:\n",
    "       Hooks(ms, hook_debug, is_forward=True, detach=False)\n",
    "    \n",
    "    \"\"\"\n",
    "    print('Hooking ' + module.__class__.__name__)\n",
    "    print('output size:', output.data.size())\n",
    "    return output"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "MY131WWbx3nN"
   },
   "source": [
    "## Encoder Part"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "code_folding": [
     0
    ],
    "colab": {},
    "colab_type": "code",
    "id": "f_8ynHTavdvL"
   },
   "outputs": [],
   "source": [
    "class Encoder(nn.Module):\n",
    "    \"Encoder part\"\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.conv1 = nn.Conv3d(4, 32, 3, stride=1, padding=1)         \n",
    "        self.res_block1 = reslike_block(32, num_groups=8)\n",
    "        self.conv_block1 = conv_block(32, 64, 3, num_groups=8, stride=2, padding=1)\n",
    "        self.res_block2 = reslike_block(64, num_groups=8)\n",
    "        self.conv_block2 = conv_block(64, 64, 3, num_groups=8, stride=1, padding=1)\n",
    "        self.res_block3 = reslike_block(64, num_groups=8)\n",
    "        self.conv_block3 = conv_block(64, 128, 3, num_groups=8, stride=2, padding=1)\n",
    "        self.res_block4 = reslike_block(128, num_groups=8)\n",
    "        self.conv_block4 = conv_block(128, 128, 3, num_groups=8, stride=1, padding=1)\n",
    "        self.res_block5 = reslike_block(128, num_groups=8)\n",
    "        self.conv_block5 = conv_block(128, 256, 3, num_groups=8, stride=2, padding=1)\n",
    "        self.res_block6 = reslike_block(256, num_groups=8)\n",
    "        self.conv_block6 = conv_block(256, 256, 3, num_groups=8, stride=1, padding=1)\n",
    "        self.res_block7 = reslike_block(256, num_groups=8)\n",
    "        self.conv_block7 = conv_block(256, 256, 3, num_groups=8, stride=1, padding=1)\n",
    "        self.res_block8 = reslike_block(256, num_groups=8)\n",
    "        self.conv_block8 = conv_block(256, 256, 3, num_groups=8, stride=1, padding=1)\n",
    "        self.res_block9 = reslike_block(256, num_groups=8)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        x = self.conv1(x)                                           # Output size: (1, 32, 160, 192, 128)\n",
    "        x = self.res_block1(x)                                      # Output size: (1, 32, 160, 192, 128)\n",
    "        x = self.conv_block1(x)                                     # Output size: (1, 64, 80, 96, 64)\n",
    "        x = self.res_block2(x)                                      # Output size: (1, 64, 80, 96, 64)\n",
    "        x = self.conv_block2(x)                                     # Output size: (1, 64, 80, 96, 64)\n",
    "        x = self.res_block3(x)                                      # Output size: (1, 64, 80, 96, 64)\n",
    "        x = self.conv_block3(x)                                     # Output size: (1, 128, 40, 48, 32)\n",
    "        x = self.res_block4(x)                                      # Output size: (1, 128, 40, 48, 32)\n",
    "        x = self.conv_block4(x)                                     # Output size: (1, 128, 40, 48, 32)\n",
    "        x = self.res_block5(x)                                      # Output size: (1, 128, 40, 48, 32)\n",
    "        x = self.conv_block5(x)                                     # Output size: (1, 256, 20, 24, 16)\n",
    "        x = self.res_block6(x)                                      # Output size: (1, 256, 20, 24, 16)\n",
    "        x = self.conv_block6(x)                                     # Output size: (1, 256, 20, 24, 16)\n",
    "        x = self.res_block7(x)                                      # Output size: (1, 256, 20, 24, 16)\n",
    "        x = self.conv_block7(x)                                     # Output size: (1, 256, 20, 24, 16)\n",
    "        x = self.res_block8(x)                                      # Output size: (1, 256, 20, 24, 16)\n",
    "        x = self.conv_block8(x)                                     # Output size: (1, 256, 20, 24, 16)\n",
    "        x = self.res_block9(x)                                      # Output size: (1, 256, 20, 24, 16)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "code_folding": [
     0
    ],
    "colab": {},
    "colab_type": "code",
    "id": "bgPmjWCq6n1F"
   },
   "outputs": [],
   "source": [
    "########## Sanity-check ############\n",
    "# input = torch.randn(1, 4, 160, 192, 128)\n",
    "# input = input.cuda()\n",
    "# encoder = Encoder()\n",
    "# encoder.cuda()\n",
    "# ms = [encoder.res_block1, encoder.res_block3, encoder.res_block5]\n",
    "# hooks = Hooks(ms, hook_debug, is_forward=True, detach=False)\n",
    "# output = encoder(input)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "BKmlf1qY74Fx"
   },
   "source": [
    "## Decoder Part"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "code_folding": [
     0
    ],
    "colab": {},
    "colab_type": "code",
    "id": "p9jCdQAeBTch"
   },
   "outputs": [],
   "source": [
    "class Decoder(nn.Module):\n",
    "    \"Decoder Part\"\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.upsize1 = upsize(256, 128)\n",
    "        self.reslike1 = reslike_block(128, num_groups=8)\n",
    "        self.upsize2 = upsize(128, 64)\n",
    "        self.reslike2 = reslike_block(64, num_groups=8)\n",
    "        self.upsize3 = upsize(64, 32)\n",
    "        self.reslike3 = reslike_block(32, num_groups=8)\n",
    "        self.conv1 = nn.Conv3d(32, 3, 1) \n",
    "        self.sigmoid1 = torch.nn.Sigmoid()\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.upsize1(x)                                         # Output size: (1, 128, 40, 48, 32)\n",
    "        x = x + hooks.stored[2]                                     # Output size: (1, 128, 40, 48, 32)\n",
    "        x = self.reslike1(x)                                        # Output size: (1, 128, 40, 48, 32)\n",
    "        x = self.upsize2(x)                                         # Output size: (1, 64, 80, 96, 64)\n",
    "        x = x + hooks.stored[1]                                     # Output size: (1, 64, 80, 96, 64)\n",
    "        x = self.reslike2(x)                                        # Output size: (1, 64, 80, 96, 64)\n",
    "        x = self.upsize3(x)                                         # Output size: (1, 32, 160, 192, 128)\n",
    "        x = x + hooks.stored[0]                                     # Output size: (1, 32, 160, 192, 128)\n",
    "        x = self.reslike3(x)                                        # Output size: (1, 32, 160, 192, 128)\n",
    "        x = self.conv1(x)                                           # Output size: (1, 3, 160, 192, 128)\n",
    "        x = self.sigmoid1(x)                                        # Output size: (1, 3, 160, 192, 128)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "code_folding": [
     0
    ],
    "colab": {},
    "colab_type": "code",
    "id": "54LhlCx7hOt6"
   },
   "outputs": [],
   "source": [
    "############ Sanity-check ############\n",
    "# input = torch.randn(1, 256, 20, 24, 16)\n",
    "# input = input.cuda()\n",
    "# decoder = Decoder()\n",
    "# decoder.cuda()\n",
    "# output = decoder(input)\n",
    "# output.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Sq9kLEFbx8sF"
   },
   "source": [
    "## VAE Part"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "code_folding": [],
    "colab": {},
    "colab_type": "code",
    "id": "KEpqknq3hUaq"
   },
   "outputs": [],
   "source": [
    "class VAEEncoder(nn.Module):\n",
    "    \"Variational auto-encoder encoder part\"\n",
    "    def __init__(self, latent_dim:int=128):\n",
    "        super().__init__()\n",
    "        self.latent_dim = latent_dim\n",
    "        self.conv_block = conv_block(256, 16, 3, num_groups=8, stride=2, padding=1)\n",
    "        self.linear1 = nn.Linear(60, 1)\n",
    "        \n",
    "        # Assumed latent variable's probability density function parameters\n",
    "        self.z_mean = nn.Linear(256, latent_dim)\n",
    "        self.z_log_var = nn.Linear(256, latent_dim)\n",
    "        #TODO: It should work with or without GPU\n",
    "        self.epsilon = torch.randn(1, latent_dim, device='cuda')\n",
    "        \n",
    "    def forward(self, x):\n",
    "        x = self.conv_block(x)                                   # Output size: (1, 16, 10, 12, 8)                                  \n",
    "        x = x.view(256, -1)                                      # Output size: (256, 60)                                       \n",
    "        x = self.linear1(x)                                      # Output size: (256, 1)\n",
    "        x = x.view(1, 256)                                       # Output size: (1, 256)   \n",
    "        z_mean = self.z_mean(x)                                  # Output size: (1, 128)\n",
    "        z_var = self.z_log_var(x).exp()                          # Output size: (1, 128)              \n",
    "        \n",
    "        return z_mean + z_var * self.epsilon                     # Output size: (1, 128)                              "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "code_folding": [
     0
    ],
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "id": "ll26pBm9tj7-",
    "outputId": "f1e9300e-8e79-4c66-8d0e-6897ce6b7f80"
   },
   "outputs": [],
   "source": [
    "############ Sanity-check ############\n",
    "# input = torch.randn(1, 256, 20, 24, 16)\n",
    "# input = input.cuda()\n",
    "# vae_encoder = VAEEncoder(latent_dim=128)\n",
    "# vae_encoder.cuda()\n",
    "# output = vae_encoder(output)\n",
    "# output.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "code_folding": [
     0
    ],
    "colab": {},
    "colab_type": "code",
    "id": "tl4tYTaXe1qw"
   },
   "outputs": [],
   "source": [
    "class VAEDecoder(nn.Module):\n",
    "    \"Variational auto-encoder decoder part\"\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.linear1 = nn.Linear(128, 256*60)\n",
    "        self.relu1 = nn.ReLU()\n",
    "        self.upsize1 = upsize(16, 256)\n",
    "        self.upsize2 = upsize(256, 128)\n",
    "        self.reslike1 = reslike_block(128, num_groups=8)\n",
    "        self.upsize3 = upsize(128, 64)\n",
    "        self.reslike2 = reslike_block(64, num_groups=8)\n",
    "        self.upsize4 = upsize(64, 32)\n",
    "        self.reslike3 = reslike_block(32, num_groups=8)\n",
    "        self.conv1 = nn.Conv3d(32, 4, 1)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        x = self.linear1(x)                                          # Output size: (1, 256*60)      \n",
    "        x = self.relu1(x)                                            # Output size: (1, 256*60)\n",
    "        x = x.view(1, 16, 10, 12, 8)                                 # Output size: (1, 16, 10, 12, 8)\n",
    "        x = self.upsize1(x)                                          # Output size: (1, 256, 20, 24, 16)\n",
    "        x = self.upsize2(x)                                          # Output size: (1, 128, 40, 48, 32)\n",
    "        x = self.reslike1(x)                                         # Output size: (1, 128, 40, 48, 32)\n",
    "        x = self.upsize3(x)                                          # Output size: (1, 64, 80, 96, 64)\n",
    "        x = self.reslike2(x)                                         # Output size: (1, 64, 80, 96, 64)\n",
    "        x = self.upsize4(x)                                          # Output size: (1, 32, 160, 192, 128)\n",
    "        x = self.reslike3(x)                                         # Output size: (1, 32, 160, 192, 128)\n",
    "        x = self.conv1(x)                                            # Output size: (1, 4, 160, 192, 128) \n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "code_folding": [
     0
    ],
    "colab": {},
    "colab_type": "code",
    "id": "RrusoNDpzPOk"
   },
   "outputs": [],
   "source": [
    "############ Sanity-check ############\n",
    "# input = torch.randn(1, 128)\n",
    "# input = input.cuda()\n",
    "# vae_decoder = VAEDecoder()\n",
    "# vae_decoder.cuda()\n",
    "# vae_decoder(output).shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "dtLzCKAOEn6c"
   },
   "source": [
    "## AutoUNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "code_folding": [],
    "colab": {},
    "colab_type": "code",
    "id": "9lhVuR2QExrp"
   },
   "outputs": [],
   "source": [
    "class AutoUNet(nn.Module):\n",
    "  \"3D U-Net using autoencoder regularization\"\n",
    "  def __init__(self):\n",
    "    super().__init__()\n",
    "    self.encoder = Encoder()\n",
    "    self.decoder = Decoder()\n",
    "    self.vencoder = VAEEncoder(latent_dim=128)\n",
    "    self.vdecoder = VAEDecoder()\n",
    "\n",
    "  def forward(self, input):\n",
    "    interm_res = self.encoder(input)\n",
    "    top_res = self.decoder(interm_res)                               # Output size: (1, 3, 160, 192, 128)\n",
    "    bottom_res = self.vdecoder(self.vencoder(interm_res))            # Output size: (1, 4, 160, 192, 128)\n",
    "    return top_res, bottom_res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "code_folding": [],
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "############ Sanity-check ############\n",
    "input = torch.randn(1, 4, 160, 192, 128)\n",
    "input = input.cuda()\n",
    "model = AutoUNet()\n",
    "model.cuda()\n",
    "\n",
    "ms = [model.encoder.res_block1, \n",
    "      model.encoder.res_block3, \n",
    "      model.encoder.res_block5, \n",
    "      model.vencoder.z_mean, \n",
    "      model.vencoder.z_log_var]\n",
    "\n",
    "hooks = hook_outputs(ms, detach=False, grad=False) #check: overwrite for each iteration?\n",
    "#hooks = Hooks(ms, hook_debug, is_forward=True, detach=False)\n",
    "\n",
    "output = model(input)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "ZSPf7atqhOuG"
   },
   "source": [
    "## Custom Loss "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "code_folding": [],
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 85
    },
    "colab_type": "code",
    "id": "OQ4vfaR-L9Wz",
    "outputId": "cd5cb780-4027-4e12-e0de-07485713db38",
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Set the global variables\n",
    "_, C, H, W, D = [input.shape[i] for i in range(len(input.shape))]\n",
    "c = output[0].shape[1]\n",
    "\n",
    "print(\"Channels:\", C)\n",
    "print(\"Height:\", H)\n",
    "print(\"Width:\", W)\n",
    "print(\"Depth:\", D)\n",
    "print(\"The Number Of Labels:\", c)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "code_folding": [],
    "colab": {},
    "colab_type": "code",
    "id": "j7cmXkIvhOuI"
   },
   "outputs": [],
   "source": [
    "class SoftDiceLoss(Module): \n",
    "    \"Soft dice loss based on a measure of overlap between prediction and ground truth\"\n",
    "    def __init__(self, epsilon=1e-6, c=c):\n",
    "        super().__init__()\n",
    "        self.epsilon = epsilon\n",
    "        self.c = c\n",
    "    \n",
    "    def forward(self, x:Tensor, y:Tensor):\n",
    "        intersection = 2 * ( (x*y).sum() )\n",
    "        union = (x**2).sum() + (y**2).sum() \n",
    "        return 1 - ( ( intersection / (union + self.epsilon) ) / self.c )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "code_folding": [
     0
    ]
   },
   "outputs": [],
   "source": [
    "####### Sanity-check ############\n",
    "loss = "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "code_folding": [],
    "colab": {},
    "colab_type": "code",
    "id": "kOjrJ44uhOuK"
   },
   "outputs": [],
   "source": [
    "class KLDivergence(Module): \n",
    "    \"KL divergence between the estimated normal distribution and a prior distribution\"\n",
    "    N = H * W * D  #hyperparameter check\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "    \n",
    "    def forward(self, z_mean:Tensor, z_log_var:Tensor):\n",
    "        z_var = z_log_var.exp()\n",
    "        return (1/self.N) * ( (z_mean**2 + z_var**2 - z_log_var**2 - 1).sum() )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "code_folding": []
   },
   "outputs": [],
   "source": [
    "####### Sanity-check ############\n",
    "loss2 = KLDivergence()(z_mean=hooks.stored[3], z_log_var=hooks.stored[4])\n",
    "print(loss2)\n",
    "loss2.backward()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "code_folding": [
     0
    ],
    "colab": {},
    "colab_type": "code",
    "id": "HycYhLrohOuM"
   },
   "outputs": [],
   "source": [
    "class L2Loss(Module): \n",
    "    \"Measuring the `Euclidian distance` between prediction and ground truh using `L2 Norm`\"\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        \n",
    "    def forward(self, x:Tensor, y:Tensor):\n",
    "        return  ( (x - y)**2 ).sum()       "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "code_folding": [
     0
    ]
   },
   "outputs": [],
   "source": [
    "####### Sanity-check ############\n",
    "loss3 = L2Loss()(bottom_res=output[1], orig=input)\n",
    "print(loss3)\n",
    "loss3.backward()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "MsP_HOw2_6Jd"
   },
   "source": [
    "## Optimizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "XYaFQ6nQ_8O4"
   },
   "outputs": [],
   "source": [
    "optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "GPK9Qfc0_tGL"
   },
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "code_folding": [],
    "colab": {},
    "colab_type": "code",
    "id": "3YkqxURk_w8K"
   },
   "outputs": [],
   "source": [
    "for epoch in range(epochs):\n",
    "  \n",
    "  model.train()\n",
    "  for xb,yb in train_dl:\n",
    "    top_res, bottom_res = model(xb)\n",
    "    top_y, bottom_y = train_seg, input\n",
    "    z_mean, z_log_var = hooks.stored[4], hooks.stored[5] \n",
    "    loss = SoftDiceLoss()(top_res, top_y) + \\\n",
    "           (0.1 * KLDivergence()(z_mean, z_log_var)) + \\\n",
    "           (0.1 * L2Loss()(bottom_res, bottom_y))\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    optimizer.zero_grad()\n",
    "\n",
    "  model.eval()\n",
    "  with torch.no_grad():\n",
    "    tot_loss, tot_acc = 0., 0.\n",
    "    for xb, yb in valid_dl:  \n",
    "    top_res, bottom_res = model(xb)\n",
    "    top_y, bottom_y = valid_seg, input\n",
    "    z_mean, z_log_var = hooks.stored[4], hooks.stored[5]\n",
    "    loss = SoftDiceLoss()(top_res, top_y) + \\\n",
    "           (0.1 * KLDivergence()(z_mean, z_log_var)) + \\\n",
    "           (0.1 * L2Loss()(bottom_res, bottom_y))    \n",
    "    tot_loss += loss\n",
    "    tot_acc += dice_coeff\n",
    "\n",
    "  nv = len(valid_dl)\n",
    "  return tot_loss/nv, tot_acc/nv"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "heading_collapsed": true,
    "id": "GXaVq0m5hUbO"
   },
   "source": [
    "## Memory-check"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "hidden": true,
    "id": "Xuy-W1NFhUbR",
    "outputId": "f9a8cc29-2291-488f-ea8f-44b63cb8bd29"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "9884946432"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Memory ocuupied by Pytorch `Tensors`\n",
    "torch.cuda.memory_allocated(device=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 697
    },
    "colab_type": "code",
    "hidden": true,
    "id": "-F1mbF44hUbO",
    "outputId": "6b4ff5a9-766a-48d0-fc2c-bd0675e303e8",
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "|===========================================================================|\n",
      "|                  PyTorch CUDA memory summary, device ID 1                 |\n",
      "|---------------------------------------------------------------------------|\n",
      "|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |\n",
      "|===========================================================================|\n",
      "|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |\n",
      "|---------------------------------------------------------------------------|\n",
      "| Allocated memory      |    9427 MB |    9855 MB |   10859 MB |    1432 MB |\n",
      "|       from large pool |    9423 MB |    9851 MB |   10847 MB |    1424 MB |\n",
      "|       from small pool |       3 MB |       3 MB |      11 MB |       8 MB |\n",
      "|---------------------------------------------------------------------------|\n",
      "| Active memory         |    9427 MB |    9855 MB |   10859 MB |    1432 MB |\n",
      "|       from large pool |    9423 MB |    9851 MB |   10847 MB |    1424 MB |\n",
      "|       from small pool |       3 MB |       3 MB |      11 MB |       8 MB |\n",
      "|---------------------------------------------------------------------------|\n",
      "| GPU reserved memory   |    9482 MB |   10012 MB |   10012 MB |  542720 KB |\n",
      "|       from large pool |    9478 MB |   10008 MB |   10008 MB |  542720 KB |\n",
      "|       from small pool |       4 MB |       4 MB |       4 MB |       0 KB |\n",
      "|---------------------------------------------------------------------------|\n",
      "| Non-releasable memory |   56300 KB |  168416 KB |     981 MB |     926 MB |\n",
      "|       from large pool |   55744 KB |  167872 KB |     970 MB |     915 MB |\n",
      "|       from small pool |     556 KB |    2034 KB |      11 MB |      11 MB |\n",
      "|---------------------------------------------------------------------------|\n",
      "| Allocations           |     193    |     265    |     837    |     644    |\n",
      "|       from large pool |      76    |     124    |     145    |      69    |\n",
      "|       from small pool |     117    |     142    |     692    |     575    |\n",
      "|---------------------------------------------------------------------------|\n",
      "| Active allocs         |     193    |     265    |     837    |     644    |\n",
      "|       from large pool |      76    |     124    |     145    |      69    |\n",
      "|       from small pool |     117    |     142    |     692    |     575    |\n",
      "|---------------------------------------------------------------------------|\n",
      "| GPU reserved segments |      66    |      91    |      91    |      25    |\n",
      "|       from large pool |      64    |      89    |      89    |      25    |\n",
      "|       from small pool |       2    |       2    |       2    |       0    |\n",
      "|---------------------------------------------------------------------------|\n",
      "| Non-releasable allocs |      10    |      31    |     100    |      90    |\n",
      "|       from large pool |       8    |      29    |      42    |      34    |\n",
      "|       from small pool |       2    |       5    |      58    |      56    |\n",
      "|===========================================================================|\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Memory status\n",
    "print(torch.cuda.memory_summary(device=None, abbreviated=False))"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [
    "b6KgSfpahUZ1",
    "C77ffh0whUZ7",
    "btBiurUPxx7t",
    "MY131WWbx3nN",
    "BKmlf1qY74Fx",
    "Sq9kLEFbx8sF",
    "dtLzCKAOEn6c",
    "ZSPf7atqhOuG",
    "MsP_HOw2_6Jd",
    "GPK9Qfc0_tGL",
    "GXaVq0m5hUbO",
    "v-ODWm3ehUbG"
   ],
   "name": "model_prototype_1.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}