import torch
from torch import nn
class ConvBlock(nn.Module):
def __init__(self, n_stages, n_filters_in, n_filters_out, normalization='none'):
super(ConvBlock, self).__init__()
ops = []
for i in range(n_stages):
if i==0:
input_channel = n_filters_in
else:
input_channel = n_filters_out
ops.append(nn.Conv3d(input_channel, n_filters_out, 3, padding=1))
if normalization == 'batchnorm':
ops.append(nn.BatchNorm3d(n_filters_out))
elif normalization == 'groupnorm':
ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
elif normalization == 'instancenorm':
ops.append(nn.InstanceNorm3d(n_filters_out))
elif normalization != 'none':
assert False
ops.append(nn.ReLU(inplace=True))
self.conv = nn.Sequential(*ops)
def forward(self, x):
x = self.conv(x)
return x
class ResidualConvBlock(nn.Module):
def __init__(self, n_stages, n_filters_in, n_filters_out, normalization='none'):
super(ResidualConvBlock, self).__init__()
ops = []
for i in range(n_stages):
if i == 0:
input_channel = n_filters_in
else:
input_channel = n_filters_out
ops.append(nn.Conv3d(input_channel, n_filters_out, 3, padding=1))
if normalization == 'batchnorm':
ops.append(nn.BatchNorm3d(n_filters_out))
elif normalization == 'groupnorm':
ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
elif normalization == 'instancenorm':
ops.append(nn.InstanceNorm3d(n_filters_out))
elif normalization != 'none':
assert False
if i != n_stages-1:
ops.append(nn.ReLU(inplace=True))
self.conv = nn.Sequential(*ops)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = (self.conv(x) + x)
x = self.relu(x)
return x
class DownsamplingConvBlock(nn.Module):
def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'):
super(DownsamplingConvBlock, self).__init__()
ops = []
if normalization != 'none':
ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))
if normalization == 'batchnorm':
ops.append(nn.BatchNorm3d(n_filters_out))
elif normalization == 'groupnorm':
ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
elif normalization == 'instancenorm':
ops.append(nn.InstanceNorm3d(n_filters_out))
else:
assert False
else:
ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))
ops.append(nn.ReLU(inplace=True))
self.conv = nn.Sequential(*ops)
def forward(self, x):
x = self.conv(x)
return x
class UpsamplingDeconvBlock(nn.Module):
def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'):
super(UpsamplingDeconvBlock, self).__init__()
ops = []
if normalization != 'none':
ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))
if normalization == 'batchnorm':
ops.append(nn.BatchNorm3d(n_filters_out))
elif normalization == 'groupnorm':
ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
elif normalization == 'instancenorm':
ops.append(nn.InstanceNorm3d(n_filters_out))
else:
assert False
else:
ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))
ops.append(nn.ReLU(inplace=True))
self.conv = nn.Sequential(*ops)
def forward(self, x):
x = self.conv(x)
return x
class Upsampling(nn.Module):
def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'):
super(Upsampling, self).__init__()
ops = []
ops.append(nn.Upsample(scale_factor=stride, mode='trilinear',align_corners=False))
ops.append(nn.Conv3d(n_filters_in, n_filters_out, kernel_size=3, padding=1))
if normalization == 'batchnorm':
ops.append(nn.BatchNorm3d(n_filters_out))
elif normalization == 'groupnorm':
ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
elif normalization == 'instancenorm':
ops.append(nn.InstanceNorm3d(n_filters_out))
elif normalization != 'none':
assert False
ops.append(nn.ReLU(inplace=True))
self.conv = nn.Sequential(*ops)
def forward(self, x):
x = self.conv(x)
return x
class VNet(nn.Module):
def __init__(self, n_channels=3, n_classes=2, n_filters=16, normalization='none', has_dropout=False):
super(VNet, self).__init__()
self.has_dropout = has_dropout
self.block_one = ConvBlock(1, n_channels, n_filters, normalization=normalization)
self.block_one_dw = DownsamplingConvBlock(n_filters, 2 * n_filters, normalization=normalization)
self.block_two = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization)
self.block_two_dw = DownsamplingConvBlock(n_filters * 2, n_filters * 4, normalization=normalization)
self.block_three = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization)
self.block_three_dw = DownsamplingConvBlock(n_filters * 4, n_filters * 8, normalization=normalization)
self.block_four = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization)
self.block_four_dw = DownsamplingConvBlock(n_filters * 8, n_filters * 16, stride=(2,2,1), normalization=normalization)
self.block_five = ConvBlock(3, n_filters * 16, n_filters * 16, normalization=normalization)
self.block_five_up = UpsamplingDeconvBlock(n_filters * 16, n_filters * 8, stride=(2,2,1), normalization=normalization)
self.block_six = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization)
self.block_six_up = UpsamplingDeconvBlock(n_filters * 8, n_filters * 4, normalization=normalization)
self.block_seven = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization)
self.block_seven_up = UpsamplingDeconvBlock(n_filters * 4, n_filters * 2, normalization=normalization)
self.block_eight = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization)
self.block_eight_up = UpsamplingDeconvBlock(n_filters * 2, n_filters, normalization=normalization)
self.block_nine = ConvBlock(1, n_filters, n_filters, normalization=normalization)
self.out_conv = nn.Conv3d(n_filters, n_classes, 1, padding=0)
self.dropout = nn.Dropout3d(p=0.5, inplace=False)
# self.__init_weight()
def encoder(self, input):
x1 = self.block_one(input)
x1_dw = self.block_one_dw(x1)
x2 = self.block_two(x1_dw)
x2_dw = self.block_two_dw(x2)
x3 = self.block_three(x2_dw)
x3_dw = self.block_three_dw(x3)
x4 = self.block_four(x3_dw)
x4_dw = self.block_four_dw(x4)
x5 = self.block_five(x4_dw)
# x5 = F.dropout3d(x5, p=0.5, training=True)
if self.has_dropout:
x5 = self.dropout(x5)
res = [x1, x2, x3, x4, x5]
# print(x5.shape)
return res
def decoder(self, features):
x1 = features[0]
x2 = features[1]
x3 = features[2]
x4 = features[3]
x5 = features[4]
x5_up = self.block_five_up(x5)
# print(x5_up.shape)
x5_up = x5_up + x4
x6 = self.block_six(x5_up)
x6_up = self.block_six_up(x6)
x6_up = x6_up + x3
x7 = self.block_seven(x6_up)
x7_up = self.block_seven_up(x7)
x7_up = x7_up + x2
x8 = self.block_eight(x7_up)
x8_up = self.block_eight_up(x8)
x8_up = x8_up + x1
x9 = self.block_nine(x8_up)
# x9 = F.dropout3d(x9, p=0.5, training=True)
if self.has_dropout:
x9 = self.dropout(x9)
out = self.out_conv(x9)
return out
def forward(self, input, turnoff_drop=False):
if turnoff_drop:
has_dropout = self.has_dropout
self.has_dropout = False
features = self.encoder(input)
out = self.decoder(features)
if turnoff_drop:
self.has_dropout = has_dropout
return out
if __name__ == '__main__':
x = torch.randn(2,1,112,112,80)
model = VNet(n_channels=1,n_classes=2, normalization='batchnorm', has_dropout=True)
y = model(x)
print(y.shape)