import torch
import torch.nn as nn
from src.models.unet3d.building_blocks import Encoder, Decoder, DoubleConv, ExtResNetBlock
from torchsummary import summary
def number_of_features_per_level(init_channel_number, num_levels):
return [init_channel_number * 2 ** k for k in range(num_levels)]
class Abstract3DUNet(nn.Module):
"""
Base class for standard and residual UNet.
Args:
in_channels (int): number of input channels
out_channels (int): number of output segmentation masks;
Note that that the of out_channels might correspond to either
different semantic classes or to different binary segmentation mask.
It's up to the user of the class to interpret the out_channels and
use the proper loss criterion during training (i.e. CrossEntropyLoss (multi-class)
or BCEWithLogitsLoss (two-class) respectively)
f_maps (int, tuple): number of feature maps at each level of the encoder; if it's an integer the number
of feature maps is given by the geometric progression: f_maps ^ k, k=1,2,3,4
final_sigmoid (bool): if True apply element-wise nn.Sigmoid after the
final 1x1 convolution, otherwise apply nn.Softmax. MUST be True if nn.BCELoss (two-class) is used
to train the model. MUST be False if nn.CrossEntropyLoss (multi-class) is used to train the model.
basic_module: basic model for the encoder/decoder (DoubleConv, ExtResNetBlock, ....)
layer_order (string): determines the order of layers
in `SingleConv` module. e.g. 'crg' stands for Conv3d+ReLU+GroupNorm3d.
See `SingleConv` for more info
num_groups (int): number of groups for the GroupNorm
num_levels (int): number of levels in the encoder/decoder path (applied only if f_maps is an int)
testing (bool): if True (testing mode) the `final_activation` (if present, i.e. `is_segmentation=true`)
will be applied as the last operation during the forward pass; if False the model is in training mode
and the `final_activation` (even if present) won't be applied; default: False
conv_kernel_size (int or tuple): size of the convolving kernel in the basic_module
pool_kernel_size (int or tuple): the size of the window
conv_padding (int or tuple): add zero-padding added to all three sides of the input
"""
def __init__(self, in_channels, out_channels, final_sigmoid, basic_module, f_maps=64, layer_order='gcr',
num_groups=8, num_levels=4,
conv_kernel_size=3, pool_kernel_size=2, conv_padding=1, **kwargs):
super(Abstract3DUNet, self).__init__()
if isinstance(f_maps, int):
f_maps = number_of_features_per_level(f_maps, num_levels=num_levels)
# create encoder path consisting of Encoder modules. Depth of the encoder is equal to `len(f_maps)`
encoders = []
for i, out_feature_num in enumerate(f_maps):
if i == 0:
encoder = Encoder(in_channels, out_feature_num,
apply_pooling=False, # skip pooling in the first encoder
basic_module=basic_module,
conv_layer_order=layer_order,
conv_kernel_size=conv_kernel_size,
num_groups=num_groups,
padding=conv_padding)
else:
encoder = Encoder(f_maps[i - 1], out_feature_num,
basic_module=basic_module,
conv_layer_order=layer_order,
conv_kernel_size=conv_kernel_size,
num_groups=num_groups,
pool_kernel_size=pool_kernel_size,
padding=conv_padding)
encoders.append(encoder)
self.encoders = nn.ModuleList(encoders)
# create decoder path consisting of the Decoder modules. The length of the decoder is equal to `len(f_maps) - 1`
decoders = []
reversed_f_maps = list(reversed(f_maps))
for i in range(len(reversed_f_maps) - 1):
if basic_module == DoubleConv:
in_feature_num = reversed_f_maps[i] + reversed_f_maps[i + 1]
else:
in_feature_num = reversed_f_maps[i]
out_feature_num = reversed_f_maps[i + 1]
decoder = Decoder(in_feature_num, out_feature_num,
basic_module=basic_module,
conv_layer_order=layer_order,
conv_kernel_size=conv_kernel_size,
num_groups=num_groups,
padding=conv_padding)
decoders.append(decoder)
self.decoders = nn.ModuleList(decoders)
# in the last layer a 1×1 convolution reduces the number of output
# channels to the number of labels
self.final_conv = nn.Conv3d(f_maps[0], out_channels, 1)
if final_sigmoid:
self.final_activation = nn.Sigmoid()
else:
self.final_activation = nn.Softmax(dim=1)
def forward(self, x):
# encoder part
encoders_features = []
for encoder in self.encoders:
x = encoder(x)
# reverse the encoder outputs to be aligned with the decoder
encoders_features.insert(0, x)
# remove the last encoder's output from the list
# !!remember: it's the 1st in the list
encoders_features = encoders_features[1:]
# decoder part
for decoder, encoder_features in zip(self.decoders, encoders_features):
# pass the output from the corresponding encoder and the output
# of the previous decoder
x = decoder(encoder_features, x)
x = self.final_conv(x)
scores = self.final_activation(x)
return x, scores
class UNet3D(Abstract3DUNet):
"""
3DUnet model from
`"3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation"
<https://arxiv.org/pdf/1606.06650.pdf>`.
Uses `DoubleConv` as a basic_module and nearest neighbor upsampling in the decoder
"""
def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr',
num_groups=8, num_levels=4, conv_padding=1, **kwargs):
super(UNet3D, self).__init__(in_channels=in_channels, out_channels=out_channels, final_sigmoid=final_sigmoid,
basic_module=DoubleConv, f_maps=f_maps, layer_order=layer_order,
num_groups=num_groups, num_levels=num_levels,
conv_padding=conv_padding, **kwargs)
def test(self):
classes = 4
in_channels = 4
input_tensor = torch.rand(1, in_channels, 32, 32, 32)
ideal_out = torch.rand(1, classes, 32, 32, 32)
out_pred, out_scores = self.forward(input_tensor)
assert ideal_out.shape == out_pred.shape
print("UNet3D test is complete")
class ResidualUNet3D(Abstract3DUNet):
"""
Residual 3DUnet model implementation based on https://arxiv.org/pdf/1706.00120.pdf.
Uses ExtResNetBlock as a basic building block, summation joining instead
of concatenation joining and transposed convolutions for upsampling (watch out for block artifacts).
Since the model effectively becomes a residual net, in theory it allows for deeper UNet.
"""
def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr',
num_groups=8, num_levels=5, conv_padding=1, **kwargs):
super(ResidualUNet3D, self).__init__(in_channels=in_channels, out_channels=out_channels,
final_sigmoid=final_sigmoid,
basic_module=ExtResNetBlock, f_maps=f_maps, layer_order=layer_order,
num_groups=num_groups, num_levels=num_levels, conv_padding=conv_padding, **kwargs)
def test(self):
classes = 4
in_channels = 4
input_tensor = torch.rand(1, in_channels, 32, 32, 32)
ideal_out = torch.rand(1, classes, 32, 32, 32)
out_pred, _ = self.forward(input_tensor)
assert ideal_out.shape == out_pred.shape
print("ResidualUNet3D test is complete")
if __name__ == "__main__":
import numpy as np
net = ResidualUNet3D(in_channels=4, out_channels=4, f_maps=16)
model_parameters = filter(lambda p: p.requires_grad, net.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print(params)
net.test()
unet = UNet3D(in_channels=4, out_channels=4, f_maps=16, final_sigmoid=True, layer_order='crg',
num_groups=8, num_levels=4, conv_padding=1)
model_parameters = filter(lambda p: p.requires_grad, unet.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print(params)
unet.test()