--- a
+++ b/inpainting/model/layer.py
@@ -0,0 +1,335 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from util.utils import gauss_kernel
+import torchvision.models as models
+import numpy as np
+
+
+class Conv2d_BN(nn.Module):
+    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
+        super(Conv2d_BN, self).__init__()
+        self.model = nn.Sequential([
+            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias),
+            nn.BatchNorm2d(out_channels)
+        ])
+
+    def forward(self, *input):
+        return self.model(*input)
+
+
+class upsampling(nn.Module):
+    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, scale=2):
+        super(upsampling, self).__init__()
+        assert isinstance(scale, int)
+        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
+                     dilation=dilation, groups=groups, bias=bias)
+        self.scale = scale
+
+    def forward(self, x):
+        h, w = x.size(2) * self.scale, x.size(3) * self.scale
+        xout = self.conv(F.interpolate(input=x, size=(h, w), mode='nearest', align_corners=True))
+        return xout
+
+
+class PureUpsampling(nn.Module):
+    def __init__(self, scale=2, mode='bilinear'):
+        super(PureUpsampling, self).__init__()
+        assert isinstance(scale, int)
+        self.scale = scale
+        self.mode = mode
+
+    def forward(self, x):
+        h, w = x.size(2) * self.scale, x.size(3) * self.scale
+        if self.mode == 'nearest':
+            xout = F.interpolate(input=x, size=(h, w), mode=self.mode)
+        else:
+            xout = F.interpolate(input=x, size=(h, w), mode=self.mode, align_corners=True)
+        return xout
+
+
+class GaussianBlurLayer(nn.Module):
+    def __init__(self, size, sigma, in_channels=1, stride=1, pad=1):
+        super(GaussianBlurLayer, self).__init__()
+        self.size = size
+        self.sigma = sigma
+        self.ch = in_channels
+        self.stride = stride
+        self.pad = nn.ReflectionPad2d(pad)
+
+    def forward(self, x):
+        kernel = gauss_kernel(self.size, self.sigma, self.ch, self.ch)
+        kernel_tensor = torch.from_numpy(kernel)
+        kernel_tensor = kernel_tensor.cuda()
+        x = self.pad(x)
+        blurred = F.conv2d(x, kernel_tensor, stride=self.stride)
+        return blurred
+
+
+class ConfidenceDrivenMaskLayer(nn.Module):
+    def __init__(self, size=65, sigma=1.0/40, iters=7):
+        super(ConfidenceDrivenMaskLayer, self).__init__()
+        self.size = size
+        self.sigma = sigma
+        self.iters = iters
+        self.propagationLayer = GaussianBlurLayer(size, sigma, pad=32)
+
+    def forward(self, mask):
+        # here mask 1 indicates missing pixels and 0 indicates the valid pixels
+        init = 1 - mask
+        mask_confidence = None
+        for i in range(self.iters):
+            mask_confidence = self.propagationLayer(init)
+            mask_confidence = mask_confidence * mask
+            init = mask_confidence + (1 - mask)
+        return mask_confidence
+
+
+class VGG19(nn.Module):
+    def __init__(self, pool='max'):
+        super(VGG19, self).__init__()
+        self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
+        self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
+        self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
+        self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
+        self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
+        self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
+        self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
+        self.conv3_4 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
+        self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
+        self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
+        self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
+        self.conv4_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
+        self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
+        self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
+        self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
+        self.conv5_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
+        if pool == 'max':
+            self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
+            self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
+            self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
+            self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
+            self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)
+        elif pool == 'avg':
+            self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2)
+            self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)
+            self.pool3 = nn.AvgPool2d(kernel_size=2, stride=2)
+            self.pool4 = nn.AvgPool2d(kernel_size=2, stride=2)
+            self.pool5 = nn.AvgPool2d(kernel_size=2, stride=2)
+
+    def forward(self, x):
+        out = {}
+        out['r11'] = F.relu(self.conv1_1(x))
+        out['r12'] = F.relu(self.conv1_2(out['r11']))
+        out['p1'] = self.pool1(out['r12'])
+        out['r21'] = F.relu(self.conv2_1(out['p1']))
+        out['r22'] = F.relu(self.conv2_2(out['r21']))
+        out['p2'] = self.pool2(out['r22'])
+        out['r31'] = F.relu(self.conv3_1(out['p2']))
+        out['r32'] = F.relu(self.conv3_2(out['r31']))
+        out['r33'] = F.relu(self.conv3_3(out['r32']))
+        out['r34'] = F.relu(self.conv3_4(out['r33']))
+        out['p3'] = self.pool3(out['r34'])
+        out['r41'] = F.relu(self.conv4_1(out['p3']))
+        out['r42'] = F.relu(self.conv4_2(out['r41']))
+        out['r43'] = F.relu(self.conv4_3(out['r42']))
+        out['r44'] = F.relu(self.conv4_4(out['r43']))
+        out['p4'] = self.pool4(out['r44'])
+        out['r51'] = F.relu(self.conv5_1(out['p4']))
+        out['r52'] = F.relu(self.conv5_2(out['r51']))
+        out['r53'] = F.relu(self.conv5_3(out['r52']))
+        out['r54'] = F.relu(self.conv5_4(out['r53']))
+        out['p5'] = self.pool5(out['r54'])
+        return out
+
+
+class VGG19FeatLayer(nn.Module):
+    def __init__(self):
+        super(VGG19FeatLayer, self).__init__()
+        self.vgg19 = models.vgg19(pretrained=True).features.eval().cuda()
+        self.mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).cuda()
+
+    def forward(self, x):
+        out = {}
+        x = x - self.mean
+        ci = 1
+        ri = 0
+        for layer in self.vgg19.children():
+            if isinstance(layer, nn.Conv2d):
+                ri += 1
+                name = 'conv{}_{}'.format(ci, ri)
+            elif isinstance(layer, nn.ReLU):
+                ri += 1
+                name = 'relu{}_{}'.format(ci, ri)
+                layer = nn.ReLU(inplace=False)
+            elif isinstance(layer, nn.MaxPool2d):
+                ri = 0
+                name = 'pool_{}'.format(ci)
+                ci += 1
+            elif isinstance(layer, nn.BatchNorm2d):
+                name = 'bn_{}'.format(ci)
+            else:
+                raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__))
+            x = layer(x)
+            out[name] = x
+        # print([x for x in out])
+        return out
+
+
+def init_weights(net, init_type='normal', gain=0.02):
+    def init_func(m):
+        classname = m.__class__.__name__
+        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
+            if init_type == 'normal':
+                nn.init.normal_(m.weight.data, 0.0, gain)
+            elif init_type == 'xavier':
+                nn.init.xavier_normal_(m.weight.data, gain=gain)
+            elif init_type == 'kaiming':
+                nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
+            elif init_type == 'orthogonal':
+                nn.init.orthogonal_(m.weight.data, gain=gain)
+            else:
+                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
+            if hasattr(m, 'bias') and m.bias is not None:
+                nn.init.constant_(m.bias.data, 0.0)
+        elif classname.find('BatchNorm2d') != -1:
+            nn.init.normal_(m.weight.data, 1.0, gain)
+            nn.init.constant_(m.bias.data, 0.0)
+
+    print('initialize network with %s' % init_type)
+    net.apply(init_func)
+
+
+def init_net(net, init_type='normal', gpu_ids=[]):
+    if len(gpu_ids) > 0:
+        assert(torch.cuda.is_available())
+        net.to(gpu_ids[0])
+        net = torch.nn.DataParallel(net, gpu_ids)
+    init_weights(net, init_type)
+    return net
+
+
+def l2normalize(v, eps=1e-12):
+    return v / (v.norm()+eps)
+
+
+class SpectralNorm(nn.Module):
+    def __init__(self, module, name='weight', power_iteration=1):
+        super(SpectralNorm, self).__init__()
+        self.module = module
+        self.name = name
+        self.power_iteration = power_iteration
+        if not self._made_params():
+            self._make_params()
+
+    def _update_u_v(self):
+        u = getattr(self.module, self.name + '_u')
+        v = getattr(self.module, self.name + '_v')
+        w = getattr(self.module, self.name + '_bar')
+
+        height = w.data.shape[0]
+        for _ in range(self.power_iteration):
+            v.data = l2normalize(torch.mv(torch.t(w.view(height, -1).data), u.data))
+            u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data))
+
+        sigma = u.dot(w.view(height, -1).mv(v))
+        setattr(self.module, self.name, w / sigma.expand_as(w))
+
+    def _made_params(self):
+        try:
+            u = getattr(self.module, self.name + '_u')
+            v = getattr(self.module, self.name + '_v')
+            w = getattr(self.module, self.name + '_bar')
+            return True
+        except AttributeError:
+            return False
+
+    def _make_params(self):
+        w = getattr(self.module, self.name)
+
+        height = w.data.shape[0]
+        width = w.view(height, -1).data.shape[1]
+
+        u = nn.Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
+        v = nn.Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
+        u.data = l2normalize(u.data)
+        v.data = l2normalize(v.data)
+        w_bar = nn.Parameter(w.data)
+
+        del self.module._parameters[self.name]
+
+        self.module.register_parameter(self.name+'_u', u)
+        self.module.register_parameter(self.name+'_v', v)
+        self.module.register_parameter(self.name+'_bar', w_bar)
+
+    def forward(self, *input):
+        self._update_u_v()
+        return self.module.forward(*input)
+
+
+class PartialConv(nn.Module):
+    def __init__(self, in_channels=3, out_channels=32, ksize=3, stride=1):
+        super(PartialConv, self).__init__()
+        self.ksize = ksize
+        self.stride = stride
+        self.fnum = 32
+        self.padSize = self.ksize // 2
+        self.pad = nn.ReflectionPad2d(self.padSize)
+        self.eplison = 1e-5
+        self.conv = nn.Conv2d(in_channels, out_channels, stride=stride, kernel_size=ksize)
+
+    def forward(self, x, mask):
+
+        mask_ch = mask.size(1)
+        sum_kernel_np = np.ones((mask_ch, mask_ch, self.ksize, self.ksize), dtype=np.float32)
+        sum_kernel = torch.from_numpy(sum_kernel_np).cuda()
+
+        x = x * mask / (F.conv2d(mask, sum_kernel, stride=1, padding=self.padSize)+self.eplison)
+        x = self.pad(x)
+        x = self.conv(x)
+        mask = F.max_pool2d(mask, self.ksize, stride=self.stride, padding=self.padSize)
+        return x, mask
+
+
+class GatedConv(nn.Module):
+    def __init__(self, in_channels=3, out_channels=32, ksize=3, stride=1, act=F.elu):
+        super(GatedConv, self).__init__()
+        self.ksize = ksize
+        self.stride = stride
+        self.act = act
+        self.padSize = self.ksize // 2
+        self.pad = nn.ReflectionPad2d(self.padSize)
+        self.convf = nn.Conv2d(in_channels, out_channels, stride=stride, kernel_size=ksize)
+        self.convm = nn.Conv2d(in_channels, out_channels, stride=stride, kernel_size=ksize,
+                               padding=self.padSize)
+
+    def forward(self, x):
+        x = self.pad(x)
+        x = self.convf(x)
+        x = self.act(x)
+        m = self.convm(x)
+        m = F.sigmoid(m)
+        x = x * m
+        return x
+
+
+class GatedDilatedConv(nn.Module):
+    def __init__(self, in_channels, out_channels, ksize=3, stride=1, pad=1, dilation=2, act=F.elu):
+        super(GatedDilatedConv, self).__init__()
+        self.ksize = ksize
+        self.stride = stride
+        self.act = act
+        self.padSize = pad
+        self.pad = nn.ReflectionPad2d(self.padSize)
+        self.convf = nn.Conv2d(in_channels, out_channels, stride=stride, kernel_size=ksize, dilation=dilation)
+        self.convm = nn.Conv2d(in_channels, out_channels, stride=stride, kernel_size=ksize, dilation=dilation,
+                               padding=self.padSize)
+
+    def forward(self, x):
+        x = self.pad(x)
+        x = self.convf(x)
+        x = self.act(x)
+        m = self.convm(x)
+        m = F.sigmoid(m)
+        x = x * m
+        return x