Diff of /networks/discriminator.py [000000] .. [903821]

Switch to side-by-side view

--- a
+++ b/networks/discriminator.py
@@ -0,0 +1,189 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Time    : 2019/12/30 下午9:34
+# @Author  : chuyu zhang
+# @File    : discriminator.py
+# @Software: PyCharm
+
+import torch.nn as nn
+import torch.nn.functional as F
+import torch
+
+
+class FCDiscriminator(nn.Module):
+
+    def __init__(self, num_classes, ndf=64, n_channel=1):
+        super(FCDiscriminator, self).__init__()
+        self.conv0 = nn.Conv2d(num_classes, ndf, kernel_size=4, stride=2, padding=1)
+        self.conv1 = nn.Conv2d(n_channel, ndf, kernel_size=4, stride=2, padding=1)
+        self.conv2 = nn.Conv2d(ndf, ndf*2, kernel_size=4, stride=2, padding=1)
+        self.conv3 = nn.Conv2d(ndf*2, ndf*4, kernel_size=4, stride=2, padding=1)
+        self.conv4 = nn.Conv2d(ndf*4, ndf*8, kernel_size=4, stride=2, padding=1)
+        self.classifier = nn.Linear(ndf*8, 2)
+        self.avgpool = nn.AvgPool2d((7, 7))
+
+        self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+        self.dropout = nn.Dropout2d(0.5)
+        # self.up_sample = nn.Upsample(scale_factor=32, mode='bilinear')
+        # self.sigmoid = nn.Sigmoid()
+
+    def forward(self, map, feature):
+        map_feature = self.conv0(map)
+        image_feature = self.conv1(feature)
+        x = torch.add(map_feature, image_feature)
+
+        x = self.conv2(x)
+        x = self.leaky_relu(x)
+        x = self.dropout(x)
+
+        x = self.conv3(x)
+        x = self.leaky_relu(x)
+        x = self.dropout(x)
+
+        x = self.conv4(x)
+        x = self.leaky_relu(x)
+        x = self.avgpool(x)
+        x = x.view(x.size(0), -1)
+        x = self.classifier(x)
+        # x = self.up_sample(x)
+        # x = self.sigmoid(x)
+
+        return x
+
+
+class FC3DDiscriminator(nn.Module):
+
+    def __init__(self, num_classes, ndf=64, n_channel=1):
+        super(FC3DDiscriminator, self).__init__()
+        # downsample 16
+        self.conv0 = nn.Conv3d(num_classes, ndf, kernel_size=4, stride=2, padding=1)
+        self.conv1 = nn.Conv3d(n_channel, ndf, kernel_size=4, stride=2, padding=1)
+
+        self.conv2 = nn.Conv3d(ndf, ndf*2, kernel_size=4, stride=2, padding=1)
+        self.conv3 = nn.Conv3d(ndf*2, ndf*4, kernel_size=4, stride=2, padding=1)
+        self.conv4 = nn.Conv3d(ndf*4, ndf*8, kernel_size=4, stride=2, padding=1)
+        self.avgpool = nn.AvgPool3d((7, 7, 5))
+        self.classifier = nn.Linear(ndf*8, 2)
+
+        self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+        self.dropout = nn.Dropout3d(0.5)
+        self.Softmax = nn.Softmax()
+
+    def forward(self, map, image):
+        batch_size = map.shape[0]
+        map_feature = self.conv0(map)
+        image_feature = self.conv1(image)
+        x = torch.add(map_feature, image_feature)
+        x = self.leaky_relu(x)
+        x = self.dropout(x)
+
+        x = self.conv2(x)
+        x = self.leaky_relu(x)
+        x = self.dropout(x)
+
+        x = self.conv3(x)
+        x = self.leaky_relu(x)
+        x = self.dropout(x)
+
+        x = self.conv4(x)
+        x = self.leaky_relu(x)
+
+        x = self.avgpool(x)
+
+        x = x.view(batch_size, -1)
+        x = self.classifier(x)
+        x = x.reshape((batch_size, 2))
+        # x = self.Softmax(x)
+
+        return x
+
+
+class FC3DDiscriminatorNIH(nn.Module):
+    def __init__(self, num_classes, ndf=64, n_channel=1):
+        super(FC3DDiscriminatorNIH, self).__init__()
+        # downsample 16
+        self.conv0 = nn.Conv3d(num_classes, ndf, kernel_size=4, stride=2, padding=1)
+        self.conv1 = nn.Conv3d(n_channel, ndf, kernel_size=4, stride=2, padding=1)
+
+        self.conv2 = nn.Conv3d(ndf, ndf*2, kernel_size=4, stride=2, padding=1)
+        self.conv3 = nn.Conv3d(ndf*2, ndf*4, kernel_size=4, stride=2, padding=1)
+        self.conv4 = nn.Conv3d(ndf*4, ndf*8, kernel_size=4, stride=2, padding=1)
+        self.avgpool = nn.AvgPool3d((13, 10, 9))
+        self.classifier = nn.Linear(ndf*8, 2)
+
+        self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+        self.dropout = nn.Dropout3d(0.5)
+        self.Softmax = nn.Softmax()
+
+    def forward(self, map, image):
+        batch_size = map.shape[0]
+        map_feature = self.conv0(map)
+        image_feature = self.conv1(image)
+        x = torch.add(map_feature, image_feature)
+        x = self.leaky_relu(x)
+        x = self.dropout(x)
+
+        x = self.conv2(x)
+        x = self.leaky_relu(x)
+        x = self.dropout(x)
+
+        x = self.conv3(x)
+        x = self.leaky_relu(x)
+        x = self.dropout(x)
+
+        x = self.conv4(x)
+        x = self.leaky_relu(x)
+
+        x = self.avgpool(x)
+
+        x = x.view(batch_size, -1)
+        x = self.classifier(x)
+        x = x.reshape((batch_size, 2))
+        # x = self.Softmax(x)
+
+        return x
+
+
+class FCDiscriminatorDAP(nn.Module):
+    def __init__(self, num_classes, ndf = 64):
+        super(FCDiscriminatorDAP, self).__init__()
+
+        self.conv1 = nn.Conv3d(num_classes, ndf, kernel_size=4, stride=2, padding=1)
+        self.conv2 = nn.Conv3d(ndf, ndf*2, kernel_size=4, stride=2, padding=1)
+        self.conv3 = nn.Conv3d(ndf*2, ndf*4, kernel_size=4, stride=2, padding=1)
+        self.classifier = nn.Conv3d(ndf*4, 1, kernel_size=4, stride=2, padding=1)
+
+        self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+        self.up_sample = nn.Upsample(scale_factor=16, mode='trilinear', align_corners=True)
+        self.sigmoid = nn.Sigmoid()
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = self.leaky_relu(x)
+        x = self.conv2(x)
+        x = self.leaky_relu(x)
+        x = self.conv3(x)
+        x = self.leaky_relu(x)
+        x = self.classifier(x)
+        x = self.up_sample(x)
+        x = self.sigmoid(x)
+
+        return x
+
+if __name__ == '__main__':
+    # compute FLOPS & PARAMETERS
+    from thop import profile
+    from thop import clever_format
+    model = FC3DDiscriminator(num_classes=1)
+    input = torch.randn(4, 1, 112, 112, 80)
+    flops, params = profile(model, inputs=(input,input))
+    macs, params = clever_format([flops, params], "%.3f")
+    print(macs, params)
+
+    model = FCDiscriminatorDAP(num_classes=2)
+    input = torch.randn(4, 2, 112, 112, 80)
+    flops, params = profile(model, inputs=(input,))
+    macs, params = clever_format([flops, params], "%.3f")
+    print(macs, params)
+
+    import ipdb; ipdb.set_trace()
\ No newline at end of file