Diff of /Models_3D.py [000000] .. [3e03fc]

Switch to unified view

a b/Models_3D.py
1
"""
2
This code was write by Dr. Jun Zhang. If you use this code please follow the licence of Attribution-NonCommercial-ShareAlike 4.0 International. 
3
4
"""
5
6
import torch 
7
import torch.nn as nn
8
9
10
11
def center_crop(layer, n_size):
12
    cropidx = (layer.size(2) - n_size) // 2
13
    return layer[:, :, cropidx:(cropidx + n_size), cropidx:(cropidx + n_size),cropidx:(cropidx + n_size)]
14
15
class ModelBreast(nn.Module):
16
    def __init__(self, in_channel, n_classes):
17
        self.in_channel = in_channel
18
        self.n_classes = n_classes
19
        self.start_channel = 32
20
        
21
        super(ModelBreast, self).__init__()
22
        self.eninput = self.encoder(self.in_channel, self.start_channel, bias=False)
23
        self.ec1 = self.encoder(self.start_channel, self.start_channel, bias=False)
24
        self.ec2 = self.encoder(self.start_channel, self.start_channel*2, bias=False)
25
        self.ec3 = self.encoder(self.start_channel*2, self.start_channel*2, bias=False)
26
        self.ec4 = self.encoder(self.start_channel*2, self.start_channel*4, bias=False)
27
        self.ec5 = self.encoder(self.start_channel*4, self.start_channel*4, bias=False)
28
        self.ec6 = self.encoder(self.start_channel*4, self.start_channel*8, bias=False)
29
        self.ec7 = self.encoder(self.start_channel*8, self.start_channel*4, bias=False)
30
31
        self.pool = nn.MaxPool3d(2)
32
        
33
        self.dc1 = self.encoder(self.start_channel*4+self.start_channel*4, self.start_channel*4, kernel_size=3, stride=1, bias=False)
34
        self.dc2 = self.encoder(self.start_channel*4, self.start_channel*2, kernel_size=3, stride=1, bias=False)
35
        self.dc3 = self.encoder(self.start_channel*2+self.start_channel*2, self.start_channel*4, kernel_size=3, stride=1, bias=False)
36
        self.dc4 = self.encoder(self.start_channel*4, self.start_channel*2, kernel_size=3, stride=1, bias=False)
37
        self.dc5 = self.encoder(self.start_channel*2+self.start_channel*1, self.start_channel*2, kernel_size=3, stride=1, bias=False)
38
        self.dc6 = self.encoder(self.start_channel*2, self.start_channel*2, kernel_size=3, stride=1, bias=False)
39
        self.dc7 = self.outputs(self.start_channel*2, self.n_classes, kernel_size=1, stride=1,padding=0, bias=False)
40
        
41
        self.up1 = self.decoder(self.start_channel*4, self.start_channel*4)
42
        self.up2 = self.decoder(self.start_channel*2, self.start_channel*2)
43
        self.up3 = self.decoder(self.start_channel*2, self.start_channel*2)
44
45
    def encoder(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0,
46
                bias=False, batchnorm=True):
47
        if batchnorm:
48
            layer = nn.Sequential(
49
                nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias),
50
                nn.BatchNorm3d(out_channels),
51
                nn.ReLU())
52
        else:
53
            layer = nn.Sequential(
54
                nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias),
55
                nn.ReLU())
56
        return layer
57
58
59
    def decoder(self, in_channels, out_channels, kernel_size=2, stride=2, padding=0,
60
                output_padding=0, bias=True):
61
        layer = nn.Sequential(
62
            nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride=stride,
63
                               padding=padding, output_padding=output_padding, bias=bias),
64
            nn.ReLU())
65
        return layer
66
67
68
    def outputs(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0,
69
                bias=False, batchnorm=True):
70
        if batchnorm:
71
            layer = nn.Sequential(
72
                nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias),
73
                nn.BatchNorm3d(out_channels),
74
                nn.Sigmoid())
75
        else:
76
            layer = nn.Sequential(
77
                nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias),
78
                nn.Sigmoid())
79
        return layer
80
81
    def forward(self, x):
82
        e0 = self.eninput(x)
83
        e0 = self.ec1(e0)
84
85
        e1 = self.pool(e0)
86
        e1 = self.ec2(e1)
87
        e1 = self.ec3(e1)
88
89
90
        e2 = self.pool(e1)
91
        e2 = self.ec4(e2)
92
        e2 = self.ec5(e2)
93
94
        e3 = self.pool(e2)
95
        e3 = self.ec6(e3)
96
        e3 = self.ec7(e3)
97
98
        d0 = torch.cat((self.up1(e3), center_crop(e2,e3.size(2)*2)), 1)
