|
a |
|
b/model/Att_Unet.py |
|
|
1 |
|
|
|
2 |
import torch |
|
|
3 |
import torch.nn as nn |
|
|
4 |
|
|
|
5 |
|
|
|
6 |
class Attention_block(nn.Module): |
|
|
7 |
def __init__(self,F_g,F_l,F_int): |
|
|
8 |
super(Attention_block,self).__init__() |
|
|
9 |
self.W_g = nn.Sequential( |
|
|
10 |
nn.Conv2d(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True), |
|
|
11 |
nn.BatchNorm2d(F_int) |
|
|
12 |
) |
|
|
13 |
|
|
|
14 |
self.W_x = nn.Sequential( |
|
|
15 |
nn.Conv2d(F_l, F_int, kernel_size=1,stride=1,padding=0,bias=True), |
|
|
16 |
nn.BatchNorm2d(F_int) |
|
|
17 |
) |
|
|
18 |
|
|
|
19 |
self.psi = nn.Sequential( |
|
|
20 |
nn.Conv2d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True), |
|
|
21 |
nn.BatchNorm2d(1), |
|
|
22 |
nn.Sigmoid() |
|
|
23 |
) |
|
|
24 |
|
|
|
25 |
self.relu = nn.ReLU(inplace=True) |
|
|
26 |
|
|
|
27 |
def forward(self,g,x): |
|
|
28 |
g1 = self.W_g(g) |
|
|
29 |
x1 = self.W_x(x) |
|
|
30 |
psi = self.relu(g1+x1) |
|
|
31 |
psi = self.psi(psi) |
|
|
32 |
out=x*psi |
|
|
33 |
|
|
|
34 |
return out |
|
|
35 |
|
|
|
36 |
|
|
|
37 |
class Att_Unet(nn.Module): |
|
|
38 |
def __init__(self,output_ch=1): |
|
|
39 |
super(Att_Unet,self).__init__() |
|
|
40 |
|
|
|
41 |
|
|
|
42 |
self.base_model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet', |
|
|
43 |
in_channels=3,out_channels=1, init_features=32, |
|
|
44 |
pretrained=True,verbose=False) |
|
|
45 |
self.base_layers = list(self.base_model.children()) |
|
|
46 |
self.layer0=nn.Sequential(*self.base_layers[0]) |
|
|
47 |
self.layer1=nn.Sequential(*self.base_layers[1:3]) |
|
|
48 |
self.layer2=nn.Sequential(*self.base_layers[3:5]) |
|
|
49 |
self.layer3=nn.Sequential(*self.base_layers[5:7]) |
|
|
50 |
|
|
|
51 |
self.layer4=nn.Sequential(*self.base_layers[7:9]) |
|
|
52 |
|
|
|
53 |
|
|
|
54 |
|
|
|
55 |
self.Up5 = self.base_layers[9] |
|
|
56 |
self.Att5 = Attention_block(F_g=256,F_l=256,F_int=128) |
|
|
57 |
self.Up_conv5 = nn.Sequential(*self.base_layers[10]) |
|
|
58 |
|
|
|
59 |
self.Up4 = self.base_layers[11] |
|
|
60 |
self.Att4 = Attention_block(F_g=128,F_l=128,F_int=64) |
|
|
61 |
self.Up_conv4 =nn.Sequential(*self.base_layers[12]) |
|
|
62 |
|
|
|
63 |
self.Up3 = self.base_layers[13] |
|
|
64 |
self.Att3 = Attention_block(F_g=64,F_l=64,F_int=32) |
|
|
65 |
self.Up_conv3 = nn.Sequential(*self.base_layers[14]) |
|
|
66 |
|
|
|
67 |
self.Up2 = self.base_layers[15] |
|
|
68 |
self.Att2 = Attention_block(F_g=32,F_l=32,F_int=16) |
|
|
69 |
self.Up_conv2 = nn.Sequential(*self.base_layers[16]) |
|
|
70 |
|
|
|
71 |
|
|
|
72 |
self.Conv_1x1 =self.base_layers[17] |
|
|
73 |
|
|
|
74 |
|
|
|
75 |
def forward(self,x): |
|
|
76 |
# encoding path |
|
|
77 |
x1 = self.layer0(x) |
|
|
78 |
x2 = self.layer1(x1) |
|
|
79 |
x3 = self.layer2(x2) |
|
|
80 |
x4 = self.layer3(x3) |
|
|
81 |
|
|
|
82 |
x5 = self.layer4(x4) |
|
|
83 |
|
|
|
84 |
# decoding + concat path |
|
|
85 |
d5 = self.Up5(x5) |
|
|
86 |
x4 = self.Att5(g=d5,x=x4) |
|
|
87 |
d5 = torch.cat((x4,d5),dim=1) |
|
|
88 |
d5 = self.Up_conv5(d5) |
|
|
89 |
|
|
|
90 |
d4 = self.Up4(d5) |
|
|
91 |
x3 = self.Att4(g=d4,x=x3) |
|
|
92 |
d4 = torch.cat((x3,d4),dim=1) |
|
|
93 |
d4 = self.Up_conv4(d4) |
|
|
94 |
|
|
|
95 |
d3 = self.Up3(d4) |
|
|
96 |
x2 = self.Att3(g=d3,x=x2) |
|
|
97 |
d3 = torch.cat((x2,d3),dim=1) |
|
|
98 |
d3 = self.Up_conv3(d3) |
|
|
99 |
|
|
|
100 |
d2 = self.Up2(d3) |
|
|
101 |
x1 = self.Att2(g=d2,x=x1) |
|
|
102 |
d2 = torch.cat((x1,d2),dim=1) |
|
|
103 |
d2 = self.Up_conv2(d2) |
|
|
104 |
|
|
|
105 |
d1 = self.Conv_1x1(d2) |
|
|
106 |
|
|
|
107 |
return d1 |