--- a +++ b/inpainting/model/layer.py @@ -0,0 +1,335 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from util.utils import gauss_kernel +import torchvision.models as models +import numpy as np + + +class Conv2d_BN(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): + super(Conv2d_BN, self).__init__() + self.model = nn.Sequential([ + nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias), + nn.BatchNorm2d(out_channels) + ]) + + def forward(self, *input): + return self.model(*input) + + +class upsampling(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, scale=2): + super(upsampling, self).__init__() + assert isinstance(scale, int) + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, + dilation=dilation, groups=groups, bias=bias) + self.scale = scale + + def forward(self, x): + h, w = x.size(2) * self.scale, x.size(3) * self.scale + xout = self.conv(F.interpolate(input=x, size=(h, w), mode='nearest', align_corners=True)) + return xout + + +class PureUpsampling(nn.Module): + def __init__(self, scale=2, mode='bilinear'): + super(PureUpsampling, self).__init__() + assert isinstance(scale, int) + self.scale = scale + self.mode = mode + + def forward(self, x): + h, w = x.size(2) * self.scale, x.size(3) * self.scale + if self.mode == 'nearest': + xout = F.interpolate(input=x, size=(h, w), mode=self.mode) + else: + xout = F.interpolate(input=x, size=(h, w), mode=self.mode, align_corners=True) + return xout + + +class GaussianBlurLayer(nn.Module): + def __init__(self, size, sigma, in_channels=1, stride=1, pad=1): + super(GaussianBlurLayer, self).__init__() + self.size = size + self.sigma = sigma + self.ch = in_channels + self.stride = stride + self.pad = nn.ReflectionPad2d(pad) + + def forward(self, x): + kernel = gauss_kernel(self.size, self.sigma, self.ch, self.ch) + kernel_tensor = torch.from_numpy(kernel) + kernel_tensor = kernel_tensor.cuda() + x = self.pad(x) + blurred = F.conv2d(x, kernel_tensor, stride=self.stride) + return blurred + + +class ConfidenceDrivenMaskLayer(nn.Module): + def __init__(self, size=65, sigma=1.0/40, iters=7): + super(ConfidenceDrivenMaskLayer, self).__init__() + self.size = size + self.sigma = sigma + self.iters = iters + self.propagationLayer = GaussianBlurLayer(size, sigma, pad=32) + + def forward(self, mask): + # here mask 1 indicates missing pixels and 0 indicates the valid pixels + init = 1 - mask + mask_confidence = None + for i in range(self.iters): + mask_confidence = self.propagationLayer(init) + mask_confidence = mask_confidence * mask + init = mask_confidence + (1 - mask) + return mask_confidence + + +class VGG19(nn.Module): + def __init__(self, pool='max'): + super(VGG19, self).__init__() + self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, padding=1) + self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1) + self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, padding=1) + self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, padding=1) + self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, padding=1) + self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, padding=1) + self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, padding=1) + self.conv3_4 = nn.Conv2d(256, 256, kernel_size=3, padding=1) + self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, padding=1) + self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1) + self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1) + self.conv4_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1) + self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1) + self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1) + self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1) + self.conv5_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1) + if pool == 'max': + self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) + self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) + self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) + self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) + self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2) + elif pool == 'avg': + self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2) + self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2) + self.pool3 = nn.AvgPool2d(kernel_size=2, stride=2) + self.pool4 = nn.AvgPool2d(kernel_size=2, stride=2) + self.pool5 = nn.AvgPool2d(kernel_size=2, stride=2) + + def forward(self, x): + out = {} + out['r11'] = F.relu(self.conv1_1(x)) + out['r12'] = F.relu(self.conv1_2(out['r11'])) + out['p1'] = self.pool1(out['r12']) + out['r21'] = F.relu(self.conv2_1(out['p1'])) + out['r22'] = F.relu(self.conv2_2(out['r21'])) + out['p2'] = self.pool2(out['r22']) + out['r31'] = F.relu(self.conv3_1(out['p2'])) + out['r32'] = F.relu(self.conv3_2(out['r31'])) + out['r33'] = F.relu(self.conv3_3(out['r32'])) + out['r34'] = F.relu(self.conv3_4(out['r33'])) + out['p3'] = self.pool3(out['r34']) + out['r41'] = F.relu(self.conv4_1(out['p3'])) + out['r42'] = F.relu(self.conv4_2(out['r41'])) + out['r43'] = F.relu(self.conv4_3(out['r42'])) + out['r44'] = F.relu(self.conv4_4(out['r43'])) + out['p4'] = self.pool4(out['r44']) + out['r51'] = F.relu(self.conv5_1(out['p4'])) + out['r52'] = F.relu(self.conv5_2(out['r51'])) + out['r53'] = F.relu(self.conv5_3(out['r52'])) + out['r54'] = F.relu(self.conv5_4(out['r53'])) + out['p5'] = self.pool5(out['r54']) + return out + + +class VGG19FeatLayer(nn.Module): + def __init__(self): + super(VGG19FeatLayer, self).__init__() + self.vgg19 = models.vgg19(pretrained=True).features.eval().cuda() + self.mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).cuda() + + def forward(self, x): + out = {} + x = x - self.mean + ci = 1 + ri = 0 + for layer in self.vgg19.children(): + if isinstance(layer, nn.Conv2d): + ri += 1 + name = 'conv{}_{}'.format(ci, ri) + elif isinstance(layer, nn.ReLU): + ri += 1 + name = 'relu{}_{}'.format(ci, ri) + layer = nn.ReLU(inplace=False) + elif isinstance(layer, nn.MaxPool2d): + ri = 0 + name = 'pool_{}'.format(ci) + ci += 1 + elif isinstance(layer, nn.BatchNorm2d): + name = 'bn_{}'.format(ci) + else: + raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__)) + x = layer(x) + out[name] = x + # print([x for x in out]) + return out + + +def init_weights(net, init_type='normal', gain=0.02): + def init_func(m): + classname = m.__class__.__name__ + if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): + if init_type == 'normal': + nn.init.normal_(m.weight.data, 0.0, gain) + elif init_type == 'xavier': + nn.init.xavier_normal_(m.weight.data, gain=gain) + elif init_type == 'kaiming': + nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + nn.init.orthogonal_(m.weight.data, gain=gain) + else: + raise NotImplementedError('initialization method [%s] is not implemented' % init_type) + if hasattr(m, 'bias') and m.bias is not None: + nn.init.constant_(m.bias.data, 0.0) + elif classname.find('BatchNorm2d') != -1: + nn.init.normal_(m.weight.data, 1.0, gain) + nn.init.constant_(m.bias.data, 0.0) + + print('initialize network with %s' % init_type) + net.apply(init_func) + + +def init_net(net, init_type='normal', gpu_ids=[]): + if len(gpu_ids) > 0: + assert(torch.cuda.is_available()) + net.to(gpu_ids[0]) + net = torch.nn.DataParallel(net, gpu_ids) + init_weights(net, init_type) + return net + + +def l2normalize(v, eps=1e-12): + return v / (v.norm()+eps) + + +class SpectralNorm(nn.Module): + def __init__(self, module, name='weight', power_iteration=1): + super(SpectralNorm, self).__init__() + self.module = module + self.name = name + self.power_iteration = power_iteration + if not self._made_params(): + self._make_params() + + def _update_u_v(self): + u = getattr(self.module, self.name + '_u') + v = getattr(self.module, self.name + '_v') + w = getattr(self.module, self.name + '_bar') + + height = w.data.shape[0] + for _ in range(self.power_iteration): + v.data = l2normalize(torch.mv(torch.t(w.view(height, -1).data), u.data)) + u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data)) + + sigma = u.dot(w.view(height, -1).mv(v)) + setattr(self.module, self.name, w / sigma.expand_as(w)) + + def _made_params(self): + try: + u = getattr(self.module, self.name + '_u') + v = getattr(self.module, self.name + '_v') + w = getattr(self.module, self.name + '_bar') + return True + except AttributeError: + return False + + def _make_params(self): + w = getattr(self.module, self.name) + + height = w.data.shape[0] + width = w.view(height, -1).data.shape[1] + + u = nn.Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) + v = nn.Parameter(w.data.new(width).normal_(0, 1), requires_grad=False) + u.data = l2normalize(u.data) + v.data = l2normalize(v.data) + w_bar = nn.Parameter(w.data) + + del self.module._parameters[self.name] + + self.module.register_parameter(self.name+'_u', u) + self.module.register_parameter(self.name+'_v', v) + self.module.register_parameter(self.name+'_bar', w_bar) + + def forward(self, *input): + self._update_u_v() + return self.module.forward(*input) + + +class PartialConv(nn.Module): + def __init__(self, in_channels=3, out_channels=32, ksize=3, stride=1): + super(PartialConv, self).__init__() + self.ksize = ksize + self.stride = stride + self.fnum = 32 + self.padSize = self.ksize // 2 + self.pad = nn.ReflectionPad2d(self.padSize) + self.eplison = 1e-5 + self.conv = nn.Conv2d(in_channels, out_channels, stride=stride, kernel_size=ksize) + + def forward(self, x, mask): + + mask_ch = mask.size(1) + sum_kernel_np = np.ones((mask_ch, mask_ch, self.ksize, self.ksize), dtype=np.float32) + sum_kernel = torch.from_numpy(sum_kernel_np).cuda() + + x = x * mask / (F.conv2d(mask, sum_kernel, stride=1, padding=self.padSize)+self.eplison) + x = self.pad(x) + x = self.conv(x) + mask = F.max_pool2d(mask, self.ksize, stride=self.stride, padding=self.padSize) + return x, mask + + +class GatedConv(nn.Module): + def __init__(self, in_channels=3, out_channels=32, ksize=3, stride=1, act=F.elu): + super(GatedConv, self).__init__() + self.ksize = ksize + self.stride = stride + self.act = act + self.padSize = self.ksize // 2 + self.pad = nn.ReflectionPad2d(self.padSize) + self.convf = nn.Conv2d(in_channels, out_channels, stride=stride, kernel_size=ksize) + self.convm = nn.Conv2d(in_channels, out_channels, stride=stride, kernel_size=ksize, + padding=self.padSize) + + def forward(self, x): + x = self.pad(x) + x = self.convf(x) + x = self.act(x) + m = self.convm(x) + m = F.sigmoid(m) + x = x * m + return x + + +class GatedDilatedConv(nn.Module): + def __init__(self, in_channels, out_channels, ksize=3, stride=1, pad=1, dilation=2, act=F.elu): + super(GatedDilatedConv, self).__init__() + self.ksize = ksize + self.stride = stride + self.act = act + self.padSize = pad + self.pad = nn.ReflectionPad2d(self.padSize) + self.convf = nn.Conv2d(in_channels, out_channels, stride=stride, kernel_size=ksize, dilation=dilation) + self.convm = nn.Conv2d(in_channels, out_channels, stride=stride, kernel_size=ksize, dilation=dilation, + padding=self.padSize) + + def forward(self, x): + x = self.pad(x) + x = self.convf(x) + x = self.act(x) + m = self.convm(x) + m = F.sigmoid(m) + x = x * m + return x