99
100
101
        d0 = self.dc1(d0)
102
        d0 = self.dc2(d0)
103
104
105
        d1 = torch.cat((self.up2(d0), center_crop(e1,d0.size(2)*2)), 1)
106
107
        d1 = self.dc3(d1)
108
        d1 = self.dc4(d1)
109
110
        d2 = torch.cat((self.up3(d1), center_crop(e0,d1.size(2)*2)), 1)
111
112
        d2 = self.dc5(d2)
113
        d2 = self.dc6(d2)
114
        d2 = self.dc7(d2)
115
116
        
117
        return d2
118
    
119
    
120
    
121
class ModelTumor(nn.Module):
122
    def __init__(self, in_channel, n_classes):
123
        self.in_channel = in_channel
124
        self.n_classes = n_classes
125
        self.start_channel = 32
126
        
127
        super(ModelTumor, self).__init__()
128
        self.eninput = self.encoder(self.in_channel, self.start_channel, bias=False)
129
        self.ec1 = self.encoder(self.start_channel, self.start_channel, bias=False)
130
        self.ec2 = self.encoder(self.start_channel, self.start_channel*2, bias=False)
131
        self.ec3 = self.encoder(self.start_channel*2, self.start_channel*2, bias=False)
132
        self.ec4 = self.encoder(self.start_channel*2, self.start_channel*4, bias=False)
133
        self.ec5 = self.encoder(self.start_channel*4, self.start_channel*2, bias=False)
134
135
136
        self.pool = nn.MaxPool3d(2)
137
        
138
139
        self.dc1 = self.encoder(self.start_channel*2+self.start_channel*2, self.start_channel*4, kernel_size=3, stride=1, bias=False)
140
        self.dc2 = self.encoder(self.start_channel*4, self.start_channel*2, kernel_size=3, stride=1, bias=False)
141
        self.dc3 = self.encoder(self.start_channel*2+self.start_channel*1, self.start_channel*2, kernel_size=3, stride=1, bias=False)
142
        self.dc4 = self.encoder(self.start_channel*2, self.start_channel*2, kernel_size=3, stride=1, bias=False)
143
        self.dc5 = self.outputs(self.start_channel*2, self.n_classes, kernel_size=1, stride=1,padding=0, bias=False)
144
145
        self.up1 = self.decoder(self.start_channel*2, self.start_channel*2)
146
        self.up2 = self.decoder(self.start_channel*2, self.start_channel*2)
147
148
    def encoder(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0,
149
                bias=False, batchnorm=True):
150
        if batchnorm:
151
            layer = nn.Sequential(
152
                nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias),
153
                nn.BatchNorm3d(out_channels),
154
                nn.ReLU())
155
        else:
156
            layer = nn.Sequential(
157
                nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias),
158
                nn.ReLU())
159
        return layer
160
161
162
    def decoder(self, in_channels, out_channels, kernel_size=2, stride=2, padding=0,
163
                output_padding=0, bias=True):
164
        layer = nn.Sequential(
165
            nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride=stride,
166
                               padding=padding, output_padding=output_padding, bias=bias),
167
            nn.ReLU())
168
        return layer
169
170
171
    def outputs(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0,
172
                bias=False, batchnorm=True):
173
        if batchnorm:
174
            layer = nn.Sequential(
175
                nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias),
176
                nn.BatchNorm3d(out_channels),
177
                nn.Sigmoid())
178
        else:
179
            layer = nn.Sequential(
180
                nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias),
181
                nn.Sigmoid())
182
        return layer
183
184
    def forward(self, x):
185
        e0 = self.eninput(x)
186
        e0 = self.ec1(e0)
187
        
188
        e1 = self.pool(e0)
189
        e1 = self.ec2(e1)
190
        e1 = self.ec3(e1)
191
192
193
        e2 = self.pool(e1)
194
        e2 = self.ec4(e2)
195
        e2 = self.ec5(e2)
196
197
198
        d0 = torch.cat((self.up1(e2), center_crop(e1,e2.size(2)*2)), 1)
199
200
201
        d0 = self.dc1(d0)
202
        d0 = self.dc2(d0)
203
204
205
        d1 = torch.cat((self.up2(d0), center_crop(e0,d0.size(2)*2)), 1)
206
207
        d1 = self.dc3(d1)
208
        d1 = self.dc4(d1)
209
        d1 = self.dc5(d1)
210
        
211
        return d1    
212
    
213
    
214
    
215
# Note we trained the model with the same size (96*96*96) of input and output. 
216
# We used zero padding to guarantee the same size of output after filtering
217
    
218
class ModelTumor_train(nn.Module):
219
    def __init__(self, in_channel, n_classes):
