--- a +++ b/networks/discriminator.py @@ -0,0 +1,189 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Time : 2019/12/30 下午9:34 +# @Author : chuyu zhang +# @File : discriminator.py +# @Software: PyCharm + +import torch.nn as nn +import torch.nn.functional as F +import torch + + +class FCDiscriminator(nn.Module): + + def __init__(self, num_classes, ndf=64, n_channel=1): + super(FCDiscriminator, self).__init__() + self.conv0 = nn.Conv2d(num_classes, ndf, kernel_size=4, stride=2, padding=1) + self.conv1 = nn.Conv2d(n_channel, ndf, kernel_size=4, stride=2, padding=1) + self.conv2 = nn.Conv2d(ndf, ndf*2, kernel_size=4, stride=2, padding=1) + self.conv3 = nn.Conv2d(ndf*2, ndf*4, kernel_size=4, stride=2, padding=1) + self.conv4 = nn.Conv2d(ndf*4, ndf*8, kernel_size=4, stride=2, padding=1) + self.classifier = nn.Linear(ndf*8, 2) + self.avgpool = nn.AvgPool2d((7, 7)) + + self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + self.dropout = nn.Dropout2d(0.5) + # self.up_sample = nn.Upsample(scale_factor=32, mode='bilinear') + # self.sigmoid = nn.Sigmoid() + + def forward(self, map, feature): + map_feature = self.conv0(map) + image_feature = self.conv1(feature) + x = torch.add(map_feature, image_feature) + + x = self.conv2(x) + x = self.leaky_relu(x) + x = self.dropout(x) + + x = self.conv3(x) + x = self.leaky_relu(x) + x = self.dropout(x) + + x = self.conv4(x) + x = self.leaky_relu(x) + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.classifier(x) + # x = self.up_sample(x) + # x = self.sigmoid(x) + + return x + + +class FC3DDiscriminator(nn.Module): + + def __init__(self, num_classes, ndf=64, n_channel=1): + super(FC3DDiscriminator, self).__init__() + # downsample 16 + self.conv0 = nn.Conv3d(num_classes, ndf, kernel_size=4, stride=2, padding=1) + self.conv1 = nn.Conv3d(n_channel, ndf, kernel_size=4, stride=2, padding=1) + + self.conv2 = nn.Conv3d(ndf, ndf*2, kernel_size=4, stride=2, padding=1) + self.conv3 = nn.Conv3d(ndf*2, ndf*4, kernel_size=4, stride=2, padding=1) + self.conv4 = nn.Conv3d(ndf*4, ndf*8, kernel_size=4, stride=2, padding=1) + self.avgpool = nn.AvgPool3d((7, 7, 5)) + self.classifier = nn.Linear(ndf*8, 2) + + self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + self.dropout = nn.Dropout3d(0.5) + self.Softmax = nn.Softmax() + + def forward(self, map, image): + batch_size = map.shape[0] + map_feature = self.conv0(map) + image_feature = self.conv1(image) + x = torch.add(map_feature, image_feature) + x = self.leaky_relu(x) + x = self.dropout(x) + + x = self.conv2(x) + x = self.leaky_relu(x) + x = self.dropout(x) + + x = self.conv3(x) + x = self.leaky_relu(x) + x = self.dropout(x) + + x = self.conv4(x) + x = self.leaky_relu(x) + + x = self.avgpool(x) + + x = x.view(batch_size, -1) + x = self.classifier(x) + x = x.reshape((batch_size, 2)) + # x = self.Softmax(x) + + return x + + +class FC3DDiscriminatorNIH(nn.Module): + def __init__(self, num_classes, ndf=64, n_channel=1): + super(FC3DDiscriminatorNIH, self).__init__() + # downsample 16 + self.conv0 = nn.Conv3d(num_classes, ndf, kernel_size=4, stride=2, padding=1) + self.conv1 = nn.Conv3d(n_channel, ndf, kernel_size=4, stride=2, padding=1) + + self.conv2 = nn.Conv3d(ndf, ndf*2, kernel_size=4, stride=2, padding=1) + self.conv3 = nn.Conv3d(ndf*2, ndf*4, kernel_size=4, stride=2, padding=1) + self.conv4 = nn.Conv3d(ndf*4, ndf*8, kernel_size=4, stride=2, padding=1) + self.avgpool = nn.AvgPool3d((13, 10, 9)) + self.classifier = nn.Linear(ndf*8, 2) + + self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + self.dropout = nn.Dropout3d(0.5) + self.Softmax = nn.Softmax() + + def forward(self, map, image): + batch_size = map.shape[0] + map_feature = self.conv0(map) + image_feature = self.conv1(image) + x = torch.add(map_feature, image_feature) + x = self.leaky_relu(x) + x = self.dropout(x) + + x = self.conv2(x) + x = self.leaky_relu(x) + x = self.dropout(x) + + x = self.conv3(x) + x = self.leaky_relu(x) + x = self.dropout(x) + + x = self.conv4(x) + x = self.leaky_relu(x) + + x = self.avgpool(x) + + x = x.view(batch_size, -1) + x = self.classifier(x) + x = x.reshape((batch_size, 2)) + # x = self.Softmax(x) + + return x + + +class FCDiscriminatorDAP(nn.Module): + def __init__(self, num_classes, ndf = 64): + super(FCDiscriminatorDAP, self).__init__() + + self.conv1 = nn.Conv3d(num_classes, ndf, kernel_size=4, stride=2, padding=1) + self.conv2 = nn.Conv3d(ndf, ndf*2, kernel_size=4, stride=2, padding=1) + self.conv3 = nn.Conv3d(ndf*2, ndf*4, kernel_size=4, stride=2, padding=1) + self.classifier = nn.Conv3d(ndf*4, 1, kernel_size=4, stride=2, padding=1) + + self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + self.up_sample = nn.Upsample(scale_factor=16, mode='trilinear', align_corners=True) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + x = self.conv1(x) + x = self.leaky_relu(x) + x = self.conv2(x) + x = self.leaky_relu(x) + x = self.conv3(x) + x = self.leaky_relu(x) + x = self.classifier(x) + x = self.up_sample(x) + x = self.sigmoid(x) + + return x + +if __name__ == '__main__': + # compute FLOPS & PARAMETERS + from thop import profile + from thop import clever_format + model = FC3DDiscriminator(num_classes=1) + input = torch.randn(4, 1, 112, 112, 80) + flops, params = profile(model, inputs=(input,input)) + macs, params = clever_format([flops, params], "%.3f") + print(macs, params) + + model = FCDiscriminatorDAP(num_classes=2) + input = torch.randn(4, 2, 112, 112, 80) + flops, params = profile(model, inputs=(input,)) + macs, params = clever_format([flops, params], "%.3f") + print(macs, params) + + import ipdb; ipdb.set_trace() \ No newline at end of file