#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2019/12/30 下午9:34
# @Author : chuyu zhang
# @File : discriminator.py
# @Software: PyCharm
import torch.nn as nn
import torch.nn.functional as F
import torch
class FCDiscriminator(nn.Module):
def __init__(self, num_classes, ndf=64, n_channel=1):
super(FCDiscriminator, self).__init__()
self.conv0 = nn.Conv2d(num_classes, ndf, kernel_size=4, stride=2, padding=1)
self.conv1 = nn.Conv2d(n_channel, ndf, kernel_size=4, stride=2, padding=1)
self.conv2 = nn.Conv2d(ndf, ndf*2, kernel_size=4, stride=2, padding=1)
self.conv3 = nn.Conv2d(ndf*2, ndf*4, kernel_size=4, stride=2, padding=1)
self.conv4 = nn.Conv2d(ndf*4, ndf*8, kernel_size=4, stride=2, padding=1)
self.classifier = nn.Linear(ndf*8, 2)
self.avgpool = nn.AvgPool2d((7, 7))
self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
self.dropout = nn.Dropout2d(0.5)
# self.up_sample = nn.Upsample(scale_factor=32, mode='bilinear')
# self.sigmoid = nn.Sigmoid()
def forward(self, map, feature):
map_feature = self.conv0(map)
image_feature = self.conv1(feature)
x = torch.add(map_feature, image_feature)
x = self.conv2(x)
x = self.leaky_relu(x)
x = self.dropout(x)
x = self.conv3(x)
x = self.leaky_relu(x)
x = self.dropout(x)
x = self.conv4(x)
x = self.leaky_relu(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
# x = self.up_sample(x)
# x = self.sigmoid(x)
return x
class FC3DDiscriminator(nn.Module):
def __init__(self, num_classes, ndf=64, n_channel=1):
super(FC3DDiscriminator, self).__init__()
# downsample 16
self.conv0 = nn.Conv3d(num_classes, ndf, kernel_size=4, stride=2, padding=1)
self.conv1 = nn.Conv3d(n_channel, ndf, kernel_size=4, stride=2, padding=1)
self.conv2 = nn.Conv3d(ndf, ndf*2, kernel_size=4, stride=2, padding=1)
self.conv3 = nn.Conv3d(ndf*2, ndf*4, kernel_size=4, stride=2, padding=1)
self.conv4 = nn.Conv3d(ndf*4, ndf*8, kernel_size=4, stride=2, padding=1)
self.avgpool = nn.AvgPool3d((7, 7, 5))
self.classifier = nn.Linear(ndf*8, 2)
self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
self.dropout = nn.Dropout3d(0.5)
self.Softmax = nn.Softmax()
def forward(self, map, image):
batch_size = map.shape[0]
map_feature = self.conv0(map)
image_feature = self.conv1(image)
x = torch.add(map_feature, image_feature)
x = self.leaky_relu(x)
x = self.dropout(x)
x = self.conv2(x)
x = self.leaky_relu(x)
x = self.dropout(x)
x = self.conv3(x)
x = self.leaky_relu(x)
x = self.dropout(x)
x = self.conv4(x)
x = self.leaky_relu(x)
x = self.avgpool(x)
x = x.view(batch_size, -1)
x = self.classifier(x)
x = x.reshape((batch_size, 2))
# x = self.Softmax(x)
return x
class FC3DDiscriminatorNIH(nn.Module):
def __init__(self, num_classes, ndf=64, n_channel=1):
super(FC3DDiscriminatorNIH, self).__init__()
# downsample 16
self.conv0 = nn.Conv3d(num_classes, ndf, kernel_size=4, stride=2, padding=1)
self.conv1 = nn.Conv3d(n_channel, ndf, kernel_size=4, stride=2, padding=1)
self.conv2 = nn.Conv3d(ndf, ndf*2, kernel_size=4, stride=2, padding=1)
self.conv3 = nn.Conv3d(ndf*2, ndf*4, kernel_size=4, stride=2, padding=1)
self.conv4 = nn.Conv3d(ndf*4, ndf*8, kernel_size=4, stride=2, padding=1)
self.avgpool = nn.AvgPool3d((13, 10, 9))
self.classifier = nn.Linear(ndf*8, 2)
self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
self.dropout = nn.Dropout3d(0.5)
self.Softmax = nn.Softmax()
def forward(self, map, image):
batch_size = map.shape[0]
map_feature = self.conv0(map)
image_feature = self.conv1(image)
x = torch.add(map_feature, image_feature)
x = self.leaky_relu(x)
x = self.dropout(x)
x = self.conv2(x)
x = self.leaky_relu(x)
x = self.dropout(x)
x = self.conv3(x)
x = self.leaky_relu(x)
x = self.dropout(x)
x = self.conv4(x)
x = self.leaky_relu(x)
x = self.avgpool(x)
x = x.view(batch_size, -1)
x = self.classifier(x)
x = x.reshape((batch_size, 2))
# x = self.Softmax(x)
return x
class FCDiscriminatorDAP(nn.Module):
def __init__(self, num_classes, ndf = 64):
super(FCDiscriminatorDAP, self).__init__()
self.conv1 = nn.Conv3d(num_classes, ndf, kernel_size=4, stride=2, padding=1)
self.conv2 = nn.Conv3d(ndf, ndf*2, kernel_size=4, stride=2, padding=1)
self.conv3 = nn.Conv3d(ndf*2, ndf*4, kernel_size=4, stride=2, padding=1)
self.classifier = nn.Conv3d(ndf*4, 1, kernel_size=4, stride=2, padding=1)
self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
self.up_sample = nn.Upsample(scale_factor=16, mode='trilinear', align_corners=True)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.conv1(x)
x = self.leaky_relu(x)
x = self.conv2(x)
x = self.leaky_relu(x)
x = self.conv3(x)
x = self.leaky_relu(x)
x = self.classifier(x)
x = self.up_sample(x)
x = self.sigmoid(x)
return x
if __name__ == '__main__':
# compute FLOPS & PARAMETERS
from thop import profile
from thop import clever_format
model = FC3DDiscriminator(num_classes=1)
input = torch.randn(4, 1, 112, 112, 80)
flops, params = profile(model, inputs=(input,input))
macs, params = clever_format([flops, params], "%.3f")
print(macs, params)
model = FCDiscriminatorDAP(num_classes=2)
input = torch.randn(4, 2, 112, 112, 80)
flops, params = profile(model, inputs=(input,))
macs, params = clever_format([flops, params], "%.3f")
print(macs, params)
import ipdb; ipdb.set_trace()