Download this file

522 lines (521 with data), 20.2 kB

{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Hands-on with the U-Net"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this exercise you will take a look at a U-net architecture and will train your model with a tiny dataset. \n",
    "\n",
    "The objective of this exercise is not to train the best model ever, but rather to give you a feel for how the U-net model operates and how it can be applied to real-world datasets.\n",
    "\n",
    "For this exercise the data consists is one abdominal CT scan with segmentation of the spleen. There is not quite enough to train a reasonable model, however you can play with parameters and overfit your model to this particular case so that it produces a segmentation that looks good. That is what we are going to do!\n",
    "\n",
    "If you are feeling adventurous, and want to try and train this model for real, you can download the full spleen segmentation dataset from here: http://medicaldecathlon.com/\n",
    "\n",
    "Below we provide some starter code and put comments prepended with TASK to indicate places where you need to fill in."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.cm as cm\n",
    "import numpy.ma as ma\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import nibabel as nib\n",
    "from collections import OrderedDict\n",
    "import torch.optim as optim\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Constructing U-net\n",
    "\n",
    "Below is a UNet implementation that is based on Mateusz Buda's kernel for Brain MRI Segmentation Challenge: https://www.kaggle.com/mateuszbuda/lgg-mri-segmentation \n",
    "\n",
    "It closely resembles the model architecture that has been presented in the original U-net paper: https://arxiv.org/pdf/1505.04597.pdf\n",
    "\n",
    "<img src=\"hands-on.img/unet.png\" width=\"600\">"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class UNet(nn.Module):\n",
    "\n",
    "    def __init__(self, in_channels=1, out_channels=1, init_features=32):\n",
    "        super(UNet, self).__init__()\n",
    "\n",
    "        # This parameter controls how far the UNet blocks grow as you go down \n",
    "        # the contracting path\n",
    "        features = init_features\n",
    "\n",
    "        # Below, we set up our layers\n",
    "        self.encoder1 = self.unet_block(in_channels, features, name=\"enc1\")\n",
    "        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)\n",
    "        self.encoder2 = self.unet_block(features, features * 2, name=\"enc2\")\n",
    "        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)\n",
    "        self.encoder3 = self.unet_block(features * 2, features * 4, name=\"enc3\")\n",
    "        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)\n",
    "        self.encoder4 = self.unet_block(features * 4, features * 8, name=\"enc4\")\n",
    "        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)\n",
    "\n",
    "        self.bottleneck = self.unet_block(features * 8, features * 16, name=\"bottleneck\")\n",
    "\n",
    "        # Note the transposed convolutions here. These are the operations that perform\n",
    "        # the upsampling. This is a blog post that explains them nicely:\n",
    "        # https://medium.com/activating-robotic-minds/up-sampling-with-transposed-convolution-9ae4f2df52d0\n",
    "        self.upconv4 = nn.ConvTranspose2d(\n",
    "            features * 16, features * 8, kernel_size=2, stride=2\n",
    "        )\n",
    "        self.decoder4 = self.unet_block((features * 8) * 2, features * 8, name=\"dec4\")\n",
    "        self.upconv3 = nn.ConvTranspose2d(\n",
    "            features * 8, features * 4, kernel_size=2, stride=2\n",
    "        )\n",
    "        self.decoder3 = self.unet_block((features * 4) * 2, features * 4, name=\"dec3\")\n",
    "        self.upconv2 = nn.ConvTranspose2d(\n",
    "            features * 4, features * 2, kernel_size=2, stride=2\n",
    "        )\n",
    "        self.decoder2 = self.unet_block((features * 2) * 2, features * 2, name=\"dec2\")\n",
    "        self.upconv1 = nn.ConvTranspose2d(\n",
    "            features * 2, features, kernel_size=2, stride=2\n",
    "        )\n",
    "        self.decoder1 = self.unet_block(features * 2, features, name=\"dec1\")\n",
    "\n",
    "        self.conv = nn.Conv2d(\n",
    "            in_channels=features, out_channels=out_channels, kernel_size=1\n",
    "        )\n",
    "        \n",
    "        self.softmax = nn.Softmax(dim = 1)\n",
    "\n",
    "    # This method runs the model on a data vector. Note that this particular\n",
    "    # implementation is performing 2D convolutions, therefore it is \n",
    "    # set up to deal with 2D/2.5D approaches. If you want to try out the 3D convolutions\n",
    "    # from the previous exercise, you will need to modify the initialization code\n",
    "    def forward(self, x):\n",
    "        # Contracting/downsampling path. Each encoder here is a set of 2x convolutional layers\n",
    "        # with batch normalization, followed by activation function and max pooling\n",
    "        enc1 = self.encoder1(x)\n",
    "        enc2 = self.encoder2(self.pool1(enc1))\n",
    "        enc3 = self.encoder3(self.pool2(enc2))\n",
    "        enc4 = self.encoder4(self.pool3(enc3))\n",
    "\n",
    "        # This is the bottom-most 1-1 layer.\n",
    "        # In the original paper, a dropout layer is suggested here, but\n",
    "        # we won't use it here since our dataset is tiny and we basically want \n",
    "        # to overfit to it\n",
    "        bottleneck = self.bottleneck(self.pool4(enc4))\n",
    "\n",
    "        # Expanding path. Note how output of each layer is concatenated with the downsampling block\n",
    "        dec4 = self.upconv4(bottleneck)\n",
    "        dec4 = torch.cat((dec4, enc4), dim=1)\n",
    "        dec4 = self.decoder4(dec4)\n",
    "        dec3 = self.upconv3(dec4)\n",
    "        dec3 = torch.cat((dec3, enc3), dim=1)\n",
    "        dec3 = self.decoder3(dec3)\n",
    "        dec2 = self.upconv2(dec3)\n",
    "        dec2 = torch.cat((dec2, enc2), dim=1)\n",
    "        dec2 = self.decoder2(dec2)\n",
    "        dec1 = self.upconv1(dec2)\n",
    "        dec1 = torch.cat((dec1, enc1), dim=1)\n",
    "        dec1 = self.decoder1(dec1)\n",
    "        \n",
    "        out_conv = self.conv(dec1)\n",
    "        \n",
    "        return self.softmax(out_conv)\n",
    "\n",
    "    # This method executes the \"U-net block\"\n",
    "    def unet_block(self, in_channels, features, name):\n",
    "        return nn.Sequential(\n",
    "            OrderedDict(\n",
    "                [\n",
    "                    (\n",
    "                        name + \"conv1\",\n",
    "                        nn.Conv2d(\n",
    "                            in_channels=in_channels,\n",
    "                            out_channels=features,\n",
    "                            kernel_size=3,\n",
    "                            padding=1,\n",
    "                            bias=False,\n",
    "                        ),\n",
    "                    ),\n",
    "                    (name + \"norm1\", nn.BatchNorm2d(num_features=features)),\n",
    "                    (name + \"relu1\", nn.ReLU(inplace=True)),\n",
    "                    (\n",
    "                        name + \"conv2\",\n",
    "                        nn.Conv2d(\n",
    "                            in_channels=features,\n",
    "                            out_channels=features,\n",
    "                            kernel_size=3,\n",
    "                            padding=1,\n",
    "                            bias=False,\n",
    "                        ),\n",
    "                    ),\n",
    "                    (name + \"norm2\", nn.BatchNorm2d(num_features=features)),\n",
    "                    (name + \"relu2\", nn.ReLU(inplace=True)),\n",
    "                ]\n",
    "            )\n",
    "        )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading the training data\n",
    "\n",
    "Load training data and labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "training_volume = nib.load(\"data/spleen1_img.nii.gz\").get_fdata()\n",
    "training_label = nib.load(\"data/spleen1_label.nii.gz\").get_fdata()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's see what we are about to segment and where's that spleen thing located:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(training_volume[:,:,5] + training_label[:,:,5]*500, cmap=\"gray\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# We assume our label has one-hot encoding. Let's confirm how many distinct classes do we have in our label volume\n",
    "\n",
    "np.unique(training_label)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's set up the training device. Do we have a GPU? "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if (torch.cuda.is_available()):\n",
    "    device = torch.device(\"cuda\")\n",
    "else:\n",
    "    device = torch.device(\"cpu\")\n",
    "    \n",
    "device"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Instantiate the U-Net class and get it ready for training. If you want to experiment with traing hyperparameters, you can re-execute this cell to reset the model weights."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Here we will use one input channel (one image at a time) and two output channels (background and label)\n",
    "unet = UNet(1, 2) \n",
    "\n",
    "# Move all trainable parameters to the device\n",
    "unet.to(device)\n",
    "\n",
    "# We will use Cross Entropy loss function for this one - we are performing per-voxel classification task, \n",
    "# so it should do ok.\n",
    "# Later in the lesson we will discuss what are some of the other options for measuring medical image \n",
    "# segmentation performance.\n",
    "\n",
    "loss = torch.nn.CrossEntropyLoss()\n",
    "\n",
    "# You can play with learning rate later to see what yields best results\n",
    "optimizer = optim.Adam(unet.parameters(), lr=0.001)\n",
    "optimizer.zero_grad()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# By the way, how many trainable parameters does our model have? If you will be playing \n",
    "# with 3D convolutions - compare the difference between 2D and 3D versions.\n",
    "\n",
    "sum(p.numel() for p in unet.parameters() if p.requires_grad)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "From here on let's create the training loop\n",
    "\n",
    "## The training loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "# This is a basic training loop. Complete the code to run the model on first 15 slices \n",
    "# of the volume (that is where the spleen segmentation is - if you include more, you run the chances of background class\n",
    "# overwhelming your tiny network with cross-entropy loss that we are using)\n",
    "\n",
    "# Set up the model for training\n",
    "unet.train()\n",
    "\n",
    "for epoch in range(0,10):\n",
    "    for slice_ix in range(0,15):\n",
    "        # Let's extract the slice from the volume and convert it to tensor that the model will understand. \n",
    "        # Note that we normalize the volume to 0..1 range\n",
    "        slc = training_volume[:,:,slice_ix].astype(np.single)/np.max(training_volume[:,:,slice_ix])\n",
    "        \n",
    "        # Our model accepts a tensor of size (batch_size, channels, w, h). We have batch of 1 and one channel, \n",
    "        # So create the missing dimensions. Also move data to our device\n",
    "        slc_tensor = torch.from_numpy(slc).unsqueeze(0).unsqueeze(0).to(device)\n",
    "        \n",
    "        # TASK: Now extract the slice from label volume into tensor that the network will accept.\n",
    "        # Keep in mind, our cross entropy loss expects integers\n",
    "        \n",
    "        # <YOUR CODE HERE>\n",
    "        # ___SOLUTION\n",
    "    \n",
    "        lbl = training_label[:,:,slice_ix]\n",
    "        lbl_tensor = torch.from_numpy(lbl).unsqueeze(0).long().to(device)\n",
    "        # ___SOLUTION\n",
    "\n",
    "        # Zero-out gradients from the previous pass so that we can start computation from scratch for this backprop run\n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "        # Do the forward pass\n",
    "        pred = unet(slc_tensor)\n",
    "        \n",
    "        # Here we compute our loss function and do the backpropagation pass\n",
    "        l = loss(pred, lbl_tensor)\n",
    "        l.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "    print(f\"Epoch: {epoch}, training loss: {l}\")\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Here's a neat trick: let's visualize our last network prediction with default colormap in matplotlib:\n",
    "# (note the .cpu().detach() calls - we need to move our data to CPU before manipulating it and we need to \n",
    "# stop collecting the computation graph)\n",
    "\n",
    "plt.imshow(pred.cpu().detach()[0,1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Isn't this a nice visualization of what draws our CNN's attention after these few runs? \n",
    "\n",
    "Try re-executing the training cell a few times to see how this attention improves as network learns more"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Save results\n",
    "\n",
    "Lastly, let's run inference on all slices of our volume, turn them into the binary map and save as NIFTI! We will use it in the next exercise."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Let's run inference on just one slice first\n",
    "\n",
    "# Switch model to the eval mode so that no gradient collection happens\n",
    "unet.eval()\n",
    "\n",
    "# TASK: pick a slice from the loaded training_volume Numpy array, convert it into PyTorch tensor,\n",
    "# and run an inference on it, convert result into 2D NumPy array and visualize it. \n",
    "# Don't forget to normalize your data before running inference! Also keep in mind\n",
    "# that our CNN return 2 channels - one for each class - target (spleen) and background \n",
    "\n",
    "# <YOUR CODE HERE>\n",
    "# ___SOLUTION\n",
    "\n",
    "def inference(img):\n",
    "    tsr_test = torch.from_numpy(img.astype(np.single)/np.max(img)).unsqueeze(0).unsqueeze(0)\n",
    "    pred = unet(tsr_test.to(device))\n",
    "    return np.squeeze(pred.cpu().detach())\n",
    "\n",
    "level = 11\n",
    "\n",
    "img_test = training_volume[:,:,level]\n",
    "pred = inference(img_test)\n",
    "\n",
    "plt.imshow(pred[1])\n",
    "# ___SOLUTION\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Now let's convert this into binary mask using PyTorch's argmax function:\n",
    "\n",
    "mask = torch.argmax(pred, dim=0)\n",
    "plt.imshow(mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# TASK: Now you have all you need to create a full NIFTI volume. Compute segmentation predictions for each slice of your volume\n",
    "# and turn them into NumPy array of the same shape as the original volume\n",
    "\n",
    "# <YOUR CODE HERE>\n",
    "# ___SOLUTION\n",
    "mask3d = np.zeros(training_volume.shape)\n",
    "\n",
    "for slc_ix in range(training_volume.shape[2]):\n",
    "    pred = inference(training_volume[:,:,slc_ix])\n",
    "    mask3d[:,:,slc_ix] = torch.argmax(pred, dim=0)\n",
    "    \n",
    "# ___SOLUTION"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "One thing to note here: \n",
    "\n",
    "Remember that [IPP/IOP](http://dicom.nema.org/medical/dicom/current/output/chtml/part03/sect_C.7.6.2.html) and \"affine\" thing I mentioned during lesson on DICOM and NIFTI file formats? Let's check it, just in case if this volume is not aligned perfectly against the coordinate axes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "org_volume = nib.load(\"data/spleen1_img.nii.gz\")\n",
    "org_volume.affine"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And it's not!\n",
    "\n",
    "It is quite important to not forget your coordinate system transforms it when you are saving your mask if you want it to line up with your volume properly. Sometimes your affine is just an identity matrix, but sometimes (like on this volume) it's quite a bit more interesting. If you want to geek out on NIFTI affines, take a look at this: https://nipy.org/nibabel/coordinate_systems.html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Finally, save our NIFTI image, copying affine from the original image\n",
    "\n",
    "img_out = nib.Nifti1Image(mask3d, org_volume.affine)\n",
    "nib.save(img_out, \"data/out.nii.gz\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Concluding remarks\n",
    "Congrats, you have finished the exercise and trained your very own U-net for biomedical image segmentation!\n",
    "\n",
    "If you want to dive deeper, here are a few suggestions:\n",
    "\n",
    "* Try playing with the depth of UNet stacks and see how it affects your quality.\n",
    "* In the last exercise you have been writing your own convolutions. Try changing the model architecture to one that that works with 2.5D or 3D convolutions.\n",
    "* Later in this lesson we will talk about Dice and Jaccard similarity measures. Once you get familiar with that content, try changing the loss function and see how it affects the performance. Look up \"soft dice\" loss function"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}