--- a +++ b/ndv/model.ipynb @@ -0,0 +1,1098 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "b6KgSfpahUZ1" + }, + "source": [ + "# Setup" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "G75A7ey4hUZ3" + }, + "outputs": [], + "source": [ + "import torch\n", + "from fastai.callbacks import *\n", + "from fastai.vision import *" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "C77ffh0whUZ7" + }, + "source": [ + "## GPU " + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, + "colab_type": "code", + "id": "syfiMzschUZ8", + "outputId": "a555f3d9-91cb-43e1-8302-a037fbb5efe9" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 2, + "metadata": { + "tags": [] + }, + "output_type": "execute_result" + } + ], + "source": [ + "# Check GPU availablity\n", + "torch.cuda.is_available()" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, + "colab_type": "code", + "id": "TOznfYKmhUaB", + "outputId": "6b060a7f-72ae-43c1-c0b0-e521b770ffad" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "1" + ] + }, + "execution_count": 4, + "metadata": { + "tags": [] + }, + "output_type": "execute_result" + } + ], + "source": [ + "# Check mounted GPU devices\n", + "torch.cuda.device_count()" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, + "colab_type": "code", + "id": "BfqqvsjvhUaF", + "outputId": "5afac64a-4654-41df-ed6d-6bc968cc8ef7" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "0" + ] + }, + "execution_count": 16, + "metadata": { + "tags": [] + }, + "output_type": "execute_result" + } + ], + "source": [ + "# Current device you're using\n", + "# * 0-indexed *\n", + "torch.cuda.current_device()" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "code_folding": [], + "colab": { + "base_uri": "https://localhost:8080/", + "height": 289 + }, + "colab_type": "code", + "id": "hIcrVEJWhUaK", + "outputId": "52eb889d-49ef-433b-9528-0ac9b514dc76" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Wed Jun 10 22:31:28 2020 \n", + "+-----------------------------------------------------------------------------+\n", + "| NVIDIA-SMI 418.67 Driver Version: 418.67 CUDA Version: 10.1 |\n", + "|-------------------------------+----------------------+----------------------+\n", + "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", + "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", + "|===============================+======================+======================|\n", + "| 0 Tesla V100-PCIE... On | 00000000:00:06.0 Off | 0 |\n", + "| N/A 36C P0 41W / 250W | 4990MiB / 32480MiB | 0% Default |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 1 Tesla V100-PCIE... On | 00000000:00:07.0 Off | 0 |\n", + "| N/A 37C P0 43W / 250W | 936MiB / 32480MiB | 0% Default |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 2 Tesla V100-PCIE... On | 00000000:00:08.0 Off | 0 |\n", + "| N/A 34C P0 39W / 250W | 936MiB / 32480MiB | 0% Default |\n", + "+-------------------------------+----------------------+----------------------+\n", + " \n", + "+-----------------------------------------------------------------------------+\n", + "| Processes: GPU Memory |\n", + "| GPU PID Type Process name Usage |\n", + "|=============================================================================|\n", + "| 0 5744 C ...u/anaconda3/envs/pytorch_p36/bin/python 941MiB |\n", + "| 0 8881 C ...u/anaconda3/envs/pytorch_p36/bin/python 941MiB |\n", + "+-----------------------------------------------------------------------------+\n" + ] + } + ], + "source": [ + "# Check workloads of your GPU(s)\n", + "!nvidia-smi" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "2BmajeAXhUaO" + }, + "outputs": [], + "source": [ + "# Reset your current device (if necessary)\n", + "torch.cuda.set_device(1)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "tOVolw1UhUaS", + "outputId": "eab8d203-31ee-41b2-822c-3a0abb058fe5" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "1" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Check change's been made\n", + "torch.cuda.current_device()\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, + "colab_type": "code", + "id": "wXALsotzhUaY", + "outputId": "c2118aaa-bf6c-4d23-db4b-bfc0667ed5a2" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "'Tesla V100-PCIE-32GB'" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Check name of your device\n", + "torch.cuda.get_device_name()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "CSmixVUShUac" + }, + "source": [ + "# Model Prototyping" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "heading_collapsed": true, + "id": "btBiurUPxx7t" + }, + "source": [ + "## Helpers " + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "code_folding": [ + 0 + ], + "colab": {}, + "colab_type": "code", + "hidden": true, + "id": "jIljYlKohUad" + }, + "outputs": [], + "source": [ + "def conv_block(c_in, c_out, ks, num_groups=None, **conv_kwargs):\n", + " \"A sequence of modules composed of Group Norm, ReLU and Conv3d in order\"\n", + " if not num_groups : num_groups = int(c_in/2) if c_in%2 == 0 else None\n", + " return nn.Sequential(nn.GroupNorm(num_groups, c_in),\n", + " nn.ReLU(),\n", + " nn.Conv3d(c_in, c_out, ks, **conv_kwargs))" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "code_folding": [ + 0 + ], + "colab": {}, + "colab_type": "code", + "hidden": true, + "id": "fLbUaZT7hUag" + }, + "outputs": [], + "source": [ + "def reslike_block(nf, num_groups=None, bottle_neck:bool=False, **conv_kwargs):\n", + " \"A ResNet-like block with the GroupNorm normalization providing optional bottle-neck functionality\"\n", + " nf_inner = nf / 2 if bottle_neck else nf\n", + " return SequentialEx(conv_block(num_groups=num_groups, c_in=nf, c_out=nf_inner, ks=3, stride=1, padding=1, **conv_kwargs),\n", + " conv_block(num_groups=num_groups, c_in=nf_inner, c_out=nf, ks=3, stride=1, padding=1, **conv_kwargs),\n", + " MergeLayer())" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "code_folding": [ + 0 + ], + "colab": {}, + "colab_type": "code", + "hidden": true, + "id": "BKNJoGdsgbk2" + }, + "outputs": [], + "source": [ + "def upsize(c_in, c_out, ks=1, scale=2):\n", + " \"Reduce the number of features by 2 using Conv with kernel size 1x1x1 and double the spatial dimension using 3D trilinear upsampling\"\n", + " return nn.Sequential(nn.Conv3d(c_in, c_out, ks),\n", + " nn.Upsample(scale_factor=scale, mode='trilinear'))" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "code_folding": [ + 0 + ], + "colab": {}, + "colab_type": "code", + "hidden": true, + "id": "6CjiBTnT8LFF" + }, + "outputs": [], + "source": [ + "def hook_debug(module, input, output):\n", + " \"\"\"\n", + " Print out what's been hooked usually for debugging purpose\n", + " ----------------------------------------------------------\n", + " Example:\n", + " Hooks(ms, hook_debug, is_forward=True, detach=False)\n", + " \n", + " \"\"\"\n", + " print('Hooking ' + module.__class__.__name__)\n", + " print('output size:', output.data.size())\n", + " return output" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "MY131WWbx3nN" + }, + "source": [ + "## Encoder Part" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "code_folding": [ + 0 + ], + "colab": {}, + "colab_type": "code", + "id": "f_8ynHTavdvL" + }, + "outputs": [], + "source": [ + "class Encoder(nn.Module):\n", + " \"Encoder part\"\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.conv1 = nn.Conv3d(4, 32, 3, stride=1, padding=1) \n", + " self.res_block1 = reslike_block(32, num_groups=8)\n", + " self.conv_block1 = conv_block(32, 64, 3, num_groups=8, stride=2, padding=1)\n", + " self.res_block2 = reslike_block(64, num_groups=8)\n", + " self.conv_block2 = conv_block(64, 64, 3, num_groups=8, stride=1, padding=1)\n", + " self.res_block3 = reslike_block(64, num_groups=8)\n", + " self.conv_block3 = conv_block(64, 128, 3, num_groups=8, stride=2, padding=1)\n", + " self.res_block4 = reslike_block(128, num_groups=8)\n", + " self.conv_block4 = conv_block(128, 128, 3, num_groups=8, stride=1, padding=1)\n", + " self.res_block5 = reslike_block(128, num_groups=8)\n", + " self.conv_block5 = conv_block(128, 256, 3, num_groups=8, stride=2, padding=1)\n", + " self.res_block6 = reslike_block(256, num_groups=8)\n", + " self.conv_block6 = conv_block(256, 256, 3, num_groups=8, stride=1, padding=1)\n", + " self.res_block7 = reslike_block(256, num_groups=8)\n", + " self.conv_block7 = conv_block(256, 256, 3, num_groups=8, stride=1, padding=1)\n", + " self.res_block8 = reslike_block(256, num_groups=8)\n", + " self.conv_block8 = conv_block(256, 256, 3, num_groups=8, stride=1, padding=1)\n", + " self.res_block9 = reslike_block(256, num_groups=8)\n", + " \n", + " def forward(self, x):\n", + " x = self.conv1(x) # Output size: (1, 32, 160, 192, 128)\n", + " x = self.res_block1(x) # Output size: (1, 32, 160, 192, 128)\n", + " x = self.conv_block1(x) # Output size: (1, 64, 80, 96, 64)\n", + " x = self.res_block2(x) # Output size: (1, 64, 80, 96, 64)\n", + " x = self.conv_block2(x) # Output size: (1, 64, 80, 96, 64)\n", + " x = self.res_block3(x) # Output size: (1, 64, 80, 96, 64)\n", + " x = self.conv_block3(x) # Output size: (1, 128, 40, 48, 32)\n", + " x = self.res_block4(x) # Output size: (1, 128, 40, 48, 32)\n", + " x = self.conv_block4(x) # Output size: (1, 128, 40, 48, 32)\n", + " x = self.res_block5(x) # Output size: (1, 128, 40, 48, 32)\n", + " x = self.conv_block5(x) # Output size: (1, 256, 20, 24, 16)\n", + " x = self.res_block6(x) # Output size: (1, 256, 20, 24, 16)\n", + " x = self.conv_block6(x) # Output size: (1, 256, 20, 24, 16)\n", + " x = self.res_block7(x) # Output size: (1, 256, 20, 24, 16)\n", + " x = self.conv_block7(x) # Output size: (1, 256, 20, 24, 16)\n", + " x = self.res_block8(x) # Output size: (1, 256, 20, 24, 16)\n", + " x = self.conv_block8(x) # Output size: (1, 256, 20, 24, 16)\n", + " x = self.res_block9(x) # Output size: (1, 256, 20, 24, 16)\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": { + "code_folding": [ + 0 + ], + "colab": {}, + "colab_type": "code", + "id": "bgPmjWCq6n1F" + }, + "outputs": [], + "source": [ + "########## Sanity-check ############\n", + "# input = torch.randn(1, 4, 160, 192, 128)\n", + "# input = input.cuda()\n", + "# encoder = Encoder()\n", + "# encoder.cuda()\n", + "# ms = [encoder.res_block1, encoder.res_block3, encoder.res_block5]\n", + "# hooks = Hooks(ms, hook_debug, is_forward=True, detach=False)\n", + "# output = encoder(input)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "BKmlf1qY74Fx" + }, + "source": [ + "## Decoder Part" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "code_folding": [ + 0 + ], + "colab": {}, + "colab_type": "code", + "id": "p9jCdQAeBTch" + }, + "outputs": [], + "source": [ + "class Decoder(nn.Module):\n", + " \"Decoder Part\"\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.upsize1 = upsize(256, 128)\n", + " self.reslike1 = reslike_block(128, num_groups=8)\n", + " self.upsize2 = upsize(128, 64)\n", + " self.reslike2 = reslike_block(64, num_groups=8)\n", + " self.upsize3 = upsize(64, 32)\n", + " self.reslike3 = reslike_block(32, num_groups=8)\n", + " self.conv1 = nn.Conv3d(32, 3, 1) \n", + " self.sigmoid1 = torch.nn.Sigmoid()\n", + "\n", + " def forward(self, x):\n", + " x = self.upsize1(x) # Output size: (1, 128, 40, 48, 32)\n", + " x = x + hooks.stored[2] # Output size: (1, 128, 40, 48, 32)\n", + " x = self.reslike1(x) # Output size: (1, 128, 40, 48, 32)\n", + " x = self.upsize2(x) # Output size: (1, 64, 80, 96, 64)\n", + " x = x + hooks.stored[1] # Output size: (1, 64, 80, 96, 64)\n", + " x = self.reslike2(x) # Output size: (1, 64, 80, 96, 64)\n", + " x = self.upsize3(x) # Output size: (1, 32, 160, 192, 128)\n", + " x = x + hooks.stored[0] # Output size: (1, 32, 160, 192, 128)\n", + " x = self.reslike3(x) # Output size: (1, 32, 160, 192, 128)\n", + " x = self.conv1(x) # Output size: (1, 3, 160, 192, 128)\n", + " x = self.sigmoid1(x) # Output size: (1, 3, 160, 192, 128)\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "code_folding": [ + 0 + ], + "colab": {}, + "colab_type": "code", + "id": "54LhlCx7hOt6" + }, + "outputs": [], + "source": [ + "############ Sanity-check ############\n", + "# input = torch.randn(1, 256, 20, 24, 16)\n", + "# input = input.cuda()\n", + "# decoder = Decoder()\n", + "# decoder.cuda()\n", + "# output = decoder(input)\n", + "# output.shape" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "Sq9kLEFbx8sF" + }, + "source": [ + "## VAE Part" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "code_folding": [], + "colab": {}, + "colab_type": "code", + "id": "KEpqknq3hUaq" + }, + "outputs": [], + "source": [ + "class VAEEncoder(nn.Module):\n", + " \"Variational auto-encoder encoder part\"\n", + " def __init__(self, latent_dim:int=128):\n", + " super().__init__()\n", + " self.latent_dim = latent_dim\n", + " self.conv_block = conv_block(256, 16, 3, num_groups=8, stride=2, padding=1)\n", + " self.linear1 = nn.Linear(60, 1)\n", + " \n", + " # Assumed latent variable's probability density function parameters\n", + " self.z_mean = nn.Linear(256, latent_dim)\n", + " self.z_log_var = nn.Linear(256, latent_dim)\n", + " #TODO: It should work with or without GPU\n", + " self.epsilon = torch.randn(1, latent_dim, device='cuda')\n", + " \n", + " def forward(self, x):\n", + " x = self.conv_block(x) # Output size: (1, 16, 10, 12, 8) \n", + " x = x.view(256, -1) # Output size: (256, 60) \n", + " x = self.linear1(x) # Output size: (256, 1)\n", + " x = x.view(1, 256) # Output size: (1, 256) \n", + " z_mean = self.z_mean(x) # Output size: (1, 128)\n", + " z_var = self.z_log_var(x).exp() # Output size: (1, 128) \n", + " \n", + " return z_mean + z_var * self.epsilon # Output size: (1, 128) " + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "code_folding": [ + 0 + ], + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, + "colab_type": "code", + "id": "ll26pBm9tj7-", + "outputId": "f1e9300e-8e79-4c66-8d0e-6897ce6b7f80" + }, + "outputs": [], + "source": [ + "############ Sanity-check ############\n", + "# input = torch.randn(1, 256, 20, 24, 16)\n", + "# input = input.cuda()\n", + "# vae_encoder = VAEEncoder(latent_dim=128)\n", + "# vae_encoder.cuda()\n", + "# output = vae_encoder(output)\n", + "# output.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "code_folding": [ + 0 + ], + "colab": {}, + "colab_type": "code", + "id": "tl4tYTaXe1qw" + }, + "outputs": [], + "source": [ + "class VAEDecoder(nn.Module):\n", + " \"Variational auto-encoder decoder part\"\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.linear1 = nn.Linear(128, 256*60)\n", + " self.relu1 = nn.ReLU()\n", + " self.upsize1 = upsize(16, 256)\n", + " self.upsize2 = upsize(256, 128)\n", + " self.reslike1 = reslike_block(128, num_groups=8)\n", + " self.upsize3 = upsize(128, 64)\n", + " self.reslike2 = reslike_block(64, num_groups=8)\n", + " self.upsize4 = upsize(64, 32)\n", + " self.reslike3 = reslike_block(32, num_groups=8)\n", + " self.conv1 = nn.Conv3d(32, 4, 1)\n", + " \n", + " def forward(self, x):\n", + " x = self.linear1(x) # Output size: (1, 256*60) \n", + " x = self.relu1(x) # Output size: (1, 256*60)\n", + " x = x.view(1, 16, 10, 12, 8) # Output size: (1, 16, 10, 12, 8)\n", + " x = self.upsize1(x) # Output size: (1, 256, 20, 24, 16)\n", + " x = self.upsize2(x) # Output size: (1, 128, 40, 48, 32)\n", + " x = self.reslike1(x) # Output size: (1, 128, 40, 48, 32)\n", + " x = self.upsize3(x) # Output size: (1, 64, 80, 96, 64)\n", + " x = self.reslike2(x) # Output size: (1, 64, 80, 96, 64)\n", + " x = self.upsize4(x) # Output size: (1, 32, 160, 192, 128)\n", + " x = self.reslike3(x) # Output size: (1, 32, 160, 192, 128)\n", + " x = self.conv1(x) # Output size: (1, 4, 160, 192, 128) \n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "code_folding": [ + 0 + ], + "colab": {}, + "colab_type": "code", + "id": "RrusoNDpzPOk" + }, + "outputs": [], + "source": [ + "############ Sanity-check ############\n", + "# input = torch.randn(1, 128)\n", + "# input = input.cuda()\n", + "# vae_decoder = VAEDecoder()\n", + "# vae_decoder.cuda()\n", + "# vae_decoder(output).shape" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "dtLzCKAOEn6c" + }, + "source": [ + "## AutoUNet" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "code_folding": [], + "colab": {}, + "colab_type": "code", + "id": "9lhVuR2QExrp" + }, + "outputs": [], + "source": [ + "class AutoUNet(nn.Module):\n", + " \"3D U-Net using autoencoder regularization\"\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.encoder = Encoder()\n", + " self.decoder = Decoder()\n", + " self.vencoder = VAEEncoder(latent_dim=128)\n", + " self.vdecoder = VAEDecoder()\n", + "\n", + " def forward(self, input):\n", + " interm_res = self.encoder(input)\n", + " top_res = self.decoder(interm_res) # Output size: (1, 3, 160, 192, 128)\n", + " bottom_res = self.vdecoder(self.vencoder(interm_res)) # Output size: (1, 4, 160, 192, 128)\n", + " return top_res, bottom_res" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "code_folding": [], + "scrolled": true + }, + "outputs": [], + "source": [ + "############ Sanity-check ############\n", + "input = torch.randn(1, 4, 160, 192, 128)\n", + "input = input.cuda()\n", + "model = AutoUNet()\n", + "model.cuda()\n", + "\n", + "ms = [model.encoder.res_block1, \n", + " model.encoder.res_block3, \n", + " model.encoder.res_block5, \n", + " model.vencoder.z_mean, \n", + " model.vencoder.z_log_var]\n", + "\n", + "hooks = hook_outputs(ms, detach=False, grad=False) #check: overwrite for each iteration?\n", + "#hooks = Hooks(ms, hook_debug, is_forward=True, detach=False)\n", + "\n", + "output = model(input)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "ZSPf7atqhOuG" + }, + "source": [ + "## Custom Loss " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "code_folding": [], + "colab": { + "base_uri": "https://localhost:8080/", + "height": 85 + }, + "colab_type": "code", + "id": "OQ4vfaR-L9Wz", + "outputId": "cd5cb780-4027-4e12-e0de-07485713db38", + "scrolled": false + }, + "outputs": [], + "source": [ + "# Set the global variables\n", + "_, C, H, W, D = [input.shape[i] for i in range(len(input.shape))]\n", + "c = output[0].shape[1]\n", + "\n", + "print(\"Channels:\", C)\n", + "print(\"Height:\", H)\n", + "print(\"Width:\", W)\n", + "print(\"Depth:\", D)\n", + "print(\"The Number Of Labels:\", c)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "code_folding": [], + "colab": {}, + "colab_type": "code", + "id": "j7cmXkIvhOuI" + }, + "outputs": [], + "source": [ + "class SoftDiceLoss(Module): \n", + " \"Soft dice loss based on a measure of overlap between prediction and ground truth\"\n", + " def __init__(self, epsilon=1e-6, c=c):\n", + " super().__init__()\n", + " self.epsilon = epsilon\n", + " self.c = c\n", + " \n", + " def forward(self, x:Tensor, y:Tensor):\n", + " intersection = 2 * ( (x*y).sum() )\n", + " union = (x**2).sum() + (y**2).sum() \n", + " return 1 - ( ( intersection / (union + self.epsilon) ) / self.c )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "code_folding": [ + 0 + ] + }, + "outputs": [], + "source": [ + "####### Sanity-check ############\n", + "loss = " + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "code_folding": [], + "colab": {}, + "colab_type": "code", + "id": "kOjrJ44uhOuK" + }, + "outputs": [], + "source": [ + "class KLDivergence(Module): \n", + " \"KL divergence between the estimated normal distribution and a prior distribution\"\n", + " N = H * W * D #hyperparameter check\n", + "\n", + " def __init__(self):\n", + " super().__init__()\n", + " \n", + " def forward(self, z_mean:Tensor, z_log_var:Tensor):\n", + " z_var = z_log_var.exp()\n", + " return (1/self.N) * ( (z_mean**2 + z_var**2 - z_log_var**2 - 1).sum() )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "code_folding": [] + }, + "outputs": [], + "source": [ + "####### Sanity-check ############\n", + "loss2 = KLDivergence()(z_mean=hooks.stored[3], z_log_var=hooks.stored[4])\n", + "print(loss2)\n", + "loss2.backward()" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "code_folding": [ + 0 + ], + "colab": {}, + "colab_type": "code", + "id": "HycYhLrohOuM" + }, + "outputs": [], + "source": [ + "class L2Loss(Module): \n", + " \"Measuring the `Euclidian distance` between prediction and ground truh using `L2 Norm`\"\n", + " def __init__(self):\n", + " super().__init__()\n", + " \n", + " def forward(self, x:Tensor, y:Tensor):\n", + " return ( (x - y)**2 ).sum() " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "code_folding": [ + 0 + ] + }, + "outputs": [], + "source": [ + "####### Sanity-check ############\n", + "loss3 = L2Loss()(bottom_res=output[1], orig=input)\n", + "print(loss3)\n", + "loss3.backward()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "MsP_HOw2_6Jd" + }, + "source": [ + "## Optimizer" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "XYaFQ6nQ_8O4" + }, + "outputs": [], + "source": [ + "optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "GPK9Qfc0_tGL" + }, + "source": [ + "## Training" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "code_folding": [], + "colab": {}, + "colab_type": "code", + "id": "3YkqxURk_w8K" + }, + "outputs": [], + "source": [ + "for epoch in range(epochs):\n", + " \n", + " model.train()\n", + " for xb,yb in train_dl:\n", + " top_res, bottom_res = model(xb)\n", + " top_y, bottom_y = train_seg, input\n", + " z_mean, z_log_var = hooks.stored[4], hooks.stored[5] \n", + " loss = SoftDiceLoss()(top_res, top_y) + \\\n", + " (0.1 * KLDivergence()(z_mean, z_log_var)) + \\\n", + " (0.1 * L2Loss()(bottom_res, bottom_y))\n", + " loss.backward()\n", + " optimizer.step()\n", + " optimizer.zero_grad()\n", + "\n", + " model.eval()\n", + " with torch.no_grad():\n", + " tot_loss, tot_acc = 0., 0.\n", + " for xb, yb in valid_dl: \n", + " top_res, bottom_res = model(xb)\n", + " top_y, bottom_y = valid_seg, input\n", + " z_mean, z_log_var = hooks.stored[4], hooks.stored[5]\n", + " loss = SoftDiceLoss()(top_res, top_y) + \\\n", + " (0.1 * KLDivergence()(z_mean, z_log_var)) + \\\n", + " (0.1 * L2Loss()(bottom_res, bottom_y)) \n", + " tot_loss += loss\n", + " tot_acc += dice_coeff\n", + "\n", + " nv = len(valid_dl)\n", + " return tot_loss/nv, tot_acc/nv" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "heading_collapsed": true, + "id": "GXaVq0m5hUbO" + }, + "source": [ + "## Memory-check" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": { + "colab": {}, + "colab_type": "code", + "hidden": true, + "id": "Xuy-W1NFhUbR", + "outputId": "f9a8cc29-2291-488f-ea8f-44b63cb8bd29" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "9884946432" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Memory ocuupied by Pytorch `Tensors`\n", + "torch.cuda.memory_allocated(device=None)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 697 + }, + "colab_type": "code", + "hidden": true, + "id": "-F1mbF44hUbO", + "outputId": "6b4ff5a9-766a-48d0-fc2c-bd0675e303e8", + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "|===========================================================================|\n", + "| PyTorch CUDA memory summary, device ID 1 |\n", + "|---------------------------------------------------------------------------|\n", + "| CUDA OOMs: 0 | cudaMalloc retries: 0 |\n", + "|===========================================================================|\n", + "| Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed |\n", + "|---------------------------------------------------------------------------|\n", + "| Allocated memory | 9427 MB | 9855 MB | 10859 MB | 1432 MB |\n", + "| from large pool | 9423 MB | 9851 MB | 10847 MB | 1424 MB |\n", + "| from small pool | 3 MB | 3 MB | 11 MB | 8 MB |\n", + "|---------------------------------------------------------------------------|\n", + "| Active memory | 9427 MB | 9855 MB | 10859 MB | 1432 MB |\n", + "| from large pool | 9423 MB | 9851 MB | 10847 MB | 1424 MB |\n", + "| from small pool | 3 MB | 3 MB | 11 MB | 8 MB |\n", + "|---------------------------------------------------------------------------|\n", + "| GPU reserved memory | 9482 MB | 10012 MB | 10012 MB | 542720 KB |\n", + "| from large pool | 9478 MB | 10008 MB | 10008 MB | 542720 KB |\n", + "| from small pool | 4 MB | 4 MB | 4 MB | 0 KB |\n", + "|---------------------------------------------------------------------------|\n", + "| Non-releasable memory | 56300 KB | 168416 KB | 981 MB | 926 MB |\n", + "| from large pool | 55744 KB | 167872 KB | 970 MB | 915 MB |\n", + "| from small pool | 556 KB | 2034 KB | 11 MB | 11 MB |\n", + "|---------------------------------------------------------------------------|\n", + "| Allocations | 193 | 265 | 837 | 644 |\n", + "| from large pool | 76 | 124 | 145 | 69 |\n", + "| from small pool | 117 | 142 | 692 | 575 |\n", + "|---------------------------------------------------------------------------|\n", + "| Active allocs | 193 | 265 | 837 | 644 |\n", + "| from large pool | 76 | 124 | 145 | 69 |\n", + "| from small pool | 117 | 142 | 692 | 575 |\n", + "|---------------------------------------------------------------------------|\n", + "| GPU reserved segments | 66 | 91 | 91 | 25 |\n", + "| from large pool | 64 | 89 | 89 | 25 |\n", + "| from small pool | 2 | 2 | 2 | 0 |\n", + "|---------------------------------------------------------------------------|\n", + "| Non-releasable allocs | 10 | 31 | 100 | 90 |\n", + "| from large pool | 8 | 29 | 42 | 34 |\n", + "| from small pool | 2 | 5 | 58 | 56 |\n", + "|===========================================================================|\n", + "\n" + ] + } + ], + "source": [ + "# Memory status\n", + "print(torch.cuda.memory_summary(device=None, abbreviated=False))" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [ + "b6KgSfpahUZ1", + "C77ffh0whUZ7", + "btBiurUPxx7t", + "MY131WWbx3nN", + "BKmlf1qY74Fx", + "Sq9kLEFbx8sF", + "dtLzCKAOEn6c", + "ZSPf7atqhOuG", + "MsP_HOw2_6Jd", + "GPK9Qfc0_tGL", + "GXaVq0m5hUbO", + "v-ODWm3ehUbG" + ], + "name": "model_prototype_1.ipynb", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.5" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +}