Diff of /inpainting/model/net.py [000000] .. [92cc18]

Switch to side-by-side view

--- a
+++ b/inpainting/model/net.py
@@ -0,0 +1,356 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from inpainting.model.basemodel import BaseModel
+from inpainting.model.basenet import BaseNet
+from inpainting.model.loss import WGANLoss, IDMRFLoss
+from inpainting.model.layer import init_weights, PureUpsampling, ConfidenceDrivenMaskLayer, SpectralNorm
+import numpy as np
+
+
+# generative multi-column convolutional neural net
+class GMCNN(BaseNet):
+    def __init__(self, in_channels, out_channels, cnum=32, act=F.elu, norm=F.instance_norm, using_norm=False):
+        super(GMCNN, self).__init__()
+        self.act = act
+        self.using_norm = using_norm
+        if using_norm is True:
+            self.norm = norm
+        else:
+            self.norm = None
+        ch = cnum
+
+        # network structure
+        self.EB1 = []
+        self.EB2 = []
+        self.EB3 = []
+        self.decoding_layers = []
+
+        self.EB1_pad_rec = []
+        self.EB2_pad_rec = []
+        self.EB3_pad_rec = []
+
+        self.EB1.append(nn.Conv2d(in_channels, ch, kernel_size=7, stride=1))
+
+        self.EB1.append(nn.Conv2d(ch, ch * 2, kernel_size=7, stride=2))
+        self.EB1.append(nn.Conv2d(ch * 2, ch * 2, kernel_size=7, stride=1))
+
+        self.EB1.append(nn.Conv2d(ch * 2, ch * 4, kernel_size=7, stride=2))
+        self.EB1.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=7, stride=1))
+        self.EB1.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=7, stride=1))
+
+        self.EB1.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=7, stride=1, dilation=2))
+        self.EB1.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=7, stride=1, dilation=4))
+        self.EB1.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=7, stride=1, dilation=8))
+        self.EB1.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=7, stride=1, dilation=16))
+
+        self.EB1.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=7, stride=1))
+        self.EB1.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=7, stride=1))
+
+        self.EB1.append(PureUpsampling(scale=4))
+
+        self.EB1_pad_rec = [3, 3, 3, 3, 3, 3, 6, 12, 24, 48, 3, 3, 0]
+
+        self.EB2.append(nn.Conv2d(in_channels, ch, kernel_size=5, stride=1))
+
+        self.EB2.append(nn.Conv2d(ch, ch * 2, kernel_size=5, stride=2))
+        self.EB2.append(nn.Conv2d(ch * 2, ch * 2, kernel_size=5, stride=1))
+
+        self.EB2.append(nn.Conv2d(ch * 2, ch * 4, kernel_size=5, stride=2))
+        self.EB2.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, stride=1))
+        self.EB2.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, stride=1))
+
+        self.EB2.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, stride=1, dilation=2))
+        self.EB2.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, stride=1, dilation=4))
+        self.EB2.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, stride=1, dilation=8))
+        self.EB2.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, stride=1, dilation=16))
+
+        self.EB2.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, stride=1))
+        self.EB2.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, stride=1))
+
+        self.EB2.append(PureUpsampling(scale=2, mode='nearest'))
+        self.EB2.append(nn.Conv2d(ch * 4, ch * 2, kernel_size=5, stride=1))
+        self.EB2.append(nn.Conv2d(ch * 2, ch * 2, kernel_size=5, stride=1))
+        self.EB2.append(PureUpsampling(scale=2))
+        self.EB2_pad_rec = [2, 2, 2, 2, 2, 2, 4, 8, 16, 32, 2, 2, 0, 2, 2, 0]
+
+        self.EB3.append(nn.Conv2d(in_channels, ch, kernel_size=3, stride=1))
+
+        self.EB3.append(nn.Conv2d(ch, ch * 2, kernel_size=3, stride=2))
+        self.EB3.append(nn.Conv2d(ch * 2, ch * 2, kernel_size=3, stride=1))
+
+        self.EB3.append(nn.Conv2d(ch * 2, ch * 4, kernel_size=3, stride=2))
+        self.EB3.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=3, stride=1))
+        self.EB3.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=3, stride=1))
+
+        self.EB3.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=3, stride=1, dilation=2))
+        self.EB3.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=3, stride=1, dilation=4))
+        self.EB3.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=3, stride=1, dilation=8))
+        self.EB3.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=3, stride=1, dilation=16))
+
+        self.EB3.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=3, stride=1))
+        self.EB3.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=3, stride=1))
+
+        self.EB3.append(PureUpsampling(scale=2, mode='nearest'))
+        self.EB3.append(nn.Conv2d(ch * 4, ch * 2, kernel_size=3, stride=1))
+        self.EB3.append(nn.Conv2d(ch * 2, ch * 2, kernel_size=3, stride=1))
+        self.EB3.append(PureUpsampling(scale=2, mode='nearest'))
+        self.EB3.append(nn.Conv2d(ch * 2, ch, kernel_size=3, stride=1))
+        self.EB3.append(nn.Conv2d(ch, ch, kernel_size=3, stride=1))
+
+        self.EB3_pad_rec = [1, 1, 1, 1, 1, 1, 2, 4, 8, 16, 1, 1, 0, 1, 1, 0, 1, 1]
+
+        self.decoding_layers.append(nn.Conv2d(ch * 7, ch // 2, kernel_size=3, stride=1))
+        self.decoding_layers.append(nn.Conv2d(ch // 2, out_channels, kernel_size=3, stride=1))
+
+        self.decoding_pad_rec = [1, 1]
+
+        self.EB1 = nn.ModuleList(self.EB1)
+        self.EB2 = nn.ModuleList(self.EB2)
+        self.EB3 = nn.ModuleList(self.EB3)
+        self.decoding_layers = nn.ModuleList(self.decoding_layers)
+
+        # padding operations
+        padlen = 49
+        self.pads = [0] * padlen
+        for i in range(padlen):
+            self.pads[i] = nn.ReflectionPad2d(i)
+        self.pads = nn.ModuleList(self.pads)
+
+    def forward(self, x):
+        x1, x2, x3 = x, x, x
+        for i, layer in enumerate(self.EB1):
+            pad_idx = self.EB1_pad_rec[i]
+            x1 = layer(self.pads[pad_idx](x1))
+            if self.using_norm:
+                x1 = self.norm(x1)
+            if pad_idx != 0:
+                x1 = self.act(x1)
+
+        for i, layer in enumerate(self.EB2):
+            pad_idx = self.EB2_pad_rec[i]
+            x2 = layer(self.pads[pad_idx](x2))
+            if self.using_norm:
+                x2 = self.norm(x2)
+            if pad_idx != 0:
+                x2 = self.act(x2)
+
+        for i, layer in enumerate(self.EB3):
+            pad_idx = self.EB3_pad_rec[i]
+            x3 = layer(self.pads[pad_idx](x3))
+            if self.using_norm:
+                x3 = self.norm(x3)
+            if pad_idx != 0:
+                x3 = self.act(x3)
+
+        x_d = torch.cat((x1, x2, x3), 1)
+        x_d = self.act(self.decoding_layers[0](self.pads[self.decoding_pad_rec[0]](x_d)))
+        x_d = self.decoding_layers[1](self.pads[self.decoding_pad_rec[1]](x_d))
+        x_out = torch.clamp(x_d, -1, 1)
+        return x_out
+
+
+# return one dimensional output indicating the probability of realness or fakeness
+class Discriminator(BaseNet):
+    def __init__(self, in_channels, cnum=32, fc_channels=8 * 8 * 32 * 4, act=F.elu, norm=None, spectral_norm=True):
+        super(Discriminator, self).__init__()
+        self.act = act
+        self.norm = norm
+        self.embedding = None
+        self.logit = None
+
+        ch = cnum
+        self.layers = []
+        if spectral_norm:
+            self.layers.append(SpectralNorm(nn.Conv2d(in_channels, ch, kernel_size=5, padding=2, stride=2)))
+            self.layers.append(SpectralNorm(nn.Conv2d(ch, ch * 2, kernel_size=5, padding=2, stride=2)))
+            self.layers.append(SpectralNorm(nn.Conv2d(ch * 2, ch * 4, kernel_size=5, padding=2, stride=2)))
+            self.layers.append(SpectralNorm(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, padding=2, stride=2)))
+            self.layers.append(SpectralNorm(nn.Linear(fc_channels, 1)))
+        else:
+            self.layers.append(nn.Conv2d(in_channels, ch, kernel_size=5, padding=2, stride=2))
+            self.layers.append(nn.Conv2d(ch, ch * 2, kernel_size=5, padding=2, stride=2))
+            self.layers.append(nn.Conv2d(ch * 2, ch * 4, kernel_size=5, padding=2, stride=2))
+            self.layers.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, padding=2, stride=2))
+            self.layers.append(nn.Linear(fc_channels, 1))
+        self.layers = nn.ModuleList(self.layers)
+
+    def forward(self, x):
+        for layer in self.layers[:-1]:
+            x = layer(x)
+            if self.norm is not None:
+                x = self.norm(x)
+            x = self.act(x)
+        self.embedding = x.view(x.size(0), -1)
+        self.logit = self.layers[-1](self.embedding)
+        return self.logit
+
+
+class GlobalLocalDiscriminator(BaseNet):
+    def __init__(self, in_channels, cnum=32, g_fc_channels=16 * 16 * 32 * 4, l_fc_channels=8 * 8 * 32 * 4, act=F.elu,
+                 norm=None,
+                 spectral_norm=True):
+        super(GlobalLocalDiscriminator, self).__init__()
+        self.act = act
+        self.norm = norm
+
+        self.global_discriminator = Discriminator(in_channels=in_channels, fc_channels=g_fc_channels, cnum=cnum,
+                                                  act=act, norm=norm, spectral_norm=spectral_norm)
+        self.local_discriminator = Discriminator(in_channels=in_channels, fc_channels=l_fc_channels, cnum=cnum,
+                                                 act=act, norm=norm, spectral_norm=spectral_norm)
+
+    def forward(self, x_g, x_l):
+        x_global = self.global_discriminator(x_g)
+        x_local = self.local_discriminator(x_l)
+        return x_global, x_local
+
+
+from inpainting.util.utils import generate_mask
+
+
+class InpaintingModel_GMCNN(BaseModel):
+    def __init__(self, in_channels, act=F.elu, norm=None, opt=None):
+        super(InpaintingModel_GMCNN, self).__init__()
+        self.opt = opt
+        self.init(opt)
+
+        self.confidence_mask_layer = ConfidenceDrivenMaskLayer()
+
+        self.netGM = GMCNN(in_channels, out_channels=3, cnum=opt.g_cnum, act=act, norm=norm).cuda()
+        init_weights(self.netGM)
+        self.model_names = ['GM']
+        if self.opt.phase == 'test':
+            return
+
+        self.netD = None
+
+        self.optimizer_G = torch.optim.Adam(self.netGM.parameters(), lr=opt.lr, betas=(0.5, 0.9))
+        self.optimizer_D = None
+
+        self.wganloss = None
+        self.recloss = nn.L1Loss()
+        self.aeloss = nn.L1Loss()
+        self.mrfloss = None
+        self.lambda_adv = opt.lambda_adv
+        self.lambda_rec = opt.lambda_rec
+        self.lambda_ae = opt.lambda_ae
+        self.lambda_gp = opt.lambda_gp
+        self.lambda_mrf = opt.lambda_mrf
+        self.G_loss = None
+        self.G_loss_reconstruction = None
+        self.G_loss_mrf = None
+        self.G_loss_adv, self.G_loss_adv_local = None, None
+        self.G_loss_ae = None
+        self.D_loss, self.D_loss_local = None, None
+        self.GAN_loss = None
+
+        self.gt, self.gt_local = None, None
+        self.mask, self.mask_01 = None, None
+        self.rect = None
+        self.im_in, self.gin = None, None
+
+        self.completed, self.completed_local = None, None
+        self.completed_logit, self.completed_local_logit = None, None
+        self.gt_logit, self.gt_local_logit = None, None
+
+        self.pred = None
+
+        if self.opt.pretrain_network is False:
+            self.netD = GlobalLocalDiscriminator(3, cnum=opt.d_cnum, act=act,
+                                                 spectral_norm=self.opt.spectral_norm,
+                                                 g_fc_channels=opt.img_shapes[0] // 16 * opt.img_shapes[
+                                                     1] // 16 * opt.d_cnum * 4,
+                                                 l_fc_channels=opt.img_shapes[0] // 16 * opt.img_shapes[
+                                                     1] // 16 * opt.d_cnum * 4).cuda()
+            init_weights(self.netD)
+            self.optimizer_D = torch.optim.Adam(filter(lambda x: x.requires_grad, self.netD.parameters()), lr=opt.lr,
+                                                betas=(0.5, 0.9))
+            self.wganloss = WGANLoss()
+            self.mrfloss = IDMRFLoss()
+
+    def initVariables(self):
+        self.gt = self.img
+        mask = self.mask
+        # mask, rect = generate_mask(self.opt.mask_type, self.opt.img_shapes, self.opt.mask_shapes)
+        # self.mask_01 = torch.from_numpy(mask).cuda().repeat([self.opt.batch_size, 1, 1, 1])
+        self.mask = self.confidence_mask_layer(mask)
+        self.gt_local = self.gt
+        self.im_in = self.gt * (1 - self.mask_01)
+        self.gin = torch.cat((self.im_in, self.mask_01), 1)
+
+    def forward_G(self):
+        self.G_loss_reconstruction = self.recloss(self.completed * self.mask, self.gt.detach() * self.mask)
+        self.G_loss_reconstruction = self.G_loss_reconstruction / torch.mean(self.mask_01)
+        self.G_loss_ae = self.aeloss(self.pred * (1 - self.mask_01), self.gt.detach() * (1 - self.mask_01))
+        self.G_loss_ae = self.G_loss_ae / torch.mean(1 - self.mask_01)
+        self.G_loss = self.lambda_rec * self.G_loss_reconstruction + self.lambda_ae * self.G_loss_ae
+
+        self.completed_logit, self.completed_local_logit = self.netD(self.completed, self.completed_local)
+        self.G_loss_mrf = self.mrfloss((self.completed_local + 1) / 2.0, (self.gt_local.detach() + 1) / 2.0)
+        self.G_loss = self.G_loss + self.lambda_mrf * self.G_loss_mrf
+
+        self.G_loss_adv = -self.completed_logit.mean()
+        self.G_loss_adv_local = -self.completed_local_logit.mean()
+        self.G_loss = self.G_loss + self.lambda_adv * (self.G_loss_adv + self.G_loss_adv_local)
+
+    def forward_D(self):
+        self.completed_logit, self.completed_local_logit = self.netD(self.completed.detach(),
+                                                                     self.completed_local.detach())
+        self.gt_logit, self.gt_local_logit = self.netD(self.gt, self.gt_local)
+        # hinge loss
+        self.D_loss_local = nn.ReLU()(1.0 - self.gt_local_logit).mean() + nn.ReLU()(
+            1.0 + self.completed_local_logit).mean()
+        self.D_loss = nn.ReLU()(1.0 - self.gt_logit).mean() + nn.ReLU()(1.0 + self.completed_logit).mean()
+        self.D_loss = self.D_loss + self.D_loss_local
+
+    def backward_G(self):
+        self.G_loss.backward()
+
+    def backward_D(self):
+        self.D_loss.backward(retain_graph=True)
+
+    def optimize_parameters(self):
+        self.initVariables()
+
+        self.pred = self.netGM(self.gin)
+        self.completed = self.pred * self.mask_01 + self.gt * (1 - self.mask_01)
+        self.completed_local = self.completed
+
+        for i in range(self.opt.D_max_iters):
+            self.optimizer_D.zero_grad()
+            self.optimizer_G.zero_grad()
+            self.forward_D()
+            self.backward_D()
+            self.optimizer_D.step()
+
+        self.optimizer_G.zero_grad()
+        self.forward_G()
+        self.backward_G()
+        self.optimizer_G.step()
+
+    def get_current_losses(self):
+        l = {'G_loss': self.G_loss.item(), 'G_loss_rec': self.G_loss_reconstruction.item(),
+             'G_loss_ae': self.G_loss_ae.item()}
+        if self.opt.pretrain_network is False:
+            l.update({'G_loss_adv': self.G_loss_adv.item(),
+                      'G_loss_adv_local': self.G_loss_adv_local.item(),
+                      'D_loss': self.D_loss.item(),
+                      'G_loss_mrf': self.G_loss_mrf.item()})
+        return l
+
+    def get_current_visuals(self):
+        return {'input': self.im_in.cpu().detach().numpy(), 'gt': self.gt.cpu().detach().numpy(),
+                'completed': self.completed.cpu().detach().numpy()}
+
+    def get_current_visuals_tensor(self):
+        return {'input': self.im_in.cpu().detach(), 'gt': self.gt.cpu().detach(),
+                'completed': self.completed.cpu().detach()}
+
+    def evaluate(self, im_in, mask):
+        im_in = torch.from_numpy(im_in).type(torch.FloatTensor).cuda() / 127.5 - 1
+        mask = torch.from_numpy(mask).type(torch.FloatTensor).cuda()
+        im_in = im_in * (1 - mask)
+        xin = torch.cat((im_in, mask), 1)
+        ret = self.netGM(xin) * mask + im_in * (1 - mask)
+        ret = (ret.cpu().detach().numpy() + 1) * 127.5
+        return ret.astype(np.uint8)