Diff of /unet_3d.py [000000] .. [94d9b6]

Switch to unified view

a b/unet_3d.py
1
import torch
2
import torch.nn as nn
3
from torch.autograd import Variable
4
import torch.optim as optim
5
import torchvision
6
from torchvision import datasets, models
7
from torchvision import transforms as T
8
from torch.utils.data import DataLoader, Dataset
9
import numpy as np
10
import matplotlib.pyplot as plt
11
import os
12
import time
13
import pandas as pd
14
from skimage import io, transform
15
import matplotlib.image as mpimg
16
from PIL import Image
17
from sklearn.metrics import roc_auc_score
18
import torch.nn.functional as F
19
import scipy
20
import random
21
import pickle
22
import scipy.io as sio
23
import itertools
24
from scipy.ndimage.interpolation import shift
25
26
import warnings
27
warnings.filterwarnings("ignore")
28
plt.ion()   
29
30
class Downsample_Unet3d(nn.Module):
31
    def __init__(self, in_channels, out_channels, down_sample = (1,2,2)):
32
        super(Downsample_Unet3d, self).__init__()
33
        self.conv1 = nn.Conv3d(in_channels, out_channels, 3, padding = 1)
34
        self.bn1 = nn.BatchNorm3d(out_channels)
35
        self.conv2 = nn.Conv3d(out_channels, out_channels, 3, padding = 1)
36
        self.bn2 = nn.BatchNorm3d(out_channels)
37
        self.conv3 = nn.Conv3d(out_channels, out_channels, down_sample, stride = down_sample)
38
        self.bn3 = nn.BatchNorm3d(out_channels)
39
40
    def forward(self,x):
41
        x = F.relu(self.bn1(self.conv1(x)))
42
        y = F.relu(self.bn2(self.conv2(x)))
43
#         x = F.max_pool3d(y, (1,2,2),stride = (1,2,2))
44
        x = F.relu(self.bn3(self.conv3(y)))
45
        
46
        return x, y
47
48
class Upsample_Unet3d(nn.Module):
49
    def __init__(self,in_channels, out_channels, up_sample = (1,2,2)):
50
        super(Upsample_Unet3d, self).__init__()
51
        self.upconv = nn.ConvTranspose3d(in_channels, out_channels, up_sample, stride = up_sample)
52
        self.bnup = nn.BatchNorm3d(out_channels)
53
        self.conv1 = nn.Conv3d(in_channels, out_channels, 3, padding = 1)
54
        self.bn1 = nn.BatchNorm3d(out_channels)
55
        self.conv2 = nn.Conv3d(out_channels, out_channels, 3, padding = 1)
56
        self.bn2 = nn.BatchNorm2d(out_channels)
57
        
58
    def forward(self,x, y):
59
        x = F.relu(self.bnup(self.upconv(x)))
60
        x = torch.cat((x,y),dim = 1)
61
        x = F.relu(self.bn1(self.conv1(x)))
62
        x = F.relu(self.bn2(self.conv2(x)))
63
        
64
        return x
65
66
class Unet_3D(nn.Module):
67
    def __init__(self, in_layers, out_layers):
68
        super(Unet_3D, self).__init__()
69
        self.down1 = Downsample_Unet3d(in_layers,32)
70
        self.down2 = Downsample_Unet3d(32,64)
71
        self.down3 = Downsample_Unet3d(64,128)
72
        self.down4 = Downsample_Unet3d(128,256)
73
        self.conv1 = nn.Conv3d(256,512, 3, padding=1)
74
        self.bn1 = nn.BatchNorm3d(512)
75
        self.conv2 = nn.Conv3d(512,512,3, padding = 1)
76
        self.bn2 = nn.BatchNorm3d(512)
77
        self.up4 = Upsample_Unet3d(512,256)
78
        self.up3 = Upsample_Unet3d(256,128)
79
        self.up2 = Upsample_Unet3d(128,64)
80
        self.up1 = Upsample_Unet3d(64,32)
81
        self.outconv = nn.Conv3d(32,out_layers, 1)
82
        
83
    def forward(self,x):
84
        x, y1 = self.down1(x)
85
        x, y2 = self.down2(x)
86
        x, y3 = self.down3(x)
87
        x, y4 = self.down4(x)
88
        x = F.dropout2d(F.relu(self.bn1(self.conv1(x))))
89
        x = F.dropout2d(F.relu(self.bn2(self.conv2(x))))
90
        x = self.up4(x, y4)
91
        x = self.up3(x, y3)
92
        x = self.up2(x, y2)
93
        x = self.up1(x, y1)
94
        x = self.outconv(x)
95
        
96
        return x
97