--- a +++ b/libs/network/unet_df.py @@ -0,0 +1,158 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +from libs.network.unet import conv_block, up_conv + +class SelFuseFeature(nn.Module): + def __init__(self, in_channels, shift_n=5, n_class=4, auxseg=False): + super(SelFuseFeature, self).__init__() + + self.shift_n = shift_n + self.n_class = n_class + self.auxseg = auxseg + self.fuse_conv = nn.Sequential(nn.Conv2d(in_channels*2, in_channels, kernel_size=1, padding=0), + nn.BatchNorm2d(in_channels), + nn.ReLU(inplace=True), + ) + if auxseg: + self.auxseg_conv = nn.Conv2d(in_channels, self.n_class, 1) + + + def forward(self, x, df): + N, _, H, W = df.shape + mag = torch.sqrt(torch.sum(df ** 2, dim=1)) + greater_mask = mag > 0.5 + greater_mask = torch.stack([greater_mask, greater_mask], dim=1) + df[~greater_mask] = 0 + + scale = 1. + + grid = torch.stack(torch.meshgrid(torch.arange(H), torch.arange(W)), dim=0) + grid = grid.expand(N, -1, -1, -1).to(x.device, dtype=torch.float).requires_grad_() + grid = grid + scale * df + + grid = grid.permute(0, 2, 3, 1).transpose(1, 2) + grid_ = grid + 0. + grid[...,0] = 2*grid_[..., 0] / (H-1) - 1 + grid[...,1] = 2*grid_[..., 1] / (W-1) - 1 + + # features = [] + select_x = x.clone() + for _ in range(self.shift_n): + select_x = F.grid_sample(select_x, grid, mode='bilinear', padding_mode='border') + # features.append(select_x) + # select_x = torch.mean(torch.stack(features, dim=0), dim=0) + # features.append(select_x.detach().cpu().numpy()) + # np.save("/root/chengfeng/Cardiac/source_code/logs/acdc_logs/logs_temp/feature.npy", np.array(features)) + if self.auxseg: + auxseg = self.auxseg_conv(x) + else: + auxseg = None + + select_x = self.fuse_conv(torch.cat([x, select_x], dim=1)) + return [select_x, auxseg] + +class U_NetDF(nn.Module): + def __init__(self,img_ch=1,num_class=4, selfeat=False, shift_n=5, auxseg=False): + super(U_NetDF,self).__init__() + + self.selfeat = selfeat + self.shift_n = shift_n + + self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2) + + self.Conv1 = conv_block(ch_in=img_ch,ch_out=64) + self.Conv2 = conv_block(ch_in=64,ch_out=128) + self.Conv3 = conv_block(ch_in=128,ch_out=256) + self.Conv4 = conv_block(ch_in=256,ch_out=512) + self.Conv5 = conv_block(ch_in=512,ch_out=1024) + + self.Up5 = up_conv(ch_in=1024,ch_out=512) + self.Up_conv5 = conv_block(ch_in=1024, ch_out=512) + + self.Up4 = up_conv(ch_in=512,ch_out=256) + self.Up_conv4 = conv_block(ch_in=512, ch_out=256) + + self.Up3 = up_conv(ch_in=256,ch_out=128) + self.Up_conv3 = conv_block(ch_in=256, ch_out=128) + + self.Up2 = up_conv(ch_in=128,ch_out=64) + self.Up_conv2 = conv_block(ch_in=128, ch_out=64) + + # Direct Field + self.ConvDf_1x1 = nn.Conv2d(64, 2, kernel_size=1, stride=1, padding=0) + + if selfeat: + self.SelDF = SelFuseFeature(64, auxseg=auxseg, shift_n=shift_n) + + self.Conv_1x1 = nn.Conv2d(64,num_class,kernel_size=1,stride=1,padding=0) + + def forward(self, inputs): + x = inputs + + # encoding path + x1 = self.Conv1(x) + + x2 = self.Maxpool(x1) + x2 = self.Conv2(x2) + + x3 = self.Maxpool(x2) + x3 = self.Conv3(x3) + + x4 = self.Maxpool(x3) + x4 = self.Conv4(x4) + + x5 = self.Maxpool(x4) + x5 = self.Conv5(x5) + + # decoding + concat path + d5 = self.Up5(x5) + d5 = torch.cat((x4,d5),dim=1) + + d5 = self.Up_conv5(d5) + + d4 = self.Up4(d5) + d4 = torch.cat((x3,d4),dim=1) + d4 = self.Up_conv4(d4) + + d3 = self.Up3(d4) + d3 = torch.cat((x2,d3),dim=1) + d3 = self.Up_conv3(d3) + + # df = self.ConvDf_1x1(d3) + # # df = F.interpolate(inputs[1], size=d3.shape[-2:], mode='bilinear', align_corners=True) + # if self.selfeat: + # d3 = self.SelDF(d3, df) + + + d2 = self.Up2(d3) + d2 = torch.cat((x1,d2),dim=1) + d2 = self.Up_conv2(d2) + + # Direct Field + df = self.ConvDf_1x1(d2) + # df = None + if self.selfeat: + d2_auxseg = self.SelDF(d2, df) + d2, auxseg = d2_auxseg[:2] + else: + auxseg = None + + # df = F.interpolate(df, size=x.shape[-2:], mode='bilinear', align_corners=True) + d1 = self.Conv_1x1(d2) + + return [d1, df, auxseg] + + + +if __name__ == "__main__": + + a = torch.randn(1, 1, 224, 224) + + model = U_NetDF(selfeat=True) + + out = model(a) + print(out[0].shape, out[1].shape, out[2].shape) +