Diff of /networks/u_net.py [000000] .. [903821]

Switch to unified view

a b/networks/u_net.py
1
import torch.nn as nn
2
import torch.nn.functional as F
3
import torch.utils.data
4
import torch
5
6
7
class conv_block(nn.Module):
8
    """
9
    Convolution Block
10
    """
11
    def __init__(self, in_ch, out_ch):
12
        super(conv_block, self).__init__()
13
14
        self.conv = nn.Sequential(
15
            nn.Conv3d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
16
            nn.BatchNorm3d(out_ch),
17
            nn.ReLU(inplace=True),
18
            nn.Conv3d(out_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
19
            nn.BatchNorm3d(out_ch),
20
            nn.ReLU(inplace=True))
21
22
    def forward(self, x):
23
        x = self.conv(x)
24
        return x
25
26
27
class up_conv(nn.Module):
28
    """
29
    Up Convolution Block
30
    """
31
32
    def __init__(self, in_ch, out_ch):
33
        super(up_conv, self).__init__()
34
        self.up = nn.Sequential(
35
            nn.Upsample(scale_factor=2,mode='trilinear'),
36
            nn.Conv3d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
37
            nn.BatchNorm3d(out_ch),
38
            nn.ReLU(inplace=True)
39
        )
40
41
    def forward(self, x):
42
        x = self.up(x)
43
        return x
44
45
46
class TrilinearUp(nn.Module):
47
    def __init__(self, in_ch, out_ch):
48
        super(TrilinearUp, self).__init__()
49
        self.conv1 = nn.Sequential(
50
            nn.Conv3d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
51
            nn.BatchNorm3d(out_ch),
52
            nn.ReLU(inplace=True)
53
        )
54
        self.conv2 = conv_block(in_ch, out_ch)
55
56
    def forward(self, x, x_skip):
57
        # note : x_skip is the skip connection and x is the input from the previous block
58
        x = nn.functional.interpolate(x, x_skip.shape[2:], mode='trilinear', align_corners=False)
59
        x = self.conv1(x)
60
        # stack their channels to feed to both convolution blocks
61
        x = torch.cat((x, x_skip), dim=1)
62
        x = self.conv2(x)
63
        return x
64
65
66
class U_Net(nn.Module):
67
    """
68
    UNet - Basic Implementation
69
    Paper : https://arxiv.org/abs/1505.04597
70
    """
71
72
    def __init__(self, in_ch=3, out_ch=1):
73
        super(U_Net, self).__init__()
74
75
        n1 = 16
76
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]
77
78
        self.Maxpool1 = nn.MaxPool3d(kernel_size=2, stride=2)
79
        self.Maxpool2 = nn.MaxPool3d(kernel_size=2, stride=2)
80
        self.Maxpool3 = nn.MaxPool3d(kernel_size=2, stride=2)
81
        self.Maxpool4 = nn.MaxPool3d(kernel_size=(2,2,1), stride=(2,2,1))
82
83
        self.Conv1 = conv_block(in_ch, filters[0])
84
        self.Conv2 = conv_block(filters[0], filters[1])
85
        self.Conv3 = conv_block(filters[1], filters[2])
86
        self.Conv4 = conv_block(filters[2], filters[3])
87
        self.Conv5 = conv_block(filters[3], filters[4])
88
89
        self.Up_conv5 = TrilinearUp(filters[4], filters[3])
90
91
        self.Up_conv4 = TrilinearUp(filters[3], filters[2])
92
93
        self.Up_conv3 = TrilinearUp(filters[2], filters[1])
94
95
        self.Up_conv2 = TrilinearUp(filters[1], filters[0])
96
97
        self.Conv = nn.Conv3d(filters[0], out_ch, kernel_size=1, stride=1, padding=0)
98
99
    def forward(self, x):
100
        e1 = self.Conv1(x)
101
102
        e2 = self.Maxpool1(e1)
103
        e2 = self.Conv2(e2)
104
105
        e3 = self.Maxpool2(e2)
106
        e3 = self.Conv3(e3)
107
108
        e4 = self.Maxpool3(e3)
109
        e4 = self.Conv4(e4)
110
111
        e5 = self.Maxpool4(e4)
112
        e5 = self.Conv5(e5)
113
114
        d4 = self.Up_conv5(e5,e4)
115
116
        d3 = self.Up_conv4(d4,e3)
117
118
        d2 = self.Up_conv3(d3,e2)
119
120
        d1 = self.Up_conv2(d2,e1)
121
122
        out = self.Conv(d1)
123
124
        return out