220
        self.in_channel = in_channel
221
        self.n_classes = n_classes
222
        self.start_channel = 32
223
        
224
        super(ModelTumor_train, self).__init__()
225
        self.eninput = self.encoder(self.in_channel, self.start_channel, bias=False)
226
        self.ec1 = self.encoder(self.start_channel, self.start_channel, bias=False)
227
        self.ec2 = self.encoder(self.start_channel, self.start_channel*2, bias=False)
228
        self.ec3 = self.encoder(self.start_channel*2, self.start_channel*2, bias=False)
229
        self.ec4 = self.encoder(self.start_channel*2, self.start_channel*4, bias=False)
230
        self.ec5 = self.encoder(self.start_channel*4, self.start_channel*2, bias=False)
231
232
233
        self.pool = nn.MaxPool3d(2)
234
        
235
236
        self.dc1 = self.encoder(self.start_channel*2+self.start_channel*2, self.start_channel*4, kernel_size=3, stride=1, bias=False)
237
        self.dc2 = self.encoder(self.start_channel*4, self.start_channel*2, kernel_size=3, stride=1, bias=False)
238
        self.dc3 = self.encoder(self.start_channel*2+self.start_channel*1, self.start_channel*2, kernel_size=3, stride=1, bias=False)
239
        self.dc4 = self.encoder(self.start_channel*2, self.start_channel*2, kernel_size=3, stride=1, bias=False)
240
        self.dc5 = self.outputs(self.start_channel*2, self.n_classes, kernel_size=1, stride=1,padding=0, bias=False)
241
242
        self.up1 = self.decoder(self.start_channel*2, self.start_channel*2)
243
        self.up2 = self.decoder(self.start_channel*2, self.start_channel*2)
244
245
    def encoder(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1,
246
                bias=False, batchnorm=True):
247
        if batchnorm:
248
            layer = nn.Sequential(
249
                nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias),
250
                nn.BatchNorm3d(out_channels),
251
                nn.ReLU())
252
        else:
253
            layer = nn.Sequential(
254
                nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias),
255
                nn.ReLU())
256
        return layer
257
258
259
    def decoder(self, in_channels, out_channels, kernel_size=2, stride=2, padding=0,
260
                output_padding=0, bias=True):
261
        layer = nn.Sequential(
262
            nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride=stride,
263
                               padding=padding, output_padding=output_padding, bias=bias),
264
            nn.ReLU())
265
        return layer
266
267
268
    def outputs(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0,
269
                bias=False, batchnorm=True):
270
        if batchnorm:
271
            layer = nn.Sequential(
272
                nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias),
273
                nn.BatchNorm3d(out_channels),
274
                nn.Sigmoid())
275
        else:
276
            layer = nn.Sequential(
277
                nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias),
278
                nn.Sigmoid())
279
        return layer
280
281
    def forward(self, x):
282
        e0 = self.eninput(x)
283
        e0 = self.ec1(e0)
284
        
285
        e1 = self.pool(e0)
286
        e1 = self.ec2(e1)
287
        e1 = self.ec3(e1)
288
289
290
        e2 = self.pool(e1)
291
        e2 = self.ec4(e2)
292
        e2 = self.ec5(e2)
293
294
295
        d0 = torch.cat((self.up1(e2), e1), 1)
296
297
298
        d0 = self.dc1(d0)
299
        d0 = self.dc2(d0)
300
301
302
        d1 = torch.cat((self.up2(d0), e0), 1)
303
304
        d1 = self.dc3(d1)
305
        d1 = self.dc4(d1)
306
        d1 = self.dc5(d1)
307
        
308
        return d1    
309
    
310
    
311
class ModelBreast_train(nn.Module):
312
    def __init__(self, in_channel, n_classes):
313
        self.in_channel = in_channel
314
        self.n_classes = n_classes
315
        self.start_channel = 32
316
        
317
        super(ModelBreast_train, self).__init__()
318
        self.eninput = self.encoder(self.in_channel, self.start_channel, bias=False, batchnorm=True)
319
        self.ec1 = self.encoder(self.start_channel, self.start_channel, bias=False, batchnorm=True)
320
        self.ec2 = self.encoder(self.start_channel, self.start_channel*2, bias=False, batchnorm=True)
321
        self.ec3 = self.encoder(self.start_channel*2, self.start_channel*2, bias=False, batchnorm=True)
322
        self.ec4 = self.encoder(self.start_channel*2, self.start_channel*4, bias=False, batchnorm=True)
323
        self.ec5 = self.encoder(self.start_channel*4, self.start_channel*4, bias=False, batchnorm=True)
324
        self.ec6 = self.encoder(self.start_channel*4, self.start_channel*8, bias=False, batchnorm=True)
