1099 lines (1098 with data), 35.0 kB
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "b6KgSfpahUZ1"
},
"source": [
"# Setup"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "G75A7ey4hUZ3"
},
"outputs": [],
"source": [
"import torch\n",
"from fastai.callbacks import *\n",
"from fastai.vision import *"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "C77ffh0whUZ7"
},
"source": [
"## GPU "
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"colab_type": "code",
"id": "syfiMzschUZ8",
"outputId": "a555f3d9-91cb-43e1-8302-a037fbb5efe9"
},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 2,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"# Check GPU availablity\n",
"torch.cuda.is_available()"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"colab_type": "code",
"id": "TOznfYKmhUaB",
"outputId": "6b060a7f-72ae-43c1-c0b0-e521b770ffad"
},
"outputs": [
{
"data": {
"text/plain": [
"1"
]
},
"execution_count": 4,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"# Check mounted GPU devices\n",
"torch.cuda.device_count()"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"colab_type": "code",
"id": "BfqqvsjvhUaF",
"outputId": "5afac64a-4654-41df-ed6d-6bc968cc8ef7"
},
"outputs": [
{
"data": {
"text/plain": [
"0"
]
},
"execution_count": 16,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"# Current device you're using\n",
"# * 0-indexed *\n",
"torch.cuda.current_device()"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"code_folding": [],
"colab": {
"base_uri": "https://localhost:8080/",
"height": 289
},
"colab_type": "code",
"id": "hIcrVEJWhUaK",
"outputId": "52eb889d-49ef-433b-9528-0ac9b514dc76"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Wed Jun 10 22:31:28 2020 \n",
"+-----------------------------------------------------------------------------+\n",
"| NVIDIA-SMI 418.67 Driver Version: 418.67 CUDA Version: 10.1 |\n",
"|-------------------------------+----------------------+----------------------+\n",
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
"|===============================+======================+======================|\n",
"| 0 Tesla V100-PCIE... On | 00000000:00:06.0 Off | 0 |\n",
"| N/A 36C P0 41W / 250W | 4990MiB / 32480MiB | 0% Default |\n",
"+-------------------------------+----------------------+----------------------+\n",
"| 1 Tesla V100-PCIE... On | 00000000:00:07.0 Off | 0 |\n",
"| N/A 37C P0 43W / 250W | 936MiB / 32480MiB | 0% Default |\n",
"+-------------------------------+----------------------+----------------------+\n",
"| 2 Tesla V100-PCIE... On | 00000000:00:08.0 Off | 0 |\n",
"| N/A 34C P0 39W / 250W | 936MiB / 32480MiB | 0% Default |\n",
"+-------------------------------+----------------------+----------------------+\n",
" \n",
"+-----------------------------------------------------------------------------+\n",
"| Processes: GPU Memory |\n",
"| GPU PID Type Process name Usage |\n",
"|=============================================================================|\n",
"| 0 5744 C ...u/anaconda3/envs/pytorch_p36/bin/python 941MiB |\n",
"| 0 8881 C ...u/anaconda3/envs/pytorch_p36/bin/python 941MiB |\n",
"+-----------------------------------------------------------------------------+\n"
]
}
],
"source": [
"# Check workloads of your GPU(s)\n",
"!nvidia-smi"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "2BmajeAXhUaO"
},
"outputs": [],
"source": [
"# Reset your current device (if necessary)\n",
"torch.cuda.set_device(1)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "tOVolw1UhUaS",
"outputId": "eab8d203-31ee-41b2-822c-3a0abb058fe5"
},
"outputs": [
{
"data": {
"text/plain": [
"1"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Check change's been made\n",
"torch.cuda.current_device()\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"colab_type": "code",
"id": "wXALsotzhUaY",
"outputId": "c2118aaa-bf6c-4d23-db4b-bfc0667ed5a2"
},
"outputs": [
{
"data": {
"text/plain": [
"'Tesla V100-PCIE-32GB'"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Check name of your device\n",
"torch.cuda.get_device_name()"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "CSmixVUShUac"
},
"source": [
"# Model Prototyping"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"heading_collapsed": true,
"id": "btBiurUPxx7t"
},
"source": [
"## Helpers "
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"code_folding": [
0
],
"colab": {},
"colab_type": "code",
"hidden": true,
"id": "jIljYlKohUad"
},
"outputs": [],
"source": [
"def conv_block(c_in, c_out, ks, num_groups=None, **conv_kwargs):\n",
" \"A sequence of modules composed of Group Norm, ReLU and Conv3d in order\"\n",
" if not num_groups : num_groups = int(c_in/2) if c_in%2 == 0 else None\n",
" return nn.Sequential(nn.GroupNorm(num_groups, c_in),\n",
" nn.ReLU(),\n",
" nn.Conv3d(c_in, c_out, ks, **conv_kwargs))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"code_folding": [
0
],
"colab": {},
"colab_type": "code",
"hidden": true,
"id": "fLbUaZT7hUag"
},
"outputs": [],
"source": [
"def reslike_block(nf, num_groups=None, bottle_neck:bool=False, **conv_kwargs):\n",
" \"A ResNet-like block with the GroupNorm normalization providing optional bottle-neck functionality\"\n",
" nf_inner = nf / 2 if bottle_neck else nf\n",
" return SequentialEx(conv_block(num_groups=num_groups, c_in=nf, c_out=nf_inner, ks=3, stride=1, padding=1, **conv_kwargs),\n",
" conv_block(num_groups=num_groups, c_in=nf_inner, c_out=nf, ks=3, stride=1, padding=1, **conv_kwargs),\n",
" MergeLayer())"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"code_folding": [
0
],
"colab": {},
"colab_type": "code",
"hidden": true,
"id": "BKNJoGdsgbk2"
},
"outputs": [],
"source": [
"def upsize(c_in, c_out, ks=1, scale=2):\n",
" \"Reduce the number of features by 2 using Conv with kernel size 1x1x1 and double the spatial dimension using 3D trilinear upsampling\"\n",
" return nn.Sequential(nn.Conv3d(c_in, c_out, ks),\n",
" nn.Upsample(scale_factor=scale, mode='trilinear'))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"code_folding": [
0
],
"colab": {},
"colab_type": "code",
"hidden": true,
"id": "6CjiBTnT8LFF"
},
"outputs": [],
"source": [
"def hook_debug(module, input, output):\n",
" \"\"\"\n",
" Print out what's been hooked usually for debugging purpose\n",
" ----------------------------------------------------------\n",
" Example:\n",
" Hooks(ms, hook_debug, is_forward=True, detach=False)\n",
" \n",
" \"\"\"\n",
" print('Hooking ' + module.__class__.__name__)\n",
" print('output size:', output.data.size())\n",
" return output"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "MY131WWbx3nN"
},
"source": [
"## Encoder Part"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"code_folding": [
0
],
"colab": {},
"colab_type": "code",
"id": "f_8ynHTavdvL"
},
"outputs": [],
"source": [
"class Encoder(nn.Module):\n",
" \"Encoder part\"\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.conv1 = nn.Conv3d(4, 32, 3, stride=1, padding=1) \n",
" self.res_block1 = reslike_block(32, num_groups=8)\n",
" self.conv_block1 = conv_block(32, 64, 3, num_groups=8, stride=2, padding=1)\n",
" self.res_block2 = reslike_block(64, num_groups=8)\n",
" self.conv_block2 = conv_block(64, 64, 3, num_groups=8, stride=1, padding=1)\n",
" self.res_block3 = reslike_block(64, num_groups=8)\n",
" self.conv_block3 = conv_block(64, 128, 3, num_groups=8, stride=2, padding=1)\n",
" self.res_block4 = reslike_block(128, num_groups=8)\n",
" self.conv_block4 = conv_block(128, 128, 3, num_groups=8, stride=1, padding=1)\n",
" self.res_block5 = reslike_block(128, num_groups=8)\n",
" self.conv_block5 = conv_block(128, 256, 3, num_groups=8, stride=2, padding=1)\n",
" self.res_block6 = reslike_block(256, num_groups=8)\n",
" self.conv_block6 = conv_block(256, 256, 3, num_groups=8, stride=1, padding=1)\n",
" self.res_block7 = reslike_block(256, num_groups=8)\n",
" self.conv_block7 = conv_block(256, 256, 3, num_groups=8, stride=1, padding=1)\n",
" self.res_block8 = reslike_block(256, num_groups=8)\n",
" self.conv_block8 = conv_block(256, 256, 3, num_groups=8, stride=1, padding=1)\n",
" self.res_block9 = reslike_block(256, num_groups=8)\n",
" \n",
" def forward(self, x):\n",
" x = self.conv1(x) # Output size: (1, 32, 160, 192, 128)\n",
" x = self.res_block1(x) # Output size: (1, 32, 160, 192, 128)\n",
" x = self.conv_block1(x) # Output size: (1, 64, 80, 96, 64)\n",
" x = self.res_block2(x) # Output size: (1, 64, 80, 96, 64)\n",
" x = self.conv_block2(x) # Output size: (1, 64, 80, 96, 64)\n",
" x = self.res_block3(x) # Output size: (1, 64, 80, 96, 64)\n",
" x = self.conv_block3(x) # Output size: (1, 128, 40, 48, 32)\n",
" x = self.res_block4(x) # Output size: (1, 128, 40, 48, 32)\n",
" x = self.conv_block4(x) # Output size: (1, 128, 40, 48, 32)\n",
" x = self.res_block5(x) # Output size: (1, 128, 40, 48, 32)\n",
" x = self.conv_block5(x) # Output size: (1, 256, 20, 24, 16)\n",
" x = self.res_block6(x) # Output size: (1, 256, 20, 24, 16)\n",
" x = self.conv_block6(x) # Output size: (1, 256, 20, 24, 16)\n",
" x = self.res_block7(x) # Output size: (1, 256, 20, 24, 16)\n",
" x = self.conv_block7(x) # Output size: (1, 256, 20, 24, 16)\n",
" x = self.res_block8(x) # Output size: (1, 256, 20, 24, 16)\n",
" x = self.conv_block8(x) # Output size: (1, 256, 20, 24, 16)\n",
" x = self.res_block9(x) # Output size: (1, 256, 20, 24, 16)\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"code_folding": [
0
],
"colab": {},
"colab_type": "code",
"id": "bgPmjWCq6n1F"
},
"outputs": [],
"source": [
"########## Sanity-check ############\n",
"# input = torch.randn(1, 4, 160, 192, 128)\n",
"# input = input.cuda()\n",
"# encoder = Encoder()\n",
"# encoder.cuda()\n",
"# ms = [encoder.res_block1, encoder.res_block3, encoder.res_block5]\n",
"# hooks = Hooks(ms, hook_debug, is_forward=True, detach=False)\n",
"# output = encoder(input)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "BKmlf1qY74Fx"
},
"source": [
"## Decoder Part"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"code_folding": [
0
],
"colab": {},
"colab_type": "code",
"id": "p9jCdQAeBTch"
},
"outputs": [],
"source": [
"class Decoder(nn.Module):\n",
" \"Decoder Part\"\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.upsize1 = upsize(256, 128)\n",
" self.reslike1 = reslike_block(128, num_groups=8)\n",
" self.upsize2 = upsize(128, 64)\n",
" self.reslike2 = reslike_block(64, num_groups=8)\n",
" self.upsize3 = upsize(64, 32)\n",
" self.reslike3 = reslike_block(32, num_groups=8)\n",
" self.conv1 = nn.Conv3d(32, 3, 1) \n",
" self.sigmoid1 = torch.nn.Sigmoid()\n",
"\n",
" def forward(self, x):\n",
" x = self.upsize1(x) # Output size: (1, 128, 40, 48, 32)\n",
" x = x + hooks.stored[2] # Output size: (1, 128, 40, 48, 32)\n",
" x = self.reslike1(x) # Output size: (1, 128, 40, 48, 32)\n",
" x = self.upsize2(x) # Output size: (1, 64, 80, 96, 64)\n",
" x = x + hooks.stored[1] # Output size: (1, 64, 80, 96, 64)\n",
" x = self.reslike2(x) # Output size: (1, 64, 80, 96, 64)\n",
" x = self.upsize3(x) # Output size: (1, 32, 160, 192, 128)\n",
" x = x + hooks.stored[0] # Output size: (1, 32, 160, 192, 128)\n",
" x = self.reslike3(x) # Output size: (1, 32, 160, 192, 128)\n",
" x = self.conv1(x) # Output size: (1, 3, 160, 192, 128)\n",
" x = self.sigmoid1(x) # Output size: (1, 3, 160, 192, 128)\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"code_folding": [
0
],
"colab": {},
"colab_type": "code",
"id": "54LhlCx7hOt6"
},
"outputs": [],
"source": [
"############ Sanity-check ############\n",
"# input = torch.randn(1, 256, 20, 24, 16)\n",
"# input = input.cuda()\n",
"# decoder = Decoder()\n",
"# decoder.cuda()\n",
"# output = decoder(input)\n",
"# output.shape"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "Sq9kLEFbx8sF"
},
"source": [
"## VAE Part"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"code_folding": [],
"colab": {},
"colab_type": "code",
"id": "KEpqknq3hUaq"
},
"outputs": [],
"source": [
"class VAEEncoder(nn.Module):\n",
" \"Variational auto-encoder encoder part\"\n",
" def __init__(self, latent_dim:int=128):\n",
" super().__init__()\n",
" self.latent_dim = latent_dim\n",
" self.conv_block = conv_block(256, 16, 3, num_groups=8, stride=2, padding=1)\n",
" self.linear1 = nn.Linear(60, 1)\n",
" \n",
" # Assumed latent variable's probability density function parameters\n",
" self.z_mean = nn.Linear(256, latent_dim)\n",
" self.z_log_var = nn.Linear(256, latent_dim)\n",
" #TODO: It should work with or without GPU\n",
" self.epsilon = torch.randn(1, latent_dim, device='cuda')\n",
" \n",
" def forward(self, x):\n",
" x = self.conv_block(x) # Output size: (1, 16, 10, 12, 8) \n",
" x = x.view(256, -1) # Output size: (256, 60) \n",
" x = self.linear1(x) # Output size: (256, 1)\n",
" x = x.view(1, 256) # Output size: (1, 256) \n",
" z_mean = self.z_mean(x) # Output size: (1, 128)\n",
" z_var = self.z_log_var(x).exp() # Output size: (1, 128) \n",
" \n",
" return z_mean + z_var * self.epsilon # Output size: (1, 128) "
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"code_folding": [
0
],
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"colab_type": "code",
"id": "ll26pBm9tj7-",
"outputId": "f1e9300e-8e79-4c66-8d0e-6897ce6b7f80"
},
"outputs": [],
"source": [
"############ Sanity-check ############\n",
"# input = torch.randn(1, 256, 20, 24, 16)\n",
"# input = input.cuda()\n",
"# vae_encoder = VAEEncoder(latent_dim=128)\n",
"# vae_encoder.cuda()\n",
"# output = vae_encoder(output)\n",
"# output.shape"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"code_folding": [
0
],
"colab": {},
"colab_type": "code",
"id": "tl4tYTaXe1qw"
},
"outputs": [],
"source": [
"class VAEDecoder(nn.Module):\n",
" \"Variational auto-encoder decoder part\"\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.linear1 = nn.Linear(128, 256*60)\n",
" self.relu1 = nn.ReLU()\n",
" self.upsize1 = upsize(16, 256)\n",
" self.upsize2 = upsize(256, 128)\n",
" self.reslike1 = reslike_block(128, num_groups=8)\n",
" self.upsize3 = upsize(128, 64)\n",
" self.reslike2 = reslike_block(64, num_groups=8)\n",
" self.upsize4 = upsize(64, 32)\n",
" self.reslike3 = reslike_block(32, num_groups=8)\n",
" self.conv1 = nn.Conv3d(32, 4, 1)\n",
" \n",
" def forward(self, x):\n",
" x = self.linear1(x) # Output size: (1, 256*60) \n",
" x = self.relu1(x) # Output size: (1, 256*60)\n",
" x = x.view(1, 16, 10, 12, 8) # Output size: (1, 16, 10, 12, 8)\n",
" x = self.upsize1(x) # Output size: (1, 256, 20, 24, 16)\n",
" x = self.upsize2(x) # Output size: (1, 128, 40, 48, 32)\n",
" x = self.reslike1(x) # Output size: (1, 128, 40, 48, 32)\n",
" x = self.upsize3(x) # Output size: (1, 64, 80, 96, 64)\n",
" x = self.reslike2(x) # Output size: (1, 64, 80, 96, 64)\n",
" x = self.upsize4(x) # Output size: (1, 32, 160, 192, 128)\n",
" x = self.reslike3(x) # Output size: (1, 32, 160, 192, 128)\n",
" x = self.conv1(x) # Output size: (1, 4, 160, 192, 128) \n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"code_folding": [
0
],
"colab": {},
"colab_type": "code",
"id": "RrusoNDpzPOk"
},
"outputs": [],
"source": [
"############ Sanity-check ############\n",
"# input = torch.randn(1, 128)\n",
"# input = input.cuda()\n",
"# vae_decoder = VAEDecoder()\n",
"# vae_decoder.cuda()\n",
"# vae_decoder(output).shape"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "dtLzCKAOEn6c"
},
"source": [
"## AutoUNet"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"code_folding": [],
"colab": {},
"colab_type": "code",
"id": "9lhVuR2QExrp"
},
"outputs": [],
"source": [
"class AutoUNet(nn.Module):\n",
" \"3D U-Net using autoencoder regularization\"\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.encoder = Encoder()\n",
" self.decoder = Decoder()\n",
" self.vencoder = VAEEncoder(latent_dim=128)\n",
" self.vdecoder = VAEDecoder()\n",
"\n",
" def forward(self, input):\n",
" interm_res = self.encoder(input)\n",
" top_res = self.decoder(interm_res) # Output size: (1, 3, 160, 192, 128)\n",
" bottom_res = self.vdecoder(self.vencoder(interm_res)) # Output size: (1, 4, 160, 192, 128)\n",
" return top_res, bottom_res"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"code_folding": [],
"scrolled": true
},
"outputs": [],
"source": [
"############ Sanity-check ############\n",
"input = torch.randn(1, 4, 160, 192, 128)\n",
"input = input.cuda()\n",
"model = AutoUNet()\n",
"model.cuda()\n",
"\n",
"ms = [model.encoder.res_block1, \n",
" model.encoder.res_block3, \n",
" model.encoder.res_block5, \n",
" model.vencoder.z_mean, \n",
" model.vencoder.z_log_var]\n",
"\n",
"hooks = hook_outputs(ms, detach=False, grad=False) #check: overwrite for each iteration?\n",
"#hooks = Hooks(ms, hook_debug, is_forward=True, detach=False)\n",
"\n",
"output = model(input)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "ZSPf7atqhOuG"
},
"source": [
"## Custom Loss "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"code_folding": [],
"colab": {
"base_uri": "https://localhost:8080/",
"height": 85
},
"colab_type": "code",
"id": "OQ4vfaR-L9Wz",
"outputId": "cd5cb780-4027-4e12-e0de-07485713db38",
"scrolled": false
},
"outputs": [],
"source": [
"# Set the global variables\n",
"_, C, H, W, D = [input.shape[i] for i in range(len(input.shape))]\n",
"c = output[0].shape[1]\n",
"\n",
"print(\"Channels:\", C)\n",
"print(\"Height:\", H)\n",
"print(\"Width:\", W)\n",
"print(\"Depth:\", D)\n",
"print(\"The Number Of Labels:\", c)"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"code_folding": [],
"colab": {},
"colab_type": "code",
"id": "j7cmXkIvhOuI"
},
"outputs": [],
"source": [
"class SoftDiceLoss(Module): \n",
" \"Soft dice loss based on a measure of overlap between prediction and ground truth\"\n",
" def __init__(self, epsilon=1e-6, c=c):\n",
" super().__init__()\n",
" self.epsilon = epsilon\n",
" self.c = c\n",
" \n",
" def forward(self, x:Tensor, y:Tensor):\n",
" intersection = 2 * ( (x*y).sum() )\n",
" union = (x**2).sum() + (y**2).sum() \n",
" return 1 - ( ( intersection / (union + self.epsilon) ) / self.c )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"code_folding": [
0
]
},
"outputs": [],
"source": [
"####### Sanity-check ############\n",
"loss = "
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"code_folding": [],
"colab": {},
"colab_type": "code",
"id": "kOjrJ44uhOuK"
},
"outputs": [],
"source": [
"class KLDivergence(Module): \n",
" \"KL divergence between the estimated normal distribution and a prior distribution\"\n",
" N = H * W * D #hyperparameter check\n",
"\n",
" def __init__(self):\n",
" super().__init__()\n",
" \n",
" def forward(self, z_mean:Tensor, z_log_var:Tensor):\n",
" z_var = z_log_var.exp()\n",
" return (1/self.N) * ( (z_mean**2 + z_var**2 - z_log_var**2 - 1).sum() )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"code_folding": []
},
"outputs": [],
"source": [
"####### Sanity-check ############\n",
"loss2 = KLDivergence()(z_mean=hooks.stored[3], z_log_var=hooks.stored[4])\n",
"print(loss2)\n",
"loss2.backward()"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"code_folding": [
0
],
"colab": {},
"colab_type": "code",
"id": "HycYhLrohOuM"
},
"outputs": [],
"source": [
"class L2Loss(Module): \n",
" \"Measuring the `Euclidian distance` between prediction and ground truh using `L2 Norm`\"\n",
" def __init__(self):\n",
" super().__init__()\n",
" \n",
" def forward(self, x:Tensor, y:Tensor):\n",
" return ( (x - y)**2 ).sum() "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"code_folding": [
0
]
},
"outputs": [],
"source": [
"####### Sanity-check ############\n",
"loss3 = L2Loss()(bottom_res=output[1], orig=input)\n",
"print(loss3)\n",
"loss3.backward()"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "MsP_HOw2_6Jd"
},
"source": [
"## Optimizer"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "XYaFQ6nQ_8O4"
},
"outputs": [],
"source": [
"optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "GPK9Qfc0_tGL"
},
"source": [
"## Training"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"code_folding": [],
"colab": {},
"colab_type": "code",
"id": "3YkqxURk_w8K"
},
"outputs": [],
"source": [
"for epoch in range(epochs):\n",
" \n",
" model.train()\n",
" for xb,yb in train_dl:\n",
" top_res, bottom_res = model(xb)\n",
" top_y, bottom_y = train_seg, input\n",
" z_mean, z_log_var = hooks.stored[4], hooks.stored[5] \n",
" loss = SoftDiceLoss()(top_res, top_y) + \\\n",
" (0.1 * KLDivergence()(z_mean, z_log_var)) + \\\n",
" (0.1 * L2Loss()(bottom_res, bottom_y))\n",
" loss.backward()\n",
" optimizer.step()\n",
" optimizer.zero_grad()\n",
"\n",
" model.eval()\n",
" with torch.no_grad():\n",
" tot_loss, tot_acc = 0., 0.\n",
" for xb, yb in valid_dl: \n",
" top_res, bottom_res = model(xb)\n",
" top_y, bottom_y = valid_seg, input\n",
" z_mean, z_log_var = hooks.stored[4], hooks.stored[5]\n",
" loss = SoftDiceLoss()(top_res, top_y) + \\\n",
" (0.1 * KLDivergence()(z_mean, z_log_var)) + \\\n",
" (0.1 * L2Loss()(bottom_res, bottom_y)) \n",
" tot_loss += loss\n",
" tot_acc += dice_coeff\n",
"\n",
" nv = len(valid_dl)\n",
" return tot_loss/nv, tot_acc/nv"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"heading_collapsed": true,
"id": "GXaVq0m5hUbO"
},
"source": [
"## Memory-check"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"colab": {},
"colab_type": "code",
"hidden": true,
"id": "Xuy-W1NFhUbR",
"outputId": "f9a8cc29-2291-488f-ea8f-44b63cb8bd29"
},
"outputs": [
{
"data": {
"text/plain": [
"9884946432"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Memory ocuupied by Pytorch `Tensors`\n",
"torch.cuda.memory_allocated(device=None)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 697
},
"colab_type": "code",
"hidden": true,
"id": "-F1mbF44hUbO",
"outputId": "6b4ff5a9-766a-48d0-fc2c-bd0675e303e8",
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"|===========================================================================|\n",
"| PyTorch CUDA memory summary, device ID 1 |\n",
"|---------------------------------------------------------------------------|\n",
"| CUDA OOMs: 0 | cudaMalloc retries: 0 |\n",
"|===========================================================================|\n",
"| Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed |\n",
"|---------------------------------------------------------------------------|\n",
"| Allocated memory | 9427 MB | 9855 MB | 10859 MB | 1432 MB |\n",
"| from large pool | 9423 MB | 9851 MB | 10847 MB | 1424 MB |\n",
"| from small pool | 3 MB | 3 MB | 11 MB | 8 MB |\n",
"|---------------------------------------------------------------------------|\n",
"| Active memory | 9427 MB | 9855 MB | 10859 MB | 1432 MB |\n",
"| from large pool | 9423 MB | 9851 MB | 10847 MB | 1424 MB |\n",
"| from small pool | 3 MB | 3 MB | 11 MB | 8 MB |\n",
"|---------------------------------------------------------------------------|\n",
"| GPU reserved memory | 9482 MB | 10012 MB | 10012 MB | 542720 KB |\n",
"| from large pool | 9478 MB | 10008 MB | 10008 MB | 542720 KB |\n",
"| from small pool | 4 MB | 4 MB | 4 MB | 0 KB |\n",
"|---------------------------------------------------------------------------|\n",
"| Non-releasable memory | 56300 KB | 168416 KB | 981 MB | 926 MB |\n",
"| from large pool | 55744 KB | 167872 KB | 970 MB | 915 MB |\n",
"| from small pool | 556 KB | 2034 KB | 11 MB | 11 MB |\n",
"|---------------------------------------------------------------------------|\n",
"| Allocations | 193 | 265 | 837 | 644 |\n",
"| from large pool | 76 | 124 | 145 | 69 |\n",
"| from small pool | 117 | 142 | 692 | 575 |\n",
"|---------------------------------------------------------------------------|\n",
"| Active allocs | 193 | 265 | 837 | 644 |\n",
"| from large pool | 76 | 124 | 145 | 69 |\n",
"| from small pool | 117 | 142 | 692 | 575 |\n",
"|---------------------------------------------------------------------------|\n",
"| GPU reserved segments | 66 | 91 | 91 | 25 |\n",
"| from large pool | 64 | 89 | 89 | 25 |\n",
"| from small pool | 2 | 2 | 2 | 0 |\n",
"|---------------------------------------------------------------------------|\n",
"| Non-releasable allocs | 10 | 31 | 100 | 90 |\n",
"| from large pool | 8 | 29 | 42 | 34 |\n",
"| from small pool | 2 | 5 | 58 | 56 |\n",
"|===========================================================================|\n",
"\n"
]
}
],
"source": [
"# Memory status\n",
"print(torch.cuda.memory_summary(device=None, abbreviated=False))"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [
"b6KgSfpahUZ1",
"C77ffh0whUZ7",
"btBiurUPxx7t",
"MY131WWbx3nN",
"BKmlf1qY74Fx",
"Sq9kLEFbx8sF",
"dtLzCKAOEn6c",
"ZSPf7atqhOuG",
"MsP_HOw2_6Jd",
"GPK9Qfc0_tGL",
"GXaVq0m5hUbO",
"v-ODWm3ehUbG"
],
"name": "model_prototype_1.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 1
}