Diff of /ecgtoBR/BRnet.py [000000] .. [c0487b]

Switch to unified view

a b/ecgtoBR/BRnet.py
1
import torch.nn as nn
2
import torch
3
import torch.nn.functional as F
4
from torch.autograd import Variable
5
6
import numpy as np
7
8
class IncResBlock(nn.Module): ### Inception Resblocks
9
    
10
    def __init__(self, inplanes, planes, convstr=1, convsize = 15, convpadding = 7):
11
        
12
        super(IncResBlock, self).__init__()
13
        
14
        self.Inputconv1x1 = nn.Conv1d(inplanes, planes, kernel_size=1, stride = 1, bias=False)
15
        
16
        self.conv1_1 = nn.Sequential(
17
            nn.Conv1d(in_channels = inplanes,out_channels = planes//4,kernel_size = convsize,stride = convstr,padding = convpadding),
18
            nn.BatchNorm1d(planes//4))
19
        
20
        self.conv1_2 = nn.Sequential(
21
            nn.Conv1d(inplanes, planes//4, kernel_size=1, stride = convstr, padding=0, bias=False),
22
            nn.BatchNorm1d(planes//4),
23
            nn.LeakyReLU(0.2,),
24
            nn.Conv1d(in_channels = planes//4,out_channels = planes//4,kernel_size = convsize+2,stride = convstr,padding = convpadding+1),
25
            nn.BatchNorm1d(planes//4))
26
        
27
        self.conv1_3 = nn.Sequential(
28
            nn.Conv1d(inplanes, planes//4, kernel_size=1, stride = convstr, padding=0, bias=False),
29
            nn.BatchNorm1d(planes//4),
30
            nn.LeakyReLU(0.2,),
31
            nn.Conv1d(in_channels = planes//4,out_channels = planes//4,kernel_size = convsize+4,stride = convstr,padding = convpadding+2),
32
            nn.BatchNorm1d(planes//4))
33
        
34
        self.conv1_4 = nn.Sequential(
35
            nn.Conv1d(inplanes, planes//4, kernel_size=1, stride = convstr, padding=0, bias=False),
36
            nn.BatchNorm1d(planes//4),
37
            nn.LeakyReLU(0.2,),
38
            nn.Conv1d(in_channels = planes//4,out_channels = planes//4,kernel_size = convsize+6,stride = convstr,padding = convpadding+3),
39
            nn.BatchNorm1d(planes//4))
40
        
41
        self.relu = nn.ReLU()
42
    
43
    
44
    def forward(self, x):
45
        
46
        residual = self.Inputconv1x1(x)
47
48
        c1 = self.conv1_1(x)
49
        c2 = self.conv1_2(x)
50
        c3 = self.conv1_3(x)
51
        c4 = self.conv1_4(x)
52
53
        out = torch.cat([c1,c2,c3,c4],1)
54
        out += residual
55
        return self.relu(out)
56
57
class IncUNet (nn.Module): ### Inception Unet
58
    
59
    def __init__(self, in_shape):
60
        
61
        super(IncUNet, self).__init__()
62
        in_channels, height, width = in_shape
63
        
64
        self.e1 = nn.Sequential(
65
            nn.Conv1d(in_channels, 64, kernel_size=4, stride=2,padding=1),
66
            nn.BatchNorm1d(64),
67
            nn.LeakyReLU(0.2,),
68
            IncResBlock(64,64))
69
        
70
        self.e2 = nn.Sequential(
71
            nn.LeakyReLU(0.2,inplace=True),
72
            nn.Conv1d(64, 128, kernel_size=4, stride=2,padding=1),
73
            nn.BatchNorm1d(128),
74
            IncResBlock(128,128))
75
        
76
        self.e2add = nn.Sequential(
77
            nn.Conv1d(128, 128, kernel_size=3, stride=1,padding=1),
78
            nn.BatchNorm1d(128))
79
        
80
        self.e3 = nn.Sequential(
81
            nn.LeakyReLU(0.2,inplace=True),
82
            nn.Conv1d(128, 128, kernel_size=3, stride=1,padding=1),
83
            nn.BatchNorm1d(128),
84
            nn.LeakyReLU(0.2,),
85
            nn.Conv1d(128,256, kernel_size=4, stride=2,padding=1),
86
            nn.BatchNorm1d(256),
87
            IncResBlock(256,256))
88
        
89
        self.e4 = nn.Sequential(
90
            nn.LeakyReLU(0.2,),
91
            nn.Conv1d(256,256, kernel_size=4 , stride=1 , padding=1),
92
            nn.BatchNorm1d(256),
93
            nn.LeakyReLU(0.2,inplace=True),
94
            nn.Conv1d(256,512, kernel_size=4, stride=2,padding=2),
95
            nn.BatchNorm1d(512),
96
            IncResBlock(512,512))
97
        
98
        self.e4add = nn.Sequential(
99
            nn.LeakyReLU(0.2,),
100
            nn.Conv1d(512,512, kernel_size=3, stride=1,padding=1),
101
            nn.BatchNorm1d(512)) 
102
        
103
        self.e5 = nn.Sequential(
104
            nn.LeakyReLU(0.2,inplace=True),
105
            nn.Conv1d(512,512, kernel_size=3, stride=1,padding=1),
106
            nn.BatchNorm1d(512),
107
            nn.LeakyReLU(0.2,),
108
            nn.Conv1d(512,512, kernel_size=4, stride=2,padding=1),
109
            nn.BatchNorm1d(512),
110
            IncResBlock(512,512))
111
112
        self.e6 = nn.Sequential(
113
            nn.LeakyReLU(0.2,),
114
            nn.Conv1d(512,512, kernel_size=3, stride=1,padding=1),
115
            nn.BatchNorm1d(512),
116
            nn.LeakyReLU(0.2,inplace=True),
117
            nn.Conv1d(512,512, kernel_size=4, stride=2,padding=1),
118
            nn.BatchNorm1d(512), 
119
            IncResBlock(512,512))
120
        
121
        self.e6add = nn.Sequential(
122
            nn.Conv1d(512,512, kernel_size=3, stride=1,padding=1),
123
            nn.BatchNorm1d(512)) 
124
        
125
        self.e7 = nn.Sequential(
126
            nn.LeakyReLU(0.2,inplace=True),
127
            nn.Conv1d(512,512, kernel_size=3, stride=1,padding=1),
128
            nn.BatchNorm1d(512),
129
            nn.LeakyReLU(0.2,),
130
            nn.Conv1d(512,512, kernel_size=4, stride=2,padding=1),
131
            nn.BatchNorm1d(512),
132
            IncResBlock(512,512))
133
        
134
        self.e8 = nn.Sequential(
135
            nn.LeakyReLU(0.2,),
136
            nn.Conv1d(512,512, kernel_size=4, stride=1,padding=1),
137
            nn.BatchNorm1d(512),
138
            nn.LeakyReLU(0.2,inplace=True),
139
            nn.Conv1d(512,512, kernel_size=4, stride=2,padding=1),
140
            nn.BatchNorm1d(512))
141
        
142
        
143
        self.d1 = nn.Sequential(
144
            nn.LeakyReLU(0.2,),
145
            nn.ConvTranspose1d(512, 512, kernel_size=4, stride=2,padding=1),
146
            nn.BatchNorm1d(512),
147
            nn.LeakyReLU(0.2,),
148
            nn.ConvTranspose1d(512, 512, kernel_size=4, stride=1,padding =1),
149
            nn.BatchNorm1d(512),
150
            IncResBlock(512,512))
151
        
152
        self.d2 = nn.Sequential(
153
            nn.LeakyReLU(0.2,),
154
            nn.ConvTranspose1d(1024, 512, kernel_size=4, stride=2,padding=1),
155
            nn.BatchNorm1d(512),
156
            nn.LeakyReLU(0.2,),
157
            nn.ConvTranspose1d(512, 512, kernel_size=3, stride=1,padding=1),
158
            nn.BatchNorm1d(512),
159
            IncResBlock(512,512))
160
        
161
        self.d3 = nn.Sequential(
162
            nn.LeakyReLU(0.2,),
163
            nn.ConvTranspose1d(1024, 512, kernel_size=3, stride=1,padding=1),
164
            nn.BatchNorm1d(512),
165
            nn.Dropout(p=0.5),
166
            IncResBlock(512,512))
167
        
168
        self.d4 = nn.Sequential(
169
            nn.LeakyReLU(0.2,),
170
            nn.ConvTranspose1d(1024, 512, kernel_size=4, stride=2,padding=1),
171
            nn.BatchNorm1d(512),
172
            nn.LeakyReLU(0.2,),
173
            nn.ConvTranspose1d(512, 512, kernel_size=3, stride=1,padding=1),
174
            nn.BatchNorm1d(512),
175
            IncResBlock(512,512))
176
        
177
        self.d5 = nn.Sequential(
178
            nn.LeakyReLU(0.2,),
179
            nn.ConvTranspose1d(1024, 512, kernel_size=4, stride=2,padding=1),
180
            nn.BatchNorm1d(512),
181
            nn.LeakyReLU(0.2,),
182
            nn.ConvTranspose1d(512, 512, kernel_size=3, stride=1,padding=1),
183
            nn.BatchNorm1d(512),
184
            IncResBlock(512,512))
185
        
186
        self.d6 = nn.Sequential(
187
            nn.LeakyReLU(0.2,),
188
            nn.ConvTranspose1d(1024, 512, kernel_size=3, stride=1,padding=1),
189
            nn.BatchNorm1d(512),
190
            IncResBlock(512,512))
191
        
192
        self.d7 = nn.Sequential(
193
            nn.LeakyReLU(0.2,),
194
            nn.ConvTranspose1d(1024, 256, kernel_size=4, stride=2,padding=1),
195
            nn.BatchNorm1d(256),
196
            nn.LeakyReLU(0.2,),
197
            nn.ConvTranspose1d(256, 256, kernel_size=3, stride=1,padding=1),
198
            nn.BatchNorm1d(256),
199
            IncResBlock(256,256))
200
        
201
        self.d8 = nn.Sequential(
202
            nn.LeakyReLU(0.2,),
203
            nn.ConvTranspose1d(512, 128, kernel_size=4, stride=2,padding=1),
204
            nn.BatchNorm1d(128),
205
            nn.LeakyReLU(0.2,),
206
            nn.ConvTranspose1d(128, 128, kernel_size=3, stride=1,padding=1),
207
            nn.BatchNorm1d(128))
208
        
209
        self.d9 = nn.Sequential(
210
            nn.LeakyReLU(0.2,),
211
            nn.ConvTranspose1d(256, 128, kernel_size=3, stride=1,padding=1),
212
            nn.BatchNorm1d(128))
213
        
214
        self.d10 = nn.Sequential(
215
            nn.LeakyReLU(0.2,),
216
            nn.ConvTranspose1d(256, 64, kernel_size=3, stride=1,padding=1),
217
            nn.BatchNorm1d(64))
218
        
219
        self.out_l = nn.Sequential(
220
            nn.LeakyReLU(0.2,),
221
            nn.ConvTranspose1d(256, in_channels, kernel_size=3, stride=1,padding=1))
222
    
223
    
224
    def forward(self, x):       
225
        
226
        ### Encoder
227
        en1 = self.e1(x)
228
        en2 = self.e2(en1)
229
        en2add = self.e2add(en2)
230
        en3 = self.e3(en2add)
231
        en4 = self.e4(en3)
232
        en4add = self.e4add(en4)
233
        en5 = self.e5(en4add)
234
        en6 = self.e6(en5)
235
        en6add = self.e6add(en6)
236
        en7 = self.e7(en6add)
237
        en8 = self.e8(en7)
238
        
239
240
        ### Decoder
241
        de1_ = self.d1(en8)
242
        de1 = torch.cat([en7,de1_],1)
243
        de2_ = self.d2(de1)
244
        de2 = torch.cat([en6add,de2_],1)
245
        de3_ = self.d3(de2)
246
        de3 = torch.cat([en6,de3_],1)
247
        de4_ = self.d4(de3)
248
        de4 = torch.cat([en5,de4_],1)
249
        de5_ = self.d5(de4)
250
        de5_ = nn.ConstantPad1d((0,1),0)(de5_)
251
        de5 = torch.cat([en4add,de5_],1)
252
        de6_ = self.d6(de5)
253
        de6 = torch.cat([en4,de6_],1)
254
        de7_ = self.d7(de6)
255
        de7_ = de7_[:,:,:-1]
256
        de7 = torch.cat([en3,de7_],1)
257
        de8 = self.d8(de7)
258
        de8_ = self.d8(de7)
259
        de8 = torch.cat([en2add,de8_],1)
260
        
261
        return self.out_l(de8)
262
        
263