Diff of /vnet.py [000000] .. [94d9b6]

Switch to unified view

a b/vnet.py
1
import torch
2
import torch.nn as nn
3
from torch.autograd import Variable
4
import torch.optim as optim
5
import torchvision
6
from torchvision import datasets, models
7
from torchvision import transforms as T
8
from torch.utils.data import DataLoader, Dataset
9
import numpy as np
10
import matplotlib.pyplot as plt
11
import os
12
import time
13
import pandas as pd
14
from skimage import io, transform
15
import matplotlib.image as mpimg
16
from PIL import Image
17
from sklearn.metrics import roc_auc_score
18
import torch.nn.functional as F
19
import scipy
20
import random
21
import pickle
22
import scipy.io as sio
23
import itertools
24
from scipy.ndimage.interpolation import shift
25
26
import warnings
27
warnings.filterwarnings("ignore")
28
plt.ion()
29
30
31
class Input_Vnet(nn.Module):
32
    def __init__(self, in_channels, out_channels):
33
        super(Input_Vnet, self).__init__()
34
        self.conv1 = nn.Conv3d(in_channels, out_channels, 3, padding=1)
35
        self.bn1 = nn.BatchNorm3d(out_channels)
36
        self.in_channels = in_channels
37
    
38
    def forward(self,x):
39
        ones = Variable(torch.ones(1,self.in_channels,1,256,256)).cuda()
40
        x = torch.cat([x,ones], dim = 2)
41
        return F.relu(self.bn1(self.conv1(x)))
42
43
class Downsample_Vnet(nn.Module):
44
    def __init__(self, in_channels, out_channels, down_sample = (2,2,2), num_layers = 1, drop_out = False):
45
        super(Downsample_Vnet,self).__init__()
46
        self.layer = nn.ModuleList()
47
        for i in range(num_layers):
48
            self.layer.append(nn.Conv3d(out_channels, out_channels, 3, padding =1))
49
            self.layer.append(nn.BatchNorm3d(out_channels))
50
            self.layer.append(nn.ReLU())
51
        self.down_conv = nn.Conv3d(in_channels, out_channels, down_sample, stride = down_sample)
52
        self.down_bn = nn.BatchNorm3d(out_channels)
53
        self.drop_out = drop_out
54
    
55
    def forward(self, x):
56
        a = F.relu(self.down_bn(self.down_conv(x)))
57
        if self.drop_out:
58
            x = F.dropout3d(a)
59
        else:
60
            x = a
61
        for i in self.layer:
62
            x = i(x)
63
        x = F.relu(torch.add(x,a))
64
        
65
        return x
66
67
class Upsample_Vnet(nn.Module):
68
    def __init__(self, in_channels, out_channels, up_sample = (2,2,2), num_layers = 1, drop_out = False):
69
        super(Upsample_Vnet, self).__init__()
70
        self.layer = nn.ModuleList()
71
        for i in range(num_layers):
72
            self.layer.append(nn.Conv3d(out_channels, out_channels, 3, padding =1))
73
            self.layer.append(nn.BatchNorm3d(out_channels))
74
            self.layer.append(nn.ReLU())
75
        self.up_conv = nn.ConvTranspose3d(in_channels, out_channels//2, up_sample, stride = up_sample)
76
        self.up_bn = nn.BatchNorm3d(out_channels//2)
77
        self.drop_out = drop_out
78
        
79
    def forward(self, x, y):
80
        if self.drop_out:
81
            x = F.dropout3d(x)
82
        x = F.relu(self.up_bn(self.up_conv(x)))
83
        y = F.dropout3d(y)
84
        x = torch.cat([x,y], dim = 1)
85
        a = x
86
        for i in self.layer:
87
            x = i(x)
88
        x = F.relu(torch.add(x,a))
89
        
90
        return x
91
92
93
class Output_Vnet(nn.Module):
94
    def __init__(self, in_channels, out_channels):
95
        super(Output_Vnet, self).__init__()
96
        self.conv1 = nn.Conv3d(in_channels, out_channels, 3, padding = 1)
97
        self.bn1 = nn.BatchNorm3d(out_channels)
98
        self.conv2 = nn.Conv3d(out_channels, out_channels, 1)
99
        
100
    def forward(self, x):
101
        x = F.relu(self.bn1(self.conv1(x)))
102
        x = self.conv2(x)
103
        return x[:,:,:15,:,:]
104
105
106
class Vnet(nn.Module):
107
    def __init__(self, in_channels, out_channels):
108
        super(Vnet, self).__init__()
109
        self.input_layer = Input_Vnet(in_channels, 16)
110
        self.down1 = Downsample_Vnet(16, 32, down_sample=(1,2,2), num_layers=2)
111
        self.down2 = Downsample_Vnet(32, 64, num_layers= 3)
112
        self.down3 = Downsample_Vnet(64, 128, down_sample = (1,2,2),num_layers = 3, drop_out=True)
113
        self.down4 = Downsample_Vnet(128, 256, num_layers= 3, drop_out=True)
114
        self.up4 = Upsample_Vnet(256, 256, num_layers = 3, drop_out=True)
115
        self.up3 = Upsample_Vnet(256,128, up_sample= (1,2,2), num_layers = 3, drop_out=True)
116
        self.up2 = Upsample_Vnet(128,64, num_layers = 2)
117
        self.up1 = Upsample_Vnet(64, 32, up_sample=(1,2,2), num_layers=1)
118
        self.output = Output_Vnet(32, 4)
119
    
120
    def forward(self, x):
121
        x_in = self.input_layer(x)
122
        x_d1 = self.down1(x_in)
123
        x_d2 = self.down2(x_d1)
124
        x_d3 = self.down3(x_d2)
125
        x_d4 = self.down4(x_d3)
126
        x = self.up4(x_d4, x_d3)
127
        x = self.up3(x,x_d2)
128
        x = self.up2(x, x_d1)
129
        x = self.up1(x, x_in)
130
        x = self.output(x)
131
        
132
        return x