--- a +++ b/Models_3D.py @@ -0,0 +1,433 @@ +""" +This code was write by Dr. Jun Zhang. If you use this code please follow the licence of Attribution-NonCommercial-ShareAlike 4.0 International. + +""" + +import torch +import torch.nn as nn + + + +def center_crop(layer, n_size): + cropidx = (layer.size(2) - n_size) // 2 + return layer[:, :, cropidx:(cropidx + n_size), cropidx:(cropidx + n_size),cropidx:(cropidx + n_size)] + +class ModelBreast(nn.Module): + def __init__(self, in_channel, n_classes): + self.in_channel = in_channel + self.n_classes = n_classes + self.start_channel = 32 + + super(ModelBreast, self).__init__() + self.eninput = self.encoder(self.in_channel, self.start_channel, bias=False) + self.ec1 = self.encoder(self.start_channel, self.start_channel, bias=False) + self.ec2 = self.encoder(self.start_channel, self.start_channel*2, bias=False) + self.ec3 = self.encoder(self.start_channel*2, self.start_channel*2, bias=False) + self.ec4 = self.encoder(self.start_channel*2, self.start_channel*4, bias=False) + self.ec5 = self.encoder(self.start_channel*4, self.start_channel*4, bias=False) + self.ec6 = self.encoder(self.start_channel*4, self.start_channel*8, bias=False) + self.ec7 = self.encoder(self.start_channel*8, self.start_channel*4, bias=False) + + self.pool = nn.MaxPool3d(2) + + self.dc1 = self.encoder(self.start_channel*4+self.start_channel*4, self.start_channel*4, kernel_size=3, stride=1, bias=False) + self.dc2 = self.encoder(self.start_channel*4, self.start_channel*2, kernel_size=3, stride=1, bias=False) + self.dc3 = self.encoder(self.start_channel*2+self.start_channel*2, self.start_channel*4, kernel_size=3, stride=1, bias=False) + self.dc4 = self.encoder(self.start_channel*4, self.start_channel*2, kernel_size=3, stride=1, bias=False) + self.dc5 = self.encoder(self.start_channel*2+self.start_channel*1, self.start_channel*2, kernel_size=3, stride=1, bias=False) + self.dc6 = self.encoder(self.start_channel*2, self.start_channel*2, kernel_size=3, stride=1, bias=False) + self.dc7 = self.outputs(self.start_channel*2, self.n_classes, kernel_size=1, stride=1,padding=0, bias=False) + + self.up1 = self.decoder(self.start_channel*4, self.start_channel*4) + self.up2 = self.decoder(self.start_channel*2, self.start_channel*2) + self.up3 = self.decoder(self.start_channel*2, self.start_channel*2) + + def encoder(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, + bias=False, batchnorm=True): + if batchnorm: + layer = nn.Sequential( + nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), + nn.BatchNorm3d(out_channels), + nn.ReLU()) + else: + layer = nn.Sequential( + nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), + nn.ReLU()) + return layer + + + def decoder(self, in_channels, out_channels, kernel_size=2, stride=2, padding=0, + output_padding=0, bias=True): + layer = nn.Sequential( + nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride=stride, + padding=padding, output_padding=output_padding, bias=bias), + nn.ReLU()) + return layer + + + def outputs(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, + bias=False, batchnorm=True): + if batchnorm: + layer = nn.Sequential( + nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), + nn.BatchNorm3d(out_channels), + nn.Sigmoid()) + else: + layer = nn.Sequential( + nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), + nn.Sigmoid()) + return layer + + def forward(self, x): + e0 = self.eninput(x) + e0 = self.ec1(e0) + + e1 = self.pool(e0) + e1 = self.ec2(e1) + e1 = self.ec3(e1) + + + e2 = self.pool(e1) + e2 = self.ec4(e2) + e2 = self.ec5(e2) + + e3 = self.pool(e2) + e3 = self.ec6(e3) + e3 = self.ec7(e3) + + d0 = torch.cat((self.up1(e3), center_crop(e2,e3.size(2)*2)), 1) + + + d0 = self.dc1(d0) + d0 = self.dc2(d0) + + + d1 = torch.cat((self.up2(d0), center_crop(e1,d0.size(2)*2)), 1) + + d1 = self.dc3(d1) + d1 = self.dc4(d1) + + d2 = torch.cat((self.up3(d1), center_crop(e0,d1.size(2)*2)), 1) + + d2 = self.dc5(d2) + d2 = self.dc6(d2) + d2 = self.dc7(d2) + + + return d2 + + + +class ModelTumor(nn.Module): + def __init__(self, in_channel, n_classes): + self.in_channel = in_channel + self.n_classes = n_classes + self.start_channel = 32 + + super(ModelTumor, self).__init__() + self.eninput = self.encoder(self.in_channel, self.start_channel, bias=False) + self.ec1 = self.encoder(self.start_channel, self.start_channel, bias=False) + self.ec2 = self.encoder(self.start_channel, self.start_channel*2, bias=False) + self.ec3 = self.encoder(self.start_channel*2, self.start_channel*2, bias=False) + self.ec4 = self.encoder(self.start_channel*2, self.start_channel*4, bias=False) + self.ec5 = self.encoder(self.start_channel*4, self.start_channel*2, bias=False) + + + self.pool = nn.MaxPool3d(2) + + + self.dc1 = self.encoder(self.start_channel*2+self.start_channel*2, self.start_channel*4, kernel_size=3, stride=1, bias=False) + self.dc2 = self.encoder(self.start_channel*4, self.start_channel*2, kernel_size=3, stride=1, bias=False) + self.dc3 = self.encoder(self.start_channel*2+self.start_channel*1, self.start_channel*2, kernel_size=3, stride=1, bias=False) + self.dc4 = self.encoder(self.start_channel*2, self.start_channel*2, kernel_size=3, stride=1, bias=False) + self.dc5 = self.outputs(self.start_channel*2, self.n_classes, kernel_size=1, stride=1,padding=0, bias=False) + + self.up1 = self.decoder(self.start_channel*2, self.start_channel*2) + self.up2 = self.decoder(self.start_channel*2, self.start_channel*2) + + def encoder(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, + bias=False, batchnorm=True): + if batchnorm: + layer = nn.Sequential( + nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), + nn.BatchNorm3d(out_channels), + nn.ReLU()) + else: + layer = nn.Sequential( + nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), + nn.ReLU()) + return layer + + + def decoder(self, in_channels, out_channels, kernel_size=2, stride=2, padding=0, + output_padding=0, bias=True): + layer = nn.Sequential( + nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride=stride, + padding=padding, output_padding=output_padding, bias=bias), + nn.ReLU()) + return layer + + + def outputs(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, + bias=False, batchnorm=True): + if batchnorm: + layer = nn.Sequential( + nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), + nn.BatchNorm3d(out_channels), + nn.Sigmoid()) + else: + layer = nn.Sequential( + nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), + nn.Sigmoid()) + return layer + + def forward(self, x): + e0 = self.eninput(x) + e0 = self.ec1(e0) + + e1 = self.pool(e0) + e1 = self.ec2(e1) + e1 = self.ec3(e1) + + + e2 = self.pool(e1) + e2 = self.ec4(e2) + e2 = self.ec5(e2) + + + d0 = torch.cat((self.up1(e2), center_crop(e1,e2.size(2)*2)), 1) + + + d0 = self.dc1(d0) + d0 = self.dc2(d0) + + + d1 = torch.cat((self.up2(d0), center_crop(e0,d0.size(2)*2)), 1) + + d1 = self.dc3(d1) + d1 = self.dc4(d1) + d1 = self.dc5(d1) + + return d1 + + + +# Note we trained the model with the same size (96*96*96) of input and output. +# We used zero padding to guarantee the same size of output after filtering + +class ModelTumor_train(nn.Module): + def __init__(self, in_channel, n_classes): + self.in_channel = in_channel + self.n_classes = n_classes + self.start_channel = 32 + + super(ModelTumor_train, self).__init__() + self.eninput = self.encoder(self.in_channel, self.start_channel, bias=False) + self.ec1 = self.encoder(self.start_channel, self.start_channel, bias=False) + self.ec2 = self.encoder(self.start_channel, self.start_channel*2, bias=False) + self.ec3 = self.encoder(self.start_channel*2, self.start_channel*2, bias=False) + self.ec4 = self.encoder(self.start_channel*2, self.start_channel*4, bias=False) + self.ec5 = self.encoder(self.start_channel*4, self.start_channel*2, bias=False) + + + self.pool = nn.MaxPool3d(2) + + + self.dc1 = self.encoder(self.start_channel*2+self.start_channel*2, self.start_channel*4, kernel_size=3, stride=1, bias=False) + self.dc2 = self.encoder(self.start_channel*4, self.start_channel*2, kernel_size=3, stride=1, bias=False) + self.dc3 = self.encoder(self.start_channel*2+self.start_channel*1, self.start_channel*2, kernel_size=3, stride=1, bias=False) + self.dc4 = self.encoder(self.start_channel*2, self.start_channel*2, kernel_size=3, stride=1, bias=False) + self.dc5 = self.outputs(self.start_channel*2, self.n_classes, kernel_size=1, stride=1,padding=0, bias=False) + + self.up1 = self.decoder(self.start_channel*2, self.start_channel*2) + self.up2 = self.decoder(self.start_channel*2, self.start_channel*2) + + def encoder(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, + bias=False, batchnorm=True): + if batchnorm: + layer = nn.Sequential( + nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), + nn.BatchNorm3d(out_channels), + nn.ReLU()) + else: + layer = nn.Sequential( + nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), + nn.ReLU()) + return layer + + + def decoder(self, in_channels, out_channels, kernel_size=2, stride=2, padding=0, + output_padding=0, bias=True): + layer = nn.Sequential( + nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride=stride, + padding=padding, output_padding=output_padding, bias=bias), + nn.ReLU()) + return layer + + + def outputs(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, + bias=False, batchnorm=True): + if batchnorm: + layer = nn.Sequential( + nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), + nn.BatchNorm3d(out_channels), + nn.Sigmoid()) + else: + layer = nn.Sequential( + nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), + nn.Sigmoid()) + return layer + + def forward(self, x): + e0 = self.eninput(x) + e0 = self.ec1(e0) + + e1 = self.pool(e0) + e1 = self.ec2(e1) + e1 = self.ec3(e1) + + + e2 = self.pool(e1) + e2 = self.ec4(e2) + e2 = self.ec5(e2) + + + d0 = torch.cat((self.up1(e2), e1), 1) + + + d0 = self.dc1(d0) + d0 = self.dc2(d0) + + + d1 = torch.cat((self.up2(d0), e0), 1) + + d1 = self.dc3(d1) + d1 = self.dc4(d1) + d1 = self.dc5(d1) + + return d1 + + +class ModelBreast_train(nn.Module): + def __init__(self, in_channel, n_classes): + self.in_channel = in_channel + self.n_classes = n_classes + self.start_channel = 32 + + super(ModelBreast_train, self).__init__() + self.eninput = self.encoder(self.in_channel, self.start_channel, bias=False, batchnorm=True) + self.ec1 = self.encoder(self.start_channel, self.start_channel, bias=False, batchnorm=True) + self.ec2 = self.encoder(self.start_channel, self.start_channel*2, bias=False, batchnorm=True) + self.ec3 = self.encoder(self.start_channel*2, self.start_channel*2, bias=False, batchnorm=True) + self.ec4 = self.encoder(self.start_channel*2, self.start_channel*4, bias=False, batchnorm=True) + self.ec5 = self.encoder(self.start_channel*4, self.start_channel*4, bias=False, batchnorm=True) + self.ec6 = self.encoder(self.start_channel*4, self.start_channel*8, bias=False, batchnorm=True) + self.ec7 = self.encoder(self.start_channel*8, self.start_channel*4, bias=False, batchnorm=True) + + self.pool = nn.MaxPool3d(2) + + self.dc1 = self.encoder(self.start_channel*4+self.start_channel*4, self.start_channel*4, kernel_size=3, stride=1, bias=False) + self.dc2 = self.encoder(self.start_channel*4, self.start_channel*2, kernel_size=3, stride=1, bias=False) + self.dc3 = self.encoder(self.start_channel*2+self.start_channel*2, self.start_channel*4, kernel_size=3, stride=1, bias=False) + self.dc4 = self.encoder(self.start_channel*4, self.start_channel*2, kernel_size=3, stride=1, bias=False) + self.dc5 = self.encoder(self.start_channel*2+self.start_channel*1, self.start_channel*2, kernel_size=3, stride=1, bias=False) + self.dc6 = self.encoder(self.start_channel*2, self.start_channel*2, kernel_size=3, stride=1, bias=False) + self.dc7 = self.outputs(self.start_channel*2, self.n_classes, kernel_size=1, stride=1,padding=0, bias=False) + + self.up1 = self.decoder(self.start_channel*4, self.start_channel*4) + self.up2 = self.decoder(self.start_channel*2, self.start_channel*2) + self.up3 = self.decoder(self.start_channel*2, self.start_channel*2) + + def encoder(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, + bias=False, batchnorm=True): + if batchnorm: + layer = nn.Sequential( + nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), + nn.BatchNorm3d(out_channels), + nn.ReLU()) + else: + layer = nn.Sequential( + nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), + nn.ReLU()) + return layer + + + def decoder(self, in_channels, out_channels, kernel_size=2, stride=2, padding=0, + output_padding=0, bias=True): + layer = nn.Sequential( + nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride=stride, + padding=padding, output_padding=output_padding, bias=bias), + nn.ReLU()) + return layer + + + def outputs(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, + bias=False, batchnorm=True): + if batchnorm: + layer = nn.Sequential( + nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), + nn.BatchNorm3d(out_channels), + nn.Sigmoid()) + else: + layer = nn.Sequential( + nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), + nn.Sigmoid()) + return layer + + def forward(self, x): + e0 = self.eninput(x) + e0 = self.ec1(e0) + + e1 = self.pool(e0) + e1 = self.ec2(e1) + e1 = self.ec3(e1) + + + e2 = self.pool(e1) + e2 = self.ec4(e2) + e2 = self.ec5(e2) + + e3 = self.pool(e2) + e3 = self.ec6(e3) + e3 = self.ec7(e3) + + + d0 = torch.cat((self.up1(e3), e2), 1) + + + d0 = self.dc1(d0) + d0 = self.dc2(d0) + + + d1 = torch.cat((self.up2(d0), e1), 1) + + d1 = self.dc3(d1) + d1 = self.dc4(d1) + + d2 = torch.cat((self.up3(d1), e0), 1) + + d2 = self.dc5(d2) + d2 = self.dc6(d2) + d2 = self.dc7(d2) + + return d2 + + +def Dice_loss(input, target): + smooth = 0.00000001 + + y_true_f = input.view(-1) + y_pred_f = target.view(-1) + intersection = torch.sum(torch.mul(y_true_f,y_pred_f)) + + return 1 - ((2. * intersection ) / + (torch.mul(y_true_f,y_true_f).sum() + torch.mul(y_pred_f,y_pred_f).sum() + smooth)) + +def DICESEN_loss(input, target): + smooth = 0.00000001 + y_true_f = input.view(-1) + y_pred_f = target.view(-1) + intersection = torch.sum(torch.mul(y_true_f,y_pred_f)) + dice= (2. * intersection ) / (torch.mul(y_true_f,y_true_f).sum() + torch.mul(y_pred_f,y_pred_f).sum() + smooth) + sen = (1. * intersection ) / (torch.mul(y_true_f,y_true_f).sum() + smooth) + return 2-dice-sen \ No newline at end of file