Diff of /networks/att_unet.py [000000] .. [903821]

Switch to side-by-side view

--- a
+++ b/networks/att_unet.py
@@ -0,0 +1,140 @@
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.data
+import torch
+
+
+class conv_block(nn.Module):
+    """
+    Convolution Block
+    """
+    def __init__(self, in_ch, out_ch):
+        super(conv_block, self).__init__()
+
+        self.conv = nn.Sequential(
+            nn.Conv3d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
+            nn.BatchNorm3d(out_ch),
+            nn.ReLU(inplace=True),
+            nn.Conv3d(out_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
+            nn.BatchNorm3d(out_ch),
+            nn.ReLU(inplace=True))
+
+    def forward(self, x):
+        x = self.conv(x)
+        return x
+
+
+class up_conv(nn.Module):
+    """
+    Up Convolution Block
+    """
+
+    def __init__(self, in_ch, out_ch):
+        super(up_conv, self).__init__()
+        self.up = nn.Sequential(
+            nn.Upsample(scale_factor=2,mode='trilinear'),
+            nn.Conv3d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
+            nn.BatchNorm3d(out_ch),
+            nn.ReLU(inplace=True)
+        )
+
+    def forward(self, x):
+        x = self.up(x)
+        return x
+
+
+class AttentionBlock(nn.Module):
+    def __init__(self, in_channels, skip_channels, mid_channels):
+        super(AttentionBlock, self).__init__()
+        self.W_skip = nn.Sequential(nn.Conv3d(skip_channels, mid_channels, kernel_size=1),
+                                    nn.BatchNorm3d(mid_channels))
+        self.W_x = nn.Sequential(nn.Conv3d(in_channels, mid_channels, kernel_size=1),
+                                 nn.BatchNorm3d(mid_channels))
+        self.psi = nn.Sequential(nn.Conv3d(mid_channels, 1, kernel_size=1),
+                                 nn.BatchNorm3d(1),
+                                 nn.Sigmoid())
+
+    def forward(self, x_skip, x):
+        x_skip = self.W_skip(x_skip)
+        x = self.W_x(x)
+        out = self.psi(nn.ReLU(inplace=True)(x_skip + x))
+        return out * x_skip
+
+
+class AttentionUp(nn.Module):
+    def __init__(self, in_ch, out_ch):
+        super(AttentionUp, self).__init__()
+        self.attention = AttentionBlock(in_ch, out_ch, out_ch)
+        self.conv1 = conv_block(in_ch+out_ch, out_ch)
+
+    def forward(self, x, x_skip):
+        # note : x_skip is the skip connection and x is the input from the previous block
+        x = nn.functional.interpolate(x, x_skip.shape[2:], mode='trilinear', align_corners=False)
+        x_attention = self.attention(x_skip, x)
+        # stack their channels to feed to both convolution blocks
+        x = torch.cat((x, x_attention), dim=1)
+        x = self.conv1(x)
+        return x
+
+
+class AttentionUNet(nn.Module):
+    def __init__(self, in_ch=3, out_ch=1):
+        super(AttentionUNet, self).__init__()
+
+        n1 = 16
+        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]
+
+        self.Maxpool1 = nn.MaxPool3d(kernel_size=2, stride=2)
+        self.Maxpool2 = nn.MaxPool3d(kernel_size=2, stride=2)
+        self.Maxpool3 = nn.MaxPool3d(kernel_size=2, stride=2)
+        self.Maxpool4 = nn.MaxPool3d(kernel_size=(2,2,1), stride=(2,2,1))
+
+        self.Conv1 = conv_block(in_ch, filters[0])
+        self.Conv2 = conv_block(filters[0], filters[1])
+        self.Conv3 = conv_block(filters[1], filters[2])
+        self.Conv4 = conv_block(filters[2], filters[3])
+        self.Conv5 = conv_block(filters[3], filters[4])
+
+        self.Up_conv5 = AttentionUp(filters[4], filters[3])
+
+        self.Up_conv4 = AttentionUp(filters[3], filters[2])
+
+        self.Up_conv3 = AttentionUp(filters[2], filters[1])
+
+        self.Up_conv2 = AttentionUp(filters[1], filters[0])
+
+        self.Conv = nn.Conv3d(filters[0], out_ch, kernel_size=1, stride=1, padding=0)
+
+    def forward(self, x):
+        e1 = self.Conv1(x)
+
+        e2 = self.Maxpool1(e1)
+        e2 = self.Conv2(e2)
+
+        e3 = self.Maxpool2(e2)
+        e3 = self.Conv3(e3)
+
+        e4 = self.Maxpool3(e3)
+        e4 = self.Conv4(e4)
+
+        e5 = self.Maxpool4(e4)
+        e5 = self.Conv5(e5)
+
+        d4 = self.Up_conv5(e5,e4)
+
+        d3 = self.Up_conv4(d4,e3)
+
+        d2 = self.Up_conv3(d3,e2)
+
+        d1 = self.Up_conv2(d2,e1)
+
+        out = self.Conv(d1)
+
+        return out
+
+
+if __name__ == '__main__':
+    x = torch.randn(2,1,96,96,48)
+    model = AttentionUNet(in_ch=1,out_ch=2)
+    y = model(x)
+    print(y.shape)
\ No newline at end of file