Diff of /unet_3d.py [000000] .. [94d9b6]

Switch to side-by-side view

--- a
+++ b/unet_3d.py
@@ -0,0 +1,97 @@
+import torch
+import torch.nn as nn
+from torch.autograd import Variable
+import torch.optim as optim
+import torchvision
+from torchvision import datasets, models
+from torchvision import transforms as T
+from torch.utils.data import DataLoader, Dataset
+import numpy as np
+import matplotlib.pyplot as plt
+import os
+import time
+import pandas as pd
+from skimage import io, transform
+import matplotlib.image as mpimg
+from PIL import Image
+from sklearn.metrics import roc_auc_score
+import torch.nn.functional as F
+import scipy
+import random
+import pickle
+import scipy.io as sio
+import itertools
+from scipy.ndimage.interpolation import shift
+
+import warnings
+warnings.filterwarnings("ignore")
+plt.ion()   
+
+class Downsample_Unet3d(nn.Module):
+    def __init__(self, in_channels, out_channels, down_sample = (1,2,2)):
+        super(Downsample_Unet3d, self).__init__()
+        self.conv1 = nn.Conv3d(in_channels, out_channels, 3, padding = 1)
+        self.bn1 = nn.BatchNorm3d(out_channels)
+        self.conv2 = nn.Conv3d(out_channels, out_channels, 3, padding = 1)
+        self.bn2 = nn.BatchNorm3d(out_channels)
+        self.conv3 = nn.Conv3d(out_channels, out_channels, down_sample, stride = down_sample)
+        self.bn3 = nn.BatchNorm3d(out_channels)
+
+    def forward(self,x):
+        x = F.relu(self.bn1(self.conv1(x)))
+        y = F.relu(self.bn2(self.conv2(x)))
+#         x = F.max_pool3d(y, (1,2,2),stride = (1,2,2))
+        x = F.relu(self.bn3(self.conv3(y)))
+        
+        return x, y
+
+class Upsample_Unet3d(nn.Module):
+    def __init__(self,in_channels, out_channels, up_sample = (1,2,2)):
+        super(Upsample_Unet3d, self).__init__()
+        self.upconv = nn.ConvTranspose3d(in_channels, out_channels, up_sample, stride = up_sample)
+        self.bnup = nn.BatchNorm3d(out_channels)
+        self.conv1 = nn.Conv3d(in_channels, out_channels, 3, padding = 1)
+        self.bn1 = nn.BatchNorm3d(out_channels)
+        self.conv2 = nn.Conv3d(out_channels, out_channels, 3, padding = 1)
+        self.bn2 = nn.BatchNorm2d(out_channels)
+        
+    def forward(self,x, y):
+        x = F.relu(self.bnup(self.upconv(x)))
+        x = torch.cat((x,y),dim = 1)
+        x = F.relu(self.bn1(self.conv1(x)))
+        x = F.relu(self.bn2(self.conv2(x)))
+        
+        return x
+
+class Unet_3D(nn.Module):
+    def __init__(self, in_layers, out_layers):
+        super(Unet_3D, self).__init__()
+        self.down1 = Downsample_Unet3d(in_layers,32)
+        self.down2 = Downsample_Unet3d(32,64)
+        self.down3 = Downsample_Unet3d(64,128)
+        self.down4 = Downsample_Unet3d(128,256)
+        self.conv1 = nn.Conv3d(256,512, 3, padding=1)
+        self.bn1 = nn.BatchNorm3d(512)
+        self.conv2 = nn.Conv3d(512,512,3, padding = 1)
+        self.bn2 = nn.BatchNorm3d(512)
+        self.up4 = Upsample_Unet3d(512,256)
+        self.up3 = Upsample_Unet3d(256,128)
+        self.up2 = Upsample_Unet3d(128,64)
+        self.up1 = Upsample_Unet3d(64,32)
+        self.outconv = nn.Conv3d(32,out_layers, 1)
+        
+    def forward(self,x):
+        x, y1 = self.down1(x)
+        x, y2 = self.down2(x)
+        x, y3 = self.down3(x)
+        x, y4 = self.down4(x)
+        x = F.dropout2d(F.relu(self.bn1(self.conv1(x))))
+        x = F.dropout2d(F.relu(self.bn2(self.conv2(x))))
+        x = self.up4(x, y4)
+        x = self.up3(x, y3)
+        x = self.up2(x, y2)
+        x = self.up1(x, y1)
+        x = self.outconv(x)
+        
+        return x
+