# -*- coding: utf-8 -*-
"""01_model.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1OWXPL8K-jKC4KgGmYXXkeBWOL2V1biuf
# Setup
"""
import torch
from fastai.callbacks import *
from fastai.vision import *
H = 160
W = 192
D = 128
def conv_block(c_in, c_out, ks, num_groups=None, **conv_kwargs):
"A sequence of modules composed of Group Norm, ReLU and Conv3d in order"
if not num_groups : num_groups = int(c_in/2) if c_in%2 == 0 else None
return nn.Sequential(nn.GroupNorm(num_groups, c_in),
nn.ReLU(),
nn.Conv3d(c_in, c_out, ks, **conv_kwargs))
def reslike_block(nf, num_groups=None, bottle_neck:bool=False, **conv_kwargs):
"A ResNet-like block with the GroupNorm normalization providing optional bottle-neck functionality"
nf_inner = nf / 2 if bottle_neck else nf
return SequentialEx(conv_block(num_groups=num_groups, c_in=nf, c_out=nf_inner, ks=3, stride=1, padding=1, **conv_kwargs),
conv_block(num_groups=num_groups, c_in=nf_inner, c_out=nf, ks=3, stride=1, padding=1, **conv_kwargs),
MergeLayer())
def upsize(c_in, c_out, ks=1, scale=2):
"Reduce the number of features by 2 using Conv with kernel size 1x1x1 and double the spatial dimension using 3D trilinear upsampling"
return nn.Sequential(nn.Conv3d(c_in, c_out, ks),
nn.Upsample(scale_factor=scale, mode='trilinear', align_corners=True))
def hook_debug(module, input, output):
"""
Print out what's been hooked usually for debugging purpose
----------------------------------------------------------
Example:
Hooks(ms, hook_debug, is_forward=True, detach=False)
"""
print('Hooking ' + module.__class__.__name__)
print('output size:', output.data.size())
return output
class Encoder(nn.Module):
"Encoder part"
def __init__(self):
super().__init__()
self.conv1 = nn.Conv3d(4, 32, 3, stride=1, padding=1)
self.res_block1 = reslike_block(32, num_groups=8)
self.conv_block1 = conv_block(32, 64, 3, num_groups=8, stride=2, padding=1)
self.res_block2 = reslike_block(64, num_groups=8)
self.conv_block2 = conv_block(64, 64, 3, num_groups=8, stride=1, padding=1)
self.res_block3 = reslike_block(64, num_groups=8)
self.conv_block3 = conv_block(64, 128, 3, num_groups=8, stride=2, padding=1)
self.res_block4 = reslike_block(128, num_groups=8)
self.conv_block4 = conv_block(128, 128, 3, num_groups=8, stride=1, padding=1)
self.res_block5 = reslike_block(128, num_groups=8)
self.conv_block5 = conv_block(128, 256, 3, num_groups=8, stride=2, padding=1)
self.res_block6 = reslike_block(256, num_groups=8)
self.conv_block6 = conv_block(256, 256, 3, num_groups=8, stride=1, padding=1)
self.res_block7 = reslike_block(256, num_groups=8)
self.conv_block7 = conv_block(256, 256, 3, num_groups=8, stride=1, padding=1)
self.res_block8 = reslike_block(256, num_groups=8)
self.conv_block8 = conv_block(256, 256, 3, num_groups=8, stride=1, padding=1)
self.res_block9 = reslike_block(256, num_groups=8)
def forward(self, x):
x = self.conv1(x) # Output size: (1, 32, 160, 192, 128)
x = self.res_block1(x) # Output size: (1, 32, 160, 192, 128)
x = self.conv_block1(x) # Output size: (1, 64, 80, 96, 64)
x = self.res_block2(x) # Output size: (1, 64, 80, 96, 64)
x = self.conv_block2(x) # Output size: (1, 64, 80, 96, 64)
x = self.res_block3(x) # Output size: (1, 64, 80, 96, 64)
x = self.conv_block3(x) # Output size: (1, 128, 40, 48, 32)
x = self.res_block4(x) # Output size: (1, 128, 40, 48, 32)
x = self.conv_block4(x) # Output size: (1, 128, 40, 48, 32)
x = self.res_block5(x) # Output size: (1, 128, 40, 48, 32)
x = self.conv_block5(x) # Output size: (1, 256, 20, 24, 16)
x = self.res_block6(x) # Output size: (1, 256, 20, 24, 16)
x = self.conv_block6(x) # Output size: (1, 256, 20, 24, 16)
x = self.res_block7(x) # Output size: (1, 256, 20, 24, 16)
x = self.conv_block7(x) # Output size: (1, 256, 20, 24, 16)
x = self.res_block8(x) # Output size: (1, 256, 20, 24, 16)
x = self.conv_block8(x) # Output size: (1, 256, 20, 24, 16)
x = self.res_block9(x) # Output size: (1, 256, 20, 24, 16)
return x
class Decoder(nn.Module):
"Decoder Part"
def __init__(self):
super().__init__()
self.upsize1 = upsize(256, 128)
self.reslike1 = reslike_block(128, num_groups=8)
self.upsize2 = upsize(128, 64)
self.reslike2 = reslike_block(64, num_groups=8)
self.upsize3 = upsize(64, 32)
self.reslike3 = reslike_block(32, num_groups=8)
self.conv1 = nn.Conv3d(32, 3, 1)
self.sigmoid1 = torch.nn.Sigmoid()
def forward(self, x):
x = self.upsize1(x) # Output size: (1, 128, 40, 48, 32)
x = x + hooks.stored[2] # Output size: (1, 128, 40, 48, 32)
x = self.reslike1(x) # Output size: (1, 128, 40, 48, 32)
x = self.upsize2(x) # Output size: (1, 64, 80, 96, 64)
x = x + hooks.stored[1] # Output size: (1, 64, 80, 96, 64)
x = self.reslike2(x) # Output size: (1, 64, 80, 96, 64)
x = self.upsize3(x) # Output size: (1, 32, 160, 192, 128)
x = x + hooks.stored[0] # Output size: (1, 32, 160, 192, 128)
x = self.reslike3(x) # Output size: (1, 32, 160, 192, 128)
x = self.conv1(x) # Output size: (1, 3, 160, 192, 128)
x = self.sigmoid1(x) # Output size: (1, 3, 160, 192, 128)
return x
class VAEEncoder(nn.Module):
"Variational auto-encoder encoder part"
def __init__(self, latent_dim:int=128):
super().__init__()
self.latent_dim = latent_dim
self.conv_block = conv_block(256, 16, 3, num_groups=8, stride=2, padding=1)
self.linear1 = nn.Linear(60, 1)
# Assumed latent variable's probability density function parameters
self.z_mean = nn.Linear(256, latent_dim)
self.z_log_var = nn.Linear(256, latent_dim)
self.epsilon = nn.Parameter(torch.randn(1, latent_dim))
def forward(self, x):
x = self.conv_block(x) # Output size: (1, 16, 10, 12, 8)
x = x.view(256, -1) # Output size: (256, 60)
x = self.linear1(x) # Output size: (256, 1)
x = x.view(1, 256) # Output size: (1, 256)
z_mean = self.z_mean(x) # Output size: (1, 128)
z_var = self.z_log_var(x).exp() # Output size: (1, 128)
return z_mean + z_var * self.epsilon # Output size: (1, 128)
class VAEDecoder(nn.Module):
"Variational auto-encoder decoder part"
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(128, 256*60)
self.relu1 = nn.ReLU()
self.upsize1 = upsize(16, 256)
self.upsize2 = upsize(256, 128)
self.reslike1 = reslike_block(128, num_groups=8)
self.upsize3 = upsize(128, 64)
self.reslike2 = reslike_block(64, num_groups=8)
self.upsize4 = upsize(64, 32)
self.reslike3 = reslike_block(32, num_groups=8)
self.conv1 = nn.Conv3d(32, 4, 1)
def forward(self, x):
x = self.linear1(x) # Output size: (1, 256*60)
x = self.relu1(x) # Output size: (1, 256*60)
x = x.view(1, 16, 10, 12, 8) # Output size: (1, 16, 10, 12, 8)
x = self.upsize1(x) # Output size: (1, 256, 20, 24, 16)
x = self.upsize2(x) # Output size: (1, 128, 40, 48, 32)
x = self.reslike1(x) # Output size: (1, 128, 40, 48, 32)
x = self.upsize3(x) # Output size: (1, 64, 80, 96, 64)
x = self.reslike2(x) # Output size: (1, 64, 80, 96, 64)
x = self.upsize4(x) # Output size: (1, 32, 160, 192, 128)
x = self.reslike3(x) # Output size: (1, 32, 160, 192, 128)
x = self.conv1(x) # Output size: (1, 4, 160, 192, 128)
return x
class AutoUNet(nn.Module):
"3D U-Net using autoencoder regularization"
def __init__(self):
super().__init__()
self.encoder = Encoder()
self.decoder = Decoder()
self.vencoder = VAEEncoder(latent_dim=128)
self.vdecoder = VAEDecoder()
def forward(self, input):
interm_res = self.encoder(input)
top_res = self.decoder(interm_res) # Output size: (1, 3, 160, 192, 128)
bottom_res = self.vdecoder(self.vencoder(interm_res)) # Output size: (1, 4, 160, 192, 128)
return top_res, bottom_res
class SoftDiceLoss(nn.Module):
"Soft dice loss based on a measure of overlap between prediction and ground truth"
def __init__(self, epsilon=1e-6, c=3):
super().__init__()
self.epsilon = epsilon
self.c = 3
def forward(self, x:Tensor, y:Tensor):
intersection = 2 * ( (x*y).sum() )
union = (x**2).sum() + (y**2).sum()
return 1 - ( ( intersection / (union + self.epsilon) ) / self.c )
class KLDivergence(nn.Module):
"KL divergence between the estimated normal distribution and a prior distribution"
N = H * W * D #hyperparameter check
def __init__(self):
super().__init__()
def forward(self, z_mean:Tensor, z_log_var:Tensor):
z_var = z_log_var.exp()
return (1/self.N) * ( (z_mean**2 + z_var**2 - z_log_var**2 - 1).sum() )
class L2Loss(nn.Module):
"Measuring the `Euclidian distance` between prediction and ground truh using `L2 Norm`"
def __init__(self):
super().__init__()
def forward(self, x:Tensor, y:Tensor):
return ( (x - y)**2 ).sum()
autounet = AutoUNet()
ms = [autounet.encoder.res_block1,
autounet.encoder.res_block3,
autounet.encoder.res_block5,
autounet.vencoder.z_mean,
autounet.vencoder.z_log_var]
hooks = hook_outputs(ms, detach=False, grad=False)
lr = 1e-4
optimizer = optim.Adam(autounet.parameters(), lr)