[4fa73e]: / pytorch / graphs / models / discriminator.py

Download this file

115 lines (89 with data), 4.6 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import torch
import torch.nn as nn
import torch.nn.functional as F
import json
from easydict import EasyDict as edict
from graphs.models.custom_functions.weight_norm import WN_Conv3d, WN_ConvTranspose3d
# 3D-UNet as Discriminator
class Discriminator(nn.Module):
def __init__(self, config):
super(Discriminator, self).__init__()
self.config = config
self.input_channel = self.config.num_modalities
self.num_classes = self.config.num_classes
kernel_size = (3,3,3)
kernel_size_deconv = (2,2,2)
stride_deconv = (2,2,2)
out_channels = 32
self.lrelu = nn.LeakyReLU(0.2)
self.dropout = nn.Dropout3d(p=0.2, inplace=False)
self.final_activation = nn.Softmax(dim=1)
# Defining the convolutional operations
self.pool = nn.AvgPool3d(2)
self.encoder0 = WN_Conv3d(self.input_channel, out_channels, kernel_size)
self.encoder1 = WN_Conv3d(out_channels, out_channels, kernel_size)
self.encoder2 = WN_Conv3d(out_channels, out_channels*(2), kernel_size)
self.encoder3 = WN_Conv3d(out_channels*(2), out_channels*(2), kernel_size)
self.encoder4 = WN_Conv3d(out_channels*(2), out_channels*(2**2), kernel_size)
self.encoder5 = WN_Conv3d(out_channels*(2**2), out_channels*(2**2), kernel_size)
self.encoder6 = WN_Conv3d(out_channels*(2**2), out_channels*(2**3), kernel_size)
self.encoder7 = WN_Conv3d(out_channels*(2**3), out_channels*(2**3), kernel_size)
self.decoder1 = WN_ConvTranspose3d(out_channels*(2**3), out_channels*(2**3), kernel_size_deconv, stride_deconv)
#encoder5 + decoder1
self.encoder8 = WN_Conv3d(out_channels*(2**2) + out_channels*(2**3), out_channels*(2**2), kernel_size)
self.encoder9 = WN_Conv3d(out_channels*(2**2), out_channels*(2**2), kernel_size)
self.decoder2 = WN_ConvTranspose3d(out_channels*(2**2), out_channels*(2**2), kernel_size_deconv, stride_deconv)
#encoder3 + decoder2
self.encoder10 = WN_Conv3d(out_channels*(2) + out_channels*(2**2), out_channels*(2), kernel_size)
self.encoder11 = WN_Conv3d(out_channels*(2), out_channels*(2), kernel_size)
self.decoder3 = WN_ConvTranspose3d(out_channels*(2), out_channels*(2), kernel_size_deconv, stride_deconv)
#encoder1 + decoder3
self.encoder12 = WN_Conv3d(out_channels + out_channels*(2), out_channels, kernel_size)
self.encoder13 = WN_Conv3d(out_channels, out_channels, kernel_size)
self.final_conv = WN_Conv3d(out_channels, self.num_classes, kernel_size)
def forward(self, input, get_feature=False, use_dropout=False):
conv0 = self.lrelu(self.encoder0(input))
conv1 = self.lrelu(self.encoder1(conv0))
pool1 = self.pool(conv1)
conv2 = self.lrelu(self.encoder2(pool1))
conv3 = self.lrelu(self.encoder3(conv2))
pool3 = self.pool(conv3)
conv4 = self.lrelu(self.encoder4(pool3))
conv5 = self.lrelu(self.encoder5(conv4))
pool5 = self.pool(conv5)
conv6 = self.lrelu(self.encoder6(pool5))
conv7 = self.lrelu(self.encoder7(conv6))
if use_dropout:
conv7 = self.dropout(conv7)
deconv1 = self.decoder1(conv7)
skip_connection1 = torch.cat((conv5, deconv1), 1)
conv8 = self.lrelu(self.encoder8(skip_connection1))
conv9 = self.lrelu(self.encoder9(conv8))
deconv2 = self.decoder2(conv9)
skip_connection2 = torch.cat((conv3, deconv2), 1)
conv10 = self.lrelu(self.encoder10(skip_connection2))
conv11 = self.lrelu(self.encoder11(conv10))
deconv3 = self.decoder3(conv11)
skip_connection3 = torch.cat((conv1, deconv3), 1)
conv12 = self.lrelu(self.encoder12(skip_connection3))
conv13 = self.lrelu(self.encoder13(conv12))
if use_dropout:
conv13 = self.dropout(conv13)
final_output = self.final_conv(conv13)
if not get_feature:
return final_output, self.final_activation(final_output)
else:
return final_output, self.final_activation(final_output), conv6
"""
netD testing
"""
def main():
config = json.load(open('../../configs/gan_exp_0.json'))
config = edict(config)
inp = torch.autograd.Variable(torch.randn(config.batch_size, config.input_channels, config.patch_shape[0], config.patch_shape[1], config.patch_shape[2]))
print (inp.shape)
netD = Discriminator(config)
out = netD(inp)
print (out)
if __name__ == '__main__':
main()