Diff of /ecgtoHR/HRnet.py [000000] .. [c0487b]

Switch to unified view

a b/ecgtoHR/HRnet.py
1
import torch.nn as nn
2
import torch
3
from torch.autograd import Variable
4
import numpy as np
5
6
class IncResBlock(nn.Module):
7
    def __init__(self, inplanes, planes, convstr=1, convsize = 15, convpadding = 7):
8
        super(IncResBlock, self).__init__()
9
        self.Inputconv1x1 = nn.Conv1d(inplanes, planes, kernel_size=1, stride = 1, bias=False)
10
        self.conv1_1 = nn.Sequential(
11
            nn.Conv1d(in_channels = inplanes,out_channels = planes//4,kernel_size = convsize,stride = convstr,padding = convpadding),
12
            nn.BatchNorm1d(planes//4))
13
        self.conv1_2 = nn.Sequential(
14
            nn.Conv1d(inplanes, planes//4, kernel_size=1, stride = convstr, padding=0, bias=False),
15
            nn.BatchNorm1d(planes//4),
16
            nn.LeakyReLU(0.2,),
17
            nn.Conv1d(in_channels = planes//4,out_channels = planes//4,kernel_size = convsize+2,stride = convstr,padding = convpadding+1),
18
            nn.BatchNorm1d(planes//4))
19
        self.conv1_3 = nn.Sequential(
20
            nn.Conv1d(inplanes, planes//4, kernel_size=1, stride = convstr, padding=0, bias=False),
21
            nn.BatchNorm1d(planes//4),
22
            nn.LeakyReLU(0.2,),
23
            nn.Conv1d(in_channels = planes//4,out_channels = planes//4,kernel_size = convsize+4,stride = convstr,padding = convpadding+2),
24
            nn.BatchNorm1d(planes//4))
25
        self.conv1_4 = nn.Sequential(
26
            nn.Conv1d(inplanes, planes//4, kernel_size=1, stride = convstr, padding=0, bias=False),
27
            nn.BatchNorm1d(planes//4),
28
            nn.LeakyReLU(0.2,),
29
            nn.Conv1d(in_channels = planes//4,out_channels = planes//4,kernel_size = convsize+6,stride = convstr,padding = convpadding+3),
30
            nn.BatchNorm1d(planes//4))
31
        self.relu = nn.ReLU()
32
    
33
    def forward(self, x):
34
        residual = self.Inputconv1x1(x)
35
36
        c1 = self.conv1_1(x)
37
        c2 = self.conv1_2(x)
38
        c3 = self.conv1_3(x)
39
        c4 = self.conv1_4(x)
40
41
        out = torch.cat([c1,c2,c3,c4],1)
42
        out += residual
43
        out = self.relu(out)
44
45
        return out
46
47
class IncUNet (nn.Module):
48
    def __init__(self, in_shape):
49
        super(IncUNet, self).__init__()
50
        in_channels, height, width = in_shape
51
        self.e1 = nn.Sequential(
52
            nn.Conv1d(in_channels, 64, kernel_size=4, stride=2,padding=1),
53
            nn.BatchNorm1d(64),
54
            nn.LeakyReLU(0.2,),
55
            IncResBlock(64,64))
56
        self.e2 = nn.Sequential(
57
            nn.LeakyReLU(0.2,inplace=True),
58
            nn.Conv1d(64, 128, kernel_size=4, stride=2,padding=1),
59
            nn.BatchNorm1d(128),
60
            IncResBlock(128,128))
61
        self.e2add = nn.Sequential(
62
            nn.Conv1d(128, 128, kernel_size=3, stride=1,padding=1),
63
            nn.BatchNorm1d(128))
64
        self.e3 = nn.Sequential(
65
            nn.LeakyReLU(0.2,inplace=True),
66
            nn.Conv1d(128, 128, kernel_size=3, stride=1,padding=1),
67
            nn.BatchNorm1d(128),
68
            nn.LeakyReLU(0.2,),
69
            nn.Conv1d(128,256, kernel_size=4, stride=2,padding=1),
70
            nn.BatchNorm1d(256),
71
            IncResBlock(256,256))
72
        self.e4 = nn.Sequential(
73
            nn.LeakyReLU(0.2,),
74
            nn.Conv1d(256,256, kernel_size=4 , stride=1 , padding=1),
75
            nn.BatchNorm1d(256),
76
            nn.LeakyReLU(0.2,inplace=True),
77
            nn.Conv1d(256,512, kernel_size=4, stride=2,padding=1),
78
            nn.BatchNorm1d(512),
79
            IncResBlock(512,512))
80
        self.e4add = nn.Sequential(
81
            nn.LeakyReLU(0.2,),
82
            nn.Conv1d(512,512, kernel_size=3, stride=1,padding=1),
83
            nn.BatchNorm1d(512)) 
84
        self.e5 = nn.Sequential(
85
            nn.LeakyReLU(0.2,inplace=True),
86
            nn.Conv1d(512,512, kernel_size=3, stride=1,padding=1),
87
            nn.BatchNorm1d(512),
88
            nn.LeakyReLU(0.2,),
89
            nn.Conv1d(512,512, kernel_size=4, stride=2,padding=1),
90
            nn.BatchNorm1d(512),
91
            IncResBlock(512,512))
92
93
        self.e6 = nn.Sequential(
94
            nn.LeakyReLU(0.2,),
95
            nn.Conv1d(512,512, kernel_size=3, stride=1,padding=1),
96
            nn.BatchNorm1d(512),
97
            nn.LeakyReLU(0.2,inplace=True),
98
            nn.Conv1d(512,512, kernel_size=4, stride=2,padding=1),
99
            nn.BatchNorm1d(512), 
100
            IncResBlock(512,512))
101
        
102
        self.e6add = nn.Sequential(
103
            nn.Conv1d(512,512, kernel_size=3, stride=1,padding=1),
104
            nn.BatchNorm1d(512)) 
105
        
106
        self.e7 = nn.Sequential(
107
            nn.LeakyReLU(0.2,inplace=True),
108
            nn.Conv1d(512,512, kernel_size=3, stride=1,padding=1),
109
            nn.BatchNorm1d(512),
110
            nn.LeakyReLU(0.2,),
111
            nn.Conv1d(512,512, kernel_size=4, stride=2,padding=1),
112
            nn.BatchNorm1d(512),
113
            IncResBlock(512,512))
114
        
115
        self.e8 = nn.Sequential(
116
            nn.LeakyReLU(0.2,),
117
            nn.Conv1d(512,512, kernel_size=4, stride=1,padding=1),
118
            nn.BatchNorm1d(512),
119
            nn.LeakyReLU(0.2,inplace=True),
120
            nn.Conv1d(512,512, kernel_size=4, stride=2,padding=1),
121
            nn.BatchNorm1d(512))
122
        
123
        
124
        self.d1 = nn.Sequential(
125
            nn.LeakyReLU(0.2,),
126
            nn.ConvTranspose1d(512, 512, kernel_size=4, stride=2,padding=1),
127
            nn.BatchNorm1d(512),
128
            nn.LeakyReLU(0.2,),
129
            nn.ConvTranspose1d(512, 512, kernel_size=4, stride=1,padding =1),
130
            nn.BatchNorm1d(512),
131
            IncResBlock(512,512))
132
        
133
        self.d2 = nn.Sequential(
134
            nn.LeakyReLU(0.2,),
135
            nn.ConvTranspose1d(1024, 512, kernel_size=4, stride=2,padding=1),
136
            nn.BatchNorm1d(512),
137
            nn.LeakyReLU(0.2,),
138
            nn.ConvTranspose1d(512, 512, kernel_size=3, stride=1,padding=1),
139
            nn.BatchNorm1d(512),
140
            IncResBlock(512,512))
141
        
142
        self.d3 = nn.Sequential(
143
            nn.LeakyReLU(0.2,),
144
            nn.ConvTranspose1d(1024, 512, kernel_size=3, stride=1,padding=1),
145
            nn.BatchNorm1d(512),
146
            nn.Dropout(p=0.5),
147
            IncResBlock(512,512))
148
        
149
        self.d4 = nn.Sequential(
150
            nn.LeakyReLU(0.2,),
151
            nn.ConvTranspose1d(1024, 512, kernel_size=4, stride=2,padding=1),
152
            nn.BatchNorm1d(512),
153
            nn.LeakyReLU(0.2,),
154
            nn.ConvTranspose1d(512, 512, kernel_size=3, stride=1,padding=1),
155
            nn.BatchNorm1d(512),
156
            IncResBlock(512,512))
157
        
158
        self.d5 = nn.Sequential(
159
            nn.LeakyReLU(0.2,),
160
            nn.ConvTranspose1d(1024, 512, kernel_size=4, stride=2,padding=1),
161
            nn.BatchNorm1d(512),
162
            nn.LeakyReLU(0.2,),
163
            nn.ConvTranspose1d(512, 512, kernel_size=3, stride=1,padding=1),
164
            nn.BatchNorm1d(512),
165
            IncResBlock(512,512))
166
        
167
        self.d6 = nn.Sequential(
168
            nn.LeakyReLU(0.2,),
169
            nn.ConvTranspose1d(1024, 512, kernel_size=3, stride=1,padding=1),
170
            nn.BatchNorm1d(512),
171
            IncResBlock(512,512))
172
        
173
        self.d7 = nn.Sequential(
174
            nn.LeakyReLU(0.2,),
175
            nn.ConvTranspose1d(1024, 256, kernel_size=4, stride=2,padding=1),
176
            nn.BatchNorm1d(256),
177
            nn.LeakyReLU(0.2,),
178
            nn.ConvTranspose1d(256, 256, kernel_size=4, stride=1,padding=1),
179
            nn.BatchNorm1d(256),
180
            IncResBlock(256,256))
181
        
182
        self.d8 = nn.Sequential(
183
            nn.LeakyReLU(0.2,),
184
            nn.ConvTranspose1d(512, 128, kernel_size=4, stride=2,padding=1),
185
            nn.BatchNorm1d(128),
186
            nn.LeakyReLU(0.2,),
187
            nn.ConvTranspose1d(128, 128, kernel_size=3, stride=1,padding=1),
188
            nn.BatchNorm1d(128))
189
        
190
        self.d9 = nn.Sequential(
191
            nn.LeakyReLU(0.2,),
192
            nn.ConvTranspose1d(256, 128, kernel_size=3, stride=1,padding=1),
193
            nn.BatchNorm1d(128))
194
        
195
        self.d10 = nn.Sequential(
196
            nn.LeakyReLU(0.2,),
197
            nn.ConvTranspose1d(256, 64, kernel_size=4, stride=2,padding=1),
198
            nn.BatchNorm1d(64))
199
        
200
        self.out_l = nn.Sequential(
201
            nn.LeakyReLU(0.2,),
202
            nn.ConvTranspose1d(128, in_channels, kernel_size=4, stride=2,padding=1))
203
    
204
    
205
    def forward(self, x):
206
        
207
        en1 = self.e1(x)
208
        en2 = self.e2(en1)
209
        en2add = self.e2add(en2)
210
        en3 = self.e3(en2add)
211
        en4 = self.e4(en3)
212
        en4add = self.e4add(en4)
213
        en5 = self.e5(en4add)
214
        en6 = self.e6(en5)
215
        en6add = self.e6add(en6)
216
        en7 = self.e7(en6add)
217
        en8 = self.e8(en7)
218
        
219
        de1_ = self.d1(en8)
220
        de1 = torch.cat([en7,de1_],1)
221
        de2_ = self.d2(de1)
222
        de2 = torch.cat([en6add,de2_],1)
223
        de3_ = self.d3(de2)
224
        de3 = torch.cat([en6,de3_],1)
225
        de4_ = self.d4(de3)
226
        de4 = torch.cat([en5,de4_],1)
227
        de5_ = self.d5(de4)
228
        de5 = torch.cat([en4add,de5_],1)
229
        de6_ = self.d6(de5)
230
        de6 = torch.cat([en4,de6_],1)
231
        de7_ = self.d7(de6)
232
        de7 = torch.cat([en3,de7_],1)
233
        de8 = self.d8(de7)
234
        de8_ = self.d8(de7)
235
        de8 = torch.cat([en2add,de8_],1)
236
        de9_ = self.d9(de8)
237
        de9 = torch.cat([en2,de9_],1)
238
        de10_ = self.d10(de9)
239
        de10 = torch.cat([en1,de10_],1)
240
        out = self.out_l(de10)
241
        
242
        return out