325
        self.ec7 = self.encoder(self.start_channel*8, self.start_channel*4, bias=False, batchnorm=True)
326
327
        self.pool = nn.MaxPool3d(2)
328
        
329
        self.dc1 = self.encoder(self.start_channel*4+self.start_channel*4, self.start_channel*4, kernel_size=3, stride=1, bias=False)
330
        self.dc2 = self.encoder(self.start_channel*4, self.start_channel*2, kernel_size=3, stride=1, bias=False)
331
        self.dc3 = self.encoder(self.start_channel*2+self.start_channel*2, self.start_channel*4, kernel_size=3, stride=1, bias=False)
332
        self.dc4 = self.encoder(self.start_channel*4, self.start_channel*2, kernel_size=3, stride=1, bias=False)
333
        self.dc5 = self.encoder(self.start_channel*2+self.start_channel*1, self.start_channel*2, kernel_size=3, stride=1, bias=False)
334
        self.dc6 = self.encoder(self.start_channel*2, self.start_channel*2, kernel_size=3, stride=1, bias=False)
335
        self.dc7 = self.outputs(self.start_channel*2, self.n_classes, kernel_size=1, stride=1,padding=0, bias=False)
336
        
337
        self.up1 = self.decoder(self.start_channel*4, self.start_channel*4)
338
        self.up2 = self.decoder(self.start_channel*2, self.start_channel*2)
339
        self.up3 = self.decoder(self.start_channel*2, self.start_channel*2)
340
341
    def encoder(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1,
342
                bias=False, batchnorm=True):
343
        if batchnorm:
344
            layer = nn.Sequential(
345
                nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias),
346
                nn.BatchNorm3d(out_channels),
347
                nn.ReLU())
348
        else:
349
            layer = nn.Sequential(
350
                nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias),
351
                nn.ReLU())
352
        return layer
353
354
355
    def decoder(self, in_channels, out_channels, kernel_size=2, stride=2, padding=0,
356
                output_padding=0, bias=True):
357
        layer = nn.Sequential(
358
            nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride=stride,
359
                               padding=padding, output_padding=output_padding, bias=bias),
360
            nn.ReLU())
361
        return layer
362
363
364
    def outputs(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0,
365
                bias=False, batchnorm=True):
366
        if batchnorm:
367
            layer = nn.Sequential(
368
                nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias),
369
                nn.BatchNorm3d(out_channels),
370
                nn.Sigmoid())
371
        else:
372
            layer = nn.Sequential(
373
                nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias),
374
                nn.Sigmoid())
375
        return layer
376
377
    def forward(self, x):
378
        e0 = self.eninput(x)
379
        e0 = self.ec1(e0)
380
        
381
        e1 = self.pool(e0)
382
        e1 = self.ec2(e1)
383
        e1 = self.ec3(e1)
384
385
386
        e2 = self.pool(e1)
387
        e2 = self.ec4(e2)
388
        e2 = self.ec5(e2)
389
390
        e3 = self.pool(e2)
391
        e3 = self.ec6(e3)
392
        e3 = self.ec7(e3)
393
394
395
        d0 = torch.cat((self.up1(e3), e2), 1)
396
397
398
        d0 = self.dc1(d0)
399
        d0 = self.dc2(d0)
400
401
402
        d1 = torch.cat((self.up2(d0), e1), 1)
403
404
        d1 = self.dc3(d1)
405
        d1 = self.dc4(d1)
406
407
        d2 = torch.cat((self.up3(d1), e0), 1)
408
409
        d2 = self.dc5(d2)
410
        d2 = self.dc6(d2)
411
        d2 = self.dc7(d2)
412
        
413
        return d2    
414
    
415
    
416
def Dice_loss(input, target):
417
    smooth = 0.00000001
418
419
    y_true_f = input.view(-1)
420
    y_pred_f = target.view(-1)
421
    intersection = torch.sum(torch.mul(y_true_f,y_pred_f))
422
    
423
    return 1 - ((2. * intersection ) /
424
              (torch.mul(y_true_f,y_true_f).sum() + torch.mul(y_pred_f,y_pred_f).sum() + smooth))
425
426
def DICESEN_loss(input, target):
427
    smooth = 0.00000001
428
    y_true_f = input.view(-1)
429
    y_pred_f = target.view(-1)
430
    intersection = torch.sum(torch.mul(y_true_f,y_pred_f))
431
    dice= (2. * intersection ) / (torch.mul(y_true_f,y_true_f).sum() + torch.mul(y_pred_f,y_pred_f).sum() + smooth)
432
    sen = (1. * intersection ) / (torch.mul(y_true_f,y_true_f).sum() + smooth)
433
    return 2-dice-sen