Diff of /unet_basic_dilated.py [000000] .. [94d9b6]

Switch to side-by-side view

--- a
+++ b/unet_basic_dilated.py
@@ -0,0 +1,283 @@
+import torch
+import torch.nn as nn
+from torch.autograd import Variable
+import torch.optim as optim
+import torchvision
+from torchvision import datasets, models
+from torchvision import transforms as T
+from torch.utils.data import DataLoader, Dataset
+import numpy as np
+import matplotlib.pyplot as plt
+import os
+import time
+import pandas as pd
+from skimage import io, transform
+import matplotlib.image as mpimg
+from PIL import Image
+from sklearn.metrics import roc_auc_score
+import torch.nn.functional as F
+import scipy
+import random
+import pickle
+import scipy.io as sio
+import itertools
+from scipy.ndimage.interpolation import shift
+import copy
+import warnings
+warnings.filterwarnings("ignore")
+plt.ion()
+
+class Downsample_block(nn.Module):
+    def __init__(self, in_channels, out_channels):
+        super(Downsample_block, self).__init__()
+        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
+        self.bn1 = nn.BatchNorm2d(out_channels)
+        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
+        self.bn2 = nn.BatchNorm2d(out_channels)
+    
+    def forward(self,x):
+        x = F.relu(self.bn1(self.conv1(x)))
+        y = F.relu(self.bn2(self.conv2(x)))
+        x = F.max_pool2d(y, 2,stride = 2)
+        
+        return x, y
+
+
+class Down_sample_dilated_block(nn.Module):
+    def __init__(self,in_channels, out_channels, dropout = False):
+        super(Down_sample_dilated_block,self).__init__()
+        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding= 1)
+        self.bn1 = nn.BatchNorm2d(out_channels)
+        self.conv2_1d = nn.Conv2d(out_channels, out_channels, 3, padding= 1)
+        self.bn2_1d = nn.BatchNorm2d(out_channels)
+        self.conv2_2d = nn.Conv2d(out_channels, out_channels, 3, padding = 3, dilation = 3)
+        self.bn2_2d = nn.BatchNorm2d(out_channels)
+        self.conv2_3d = nn.Conv2d(out_channels, out_channels, 3, padding = 5, dilation = 5)
+        self.bn2_3d = nn.BatchNorm2d(out_channels)
+        self.conv3 = nn.Conv2d(out_channels *3, out_channels, 1)
+        self.bn3 = nn.BatchNorm2d(out_channels)
+        self.conv_down = nn.Conv2d(out_channels, out_channels, 2, stride = 2)
+        self.bn_down_conv = nn.BatchNorm2d(out_channels)
+        self.dp = dropout
+    
+    def forward(self,x):
+        x = F.relu(self.bn1(self.conv1(x)))
+        x1 = F.relu(self.bn2_1d(self.conv2_1d(x)))
+        x2 = F.relu(self.bn2_2d(self.conv2_2d(x)))
+        x3 = F.relu(self.bn2_3d(self.conv2_3d(x)))
+        x = torch.cat([x1,x2,x3], dim = 1)
+        y = F.relu(self.bn3(self.conv3(x)))
+        if self.dp:
+            x = F.dropout2d(F.relu(self.bn_down_conv(self.conv_down(y))))
+        else:
+            x = F.relu(self.bn_down_conv(self.conv_down(y)))
+        
+        return x, y
+
+
+class Upsample_block(nn.Module):
+    def __init__(self,in_channels, out_channels, dropout = False):
+        super(Upsample_block, self).__init__()
+        self.transconv = nn.ConvTranspose2d(in_channels, out_channels, 4, padding = 1, stride = 2)
+        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding = 1)
+        self.bn1 = nn.BatchNorm2d(out_channels)
+        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding = 1)
+        self.bn2 = nn.BatchNorm2d(out_channels)
+        self.dp = dropout
+        
+    def forward(self,x, y):
+        x = self.transconv(x)
+        if self.dp:
+            y = F.dropout2d(y)
+        x = torch.cat((x,y),dim = 1)
+        x = F.relu(self.bn1(self.conv1(x)))
+        x = F.relu(self.bn2(self.conv2(x)))
+        
+        return x
+
+class Upsample_block_small(nn.Module):
+    def __init__(self,in_channels, out_channels, dropout = False):
+        super(Upsample_block_small, self).__init__()
+        self.transconv = nn.ConvTranspose2d(in_channels, out_channels, 4, padding = 1, stride = 2)
+        self.conv1 = nn.Conv2d(2*in_channels, out_channels, 3, padding = 1)
+        self.bn1 = nn.BatchNorm2d(out_channels)
+        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding = 1)
+        self.bn2 = nn.BatchNorm2d(out_channels)
+        self.dp = dropout
+        
+    def forward(self,x, y):
+        x = self.transconv(x)
+        if self.dp:
+            y = F.dropout2d(y)
+        x = torch.cat((x,y),dim = 1)
+        x = F.relu(self.bn1(self.conv1(x)))
+        x = F.relu(self.bn2(self.conv2(x)))
+        
+        return x
+
+class Unet_dilated_gen(nn.Module):
+    def __init__(self, in_layers, out_layers, dilated = True, dropout = False):
+        super(Unet_dilated_gen, self).__init__()
+        if dilated:
+            self.down1 = Down_sample_dilated_block(in_layers,64)
+            self.down2 = Down_sample_dilated_block(64,128)
+            self.down3 = Down_sample_dilated_block(128,256,dropout=dropout)
+            self.down4 = Down_sample_dilated_block(256,512,dropout=dropout)
+        else:
+            self.down1 = Downsample_block(in_layers,64)
+            self.down2 = Downsample_block(64,128)
+            self.down3 = Downsample_block(128,256)
+            self.down4 = Downsample_block(256,512)
+        self.conv1 = nn.Conv2d(512,1024, 3, padding=1)
+        self.bn1 = nn.BatchNorm2d(1024)
+        self.conv2 = nn.Conv2d(1024,1024,3, padding = 1)
+        self.bn2 = nn.BatchNorm2d(1024)
+        self.up4 = Upsample_block(1024,512)
+        self.up3 = Upsample_block(512,256)
+        self.up2 = Upsample_block(256,128, dropout=dropout)
+        self.up1 = Upsample_block(128,64, dropout=dropout)
+        self.outconv = nn.Conv2d(64,out_layers, 1)
+        
+    def forward(self,x):
+        x, y1 = self.down1(x)
+        x, y2 = self.down2(x)
+        x, y3 = self.down3(x)
+        x, y4 = self.down4(x)
+        x = F.dropout2d(F.relu(self.bn1(self.conv1(x))))
+        x = F.dropout2d(F.relu(self.bn2(self.conv2(x))))
+        x = self.up4(x, y4)
+        x = self.up3(x, y3)
+        x = self.up2(x, y2)
+        x = self.up1(x, y1)
+        x = self.outconv(x)
+        
+        return x
+
+
+class Unet_dilated_small_deep(nn.Module):
+    def __init__(self, in_layers, out_layers, dilated = True, int_var = 32):
+        super(Unet_dilated_small_deep, self).__init__()
+        if dilated:
+            self.down1 = Down_sample_dilated_block(in_layers,int_var)
+            self.down2 = Down_sample_dilated_block(int_var,int_var)
+            self.down3 = Down_sample_dilated_block(int_var,int_var)
+            self.down4 = Down_sample_dilated_block(int_var,int_var)
+            self.down5 = Down_sample_dilated_block(int_var,int_var)
+            self.down6 = Down_sample_dilated_block(int_var,int_var)
+        else:
+            self.down1 = Downsample_block(in_layers,int_var)
+            self.down2 = Downsample_block(int_var,int_var)
+            self.down3 = Downsample_block(int_var,int_var)
+            self.down4 = Downsample_block(int_var,int_var)
+            self.down5 = Downsample_block(int_var,int_var)
+            self.down6 = Downsample_block(int_var,int_var)
+        self.conv1 = nn.Conv2d(int_var,int_var, 3, padding=1)
+        self.bn1 = nn.BatchNorm2d(2048)
+        self.conv2 = nn.Conv2d(int_var,int_var,3, padding = 1)
+        self.bn2 = nn.BatchNorm2d(2048)
+        self.up5 = Upsample_block(int_var,int_var)
+        self.up4 = Upsample_block(int_var,int_var)
+        self.up3 = Upsample_block(int_var,int_var)
+        self.up2 = Upsample_block(int_var,int_var)
+        self.up1 = Upsample_block(int_var,int_var)
+        self.outconv = nn.Conv2d(int_var,out_layers, 1)
+        
+    def forward(self,x):
+        x, y1 = self.down1(x)
+        x, y2 = self.down2(x)
+        x, y3 = self.down3(x)
+        x, y4 = self.down4(x)
+        x, y5 = self.down5(x)
+        x = F.dropout2d(F.relu(self.bn1(self.conv1(x))))
+        x = F.dropout2d(F.relu(self.bn2(self.conv2(x))))
+        x = self.up5(x, y5)
+        x = self.up4(x, y4)
+        x = self.up3(x, y3)
+        x = self.up2(x, y2)
+        x = self.up1(x, y1)
+        x = self.outconv(x)
+        
+        return x
+
+class Unet_dilated_small(nn.Module):
+    def __init__(self, in_layers, out_layers, int_var = 64, dilated = True, dropout = False):
+        super(Unet_dilated_small, self).__init__()
+        if dilated:
+            self.down1 = Down_sample_dilated_block(in_layers,int_var)
+            self.down2 = Down_sample_dilated_block(int_var,int_var)
+            self.down3 = Down_sample_dilated_block(int_var,int_var,dropout=dropout)
+            self.down4 = Down_sample_dilated_block(int_var,int_var,dropout=dropout)
+        else:
+            self.down1 = Downsample_block(in_layers,int_var)
+            self.down2 = Downsample_block(int_var,int_var)
+            self.down3 = Downsample_block(int_var,int_var)
+            self.down4 = Downsample_block(int_var,int_var)
+        self.conv1 = nn.Conv2d(int_var,int_var, 3, padding=1)
+        self.bn1 = nn.BatchNorm2d(int_var)
+        self.conv2 = nn.Conv2d(int_var,int_var,3, padding = 1)
+        self.bn2 = nn.BatchNorm2d(int_var)
+        self.up4 = Upsample_block_small(int_var,int_var)
+        self.up3 = Upsample_block_small(int_var,int_var)
+        self.up2 = Upsample_block_small(int_var,int_var, dropout=dropout)
+        self.up1 = Upsample_block_small(int_var,int_var, dropout=dropout)
+        self.outconv = nn.Conv2d(int_var,out_layers, 1)
+        
+    def forward(self,x):
+        x, y1 = self.down1(x)
+        x, y2 = self.down2(x)
+        x, y3 = self.down3(x)
+        x, y4 = self.down4(x)
+        x = F.dropout2d(F.relu(self.bn1(self.conv1(x))))
+        x = F.dropout2d(F.relu(self.bn2(self.conv2(x))))
+        x = self.up4(x, y4)
+        x = self.up3(x, y3)
+        x = self.up2(x, y2)
+        x = self.up1(x, y1)
+        x = self.outconv(x)
+        
+        return x
+
+class Unet_dilated_small_multi(nn.Module):
+    def __init__(self, in_layers, out_layers, int_var = 64, dilated = True, dropout = False):
+        super(Unet_dilated_small_multi, self).__init__()
+        if dilated:
+            self.down1 = Down_sample_dilated_block(in_layers,int_var)
+            self.down2 = Down_sample_dilated_block(int_var,int_var)
+            self.down3 = Down_sample_dilated_block(int_var,int_var,dropout=dropout)
+            self.down4 = Down_sample_dilated_block(int_var,int_var,dropout=dropout)
+        else:
+            self.down1 = Downsample_block(in_layers,int_var)
+            self.down2 = Downsample_block(int_var,int_var)
+            self.down3 = Downsample_block(int_var,int_var)
+            self.down4 = Downsample_block(int_var,int_var)
+        self.conv1 = nn.Conv2d(int_var,int_var, 3, padding=1)
+        self.bn1 = nn.BatchNorm2d(int_var)
+        self.conv2 = nn.Conv2d(int_var,int_var,3, padding = 1)
+        self.bn2 = nn.BatchNorm2d(int_var)
+        self.up_layer = nn.ModuleList()
+        self.out_layers = out_layers
+        for i in range(self.out_layers):
+            self.up_layer.append(nn.ModuleList())
+            self.up_layer[i].append(Upsample_block_small(int_var,int_var))
+            self.up_layer[i].append(Upsample_block_small(int_var,int_var))
+            self.up_layer[i].append(Upsample_block_small(int_var,int_var, dropout=dropout))
+            self.up_layer[i].append(Upsample_block_small(int_var,int_var, dropout=dropout))
+            self.up_layer[i].append(nn.Conv2d(int_var,1,1))
+        
+    def forward(self,x):
+        x, y1 = self.down1(x)
+        x, y2 = self.down2(x)
+        x, y3 = self.down3(x)
+        x, y4 = self.down4(x)
+        x = F.dropout2d(F.relu(self.bn1(self.conv1(x))))
+        x = F.dropout2d(F.relu(self.bn2(self.conv2(x))))
+        xs = []
+        for i in range(self.out_layers):
+            temp_x = self.up_layer[i][0](x,y4)
+            temp_x = self.up_layer[i][1](temp_x,y3)
+            temp_x = self.up_layer[i][2](temp_x,y2)
+            temp_x = self.up_layer[i][3](temp_x,y1)
+            xs.append(self.up_layer[i][4](temp_x))
+        x = torch.cat(xs,dim=1)
+        
+        return x
\ No newline at end of file