Diff of /inpainting/model/net.py [000000] .. [92cc18]

Switch to unified view

a b/inpainting/model/net.py
1
import torch
2
import torch.nn as nn
3
import torch.nn.functional as F
4
from inpainting.model.basemodel import BaseModel
5
from inpainting.model.basenet import BaseNet
6
from inpainting.model.loss import WGANLoss, IDMRFLoss
7
from inpainting.model.layer import init_weights, PureUpsampling, ConfidenceDrivenMaskLayer, SpectralNorm
8
import numpy as np
9
10
11
# generative multi-column convolutional neural net
12
class GMCNN(BaseNet):
13
    def __init__(self, in_channels, out_channels, cnum=32, act=F.elu, norm=F.instance_norm, using_norm=False):
14
        super(GMCNN, self).__init__()
15
        self.act = act
16
        self.using_norm = using_norm
17
        if using_norm is True:
18
            self.norm = norm
19
        else:
20
            self.norm = None
21
        ch = cnum
22
23
        # network structure
24
        self.EB1 = []
25
        self.EB2 = []
26
        self.EB3 = []
27
        self.decoding_layers = []
28
29
        self.EB1_pad_rec = []
30
        self.EB2_pad_rec = []
31
        self.EB3_pad_rec = []
32
33
        self.EB1.append(nn.Conv2d(in_channels, ch, kernel_size=7, stride=1))
34
35
        self.EB1.append(nn.Conv2d(ch, ch * 2, kernel_size=7, stride=2))
36
        self.EB1.append(nn.Conv2d(ch * 2, ch * 2, kernel_size=7, stride=1))
37
38
        self.EB1.append(nn.Conv2d(ch * 2, ch * 4, kernel_size=7, stride=2))
39
        self.EB1.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=7, stride=1))
40
        self.EB1.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=7, stride=1))
41
42
        self.EB1.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=7, stride=1, dilation=2))
43
        self.EB1.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=7, stride=1, dilation=4))
44
        self.EB1.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=7, stride=1, dilation=8))
45
        self.EB1.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=7, stride=1, dilation=16))
46
47
        self.EB1.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=7, stride=1))
48
        self.EB1.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=7, stride=1))
49
50
        self.EB1.append(PureUpsampling(scale=4))
51
52
        self.EB1_pad_rec = [3, 3, 3, 3, 3, 3, 6, 12, 24, 48, 3, 3, 0]
53
54
        self.EB2.append(nn.Conv2d(in_channels, ch, kernel_size=5, stride=1))
55
56
        self.EB2.append(nn.Conv2d(ch, ch * 2, kernel_size=5, stride=2))
57
        self.EB2.append(nn.Conv2d(ch * 2, ch * 2, kernel_size=5, stride=1))
58
59
        self.EB2.append(nn.Conv2d(ch * 2, ch * 4, kernel_size=5, stride=2))
60
        self.EB2.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, stride=1))
61
        self.EB2.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, stride=1))
62
63
        self.EB2.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, stride=1, dilation=2))
64
        self.EB2.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, stride=1, dilation=4))
65
        self.EB2.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, stride=1, dilation=8))
66
        self.EB2.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, stride=1, dilation=16))
67
68
        self.EB2.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, stride=1))
69
        self.EB2.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, stride=1))
70
71
        self.EB2.append(PureUpsampling(scale=2, mode='nearest'))
72
        self.EB2.append(nn.Conv2d(ch * 4, ch * 2, kernel_size=5, stride=1))
73
        self.EB2.append(nn.Conv2d(ch * 2, ch * 2, kernel_size=5, stride=1))
74
        self.EB2.append(PureUpsampling(scale=2))
75
        self.EB2_pad_rec = [2, 2, 2, 2, 2, 2, 4, 8, 16, 32, 2, 2, 0, 2, 2, 0]
76
77
        self.EB3.append(nn.Conv2d(in_channels, ch, kernel_size=3, stride=1))
78
79
        self.EB3.append(nn.Conv2d(ch, ch * 2, kernel_size=3, stride=2))
80
        self.EB3.append(nn.Conv2d(ch * 2, ch * 2, kernel_size=3, stride=1))
81
82
        self.EB3.append(nn.Conv2d(ch * 2, ch * 4, kernel_size=3, stride=2))
83
        self.EB3.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=3, stride=1))
84
        self.EB3.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=3, stride=1))
85
86
        self.EB3.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=3, stride=1, dilation=2))
87
        self.EB3.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=3, stride=1, dilation=4))
88
        self.EB3.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=3, stride=1, dilation=8))
89
        self.EB3.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=3, stride=1, dilation=16))
90
91
        self.EB3.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=3, stride=1))
92
        self.EB3.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=3, stride=1))
93
94
        self.EB3.append(PureUpsampling(scale=2, mode='nearest'))
95
        self.EB3.append(nn.Conv2d(ch * 4, ch * 2, kernel_size=3, stride=1))
96
        self.EB3.append(nn.Conv2d(ch * 2, ch * 2, kernel_size=3, stride=1))
97
        self.EB3.append(PureUpsampling(scale=2, mode='nearest'))
98
        self.EB3.append(nn.Conv2d(ch * 2, ch, kernel_size=3, stride=1))
99
        self.EB3.append(nn.Conv2d(ch, ch, kernel_size=3, stride=1))
100
101
        self.EB3_pad_rec = [1, 1, 1, 1, 1, 1, 2, 4, 8, 16, 1, 1, 0, 1, 1, 0, 1, 1]
102
103
        self.decoding_layers.append(nn.Conv2d(ch * 7, ch // 2, kernel_size=3, stride=1))
104
        self.decoding_layers.append(nn.Conv2d(ch // 2, out_channels, kernel_size=3, stride=1))
105
106
        self.decoding_pad_rec = [1, 1]
107
108
        self.EB1 = nn.ModuleList(self.EB1)
109
        self.EB2 = nn.ModuleList(self.EB2)
110
        self.EB3 = nn.ModuleList(self.EB3)
111
        self.decoding_layers = nn.ModuleList(self.decoding_layers)
112
113
        # padding operations
114
        padlen = 49
115
        self.pads = [0] * padlen
116
        for i in range(padlen):
117
            self.pads[i] = nn.ReflectionPad2d(i)
118
        self.pads = nn.ModuleList(self.pads)
119
120
    def forward(self, x):
121
        x1, x2, x3 = x, x, x
122
        for i, layer in enumerate(self.EB1):
123
            pad_idx = self.EB1_pad_rec[i]
124
            x1 = layer(self.pads[pad_idx](x1))
125
            if self.using_norm:
126
                x1 = self.norm(x1)
127
            if pad_idx != 0:
128
                x1 = self.act(x1)
129
130
        for i, layer in enumerate(self.EB2):
131
            pad_idx = self.EB2_pad_rec[i]
132
            x2 = layer(self.pads[pad_idx](x2))
133
            if self.using_norm:
134
                x2 = self.norm(x2)
135
            if pad_idx != 0:
136
                x2 = self.act(x2)
137
138
        for i, layer in enumerate(self.EB3):
139
            pad_idx = self.EB3_pad_rec[i]
140
            x3 = layer(self.pads[pad_idx](x3))
141
            if self.using_norm:
142
                x3 = self.norm(x3)
143
            if pad_idx != 0:
144
                x3 = self.act(x3)
145
146
        x_d = torch.cat((x1, x2, x3), 1)
147
        x_d = self.act(self.decoding_layers[0](self.pads[self.decoding_pad_rec[0]](x_d)))
148
        x_d = self.decoding_layers[1](self.pads[self.decoding_pad_rec[1]](x_d))
149
        x_out = torch.clamp(x_d, -1, 1)
150
        return x_out
151
152
153
# return one dimensional output indicating the probability of realness or fakeness
154
class Discriminator(BaseNet):
155
    def __init__(self, in_channels, cnum=32, fc_channels=8 * 8 * 32 * 4, act=F.elu, norm=None, spectral_norm=True):
156
        super(Discriminator, self).__init__()
157
        self.act = act
158
        self.norm = norm
159
        self.embedding = None
160
        self.logit = None
161
162
        ch = cnum
163
        self.layers = []
164
        if spectral_norm:
165
            self.layers.append(SpectralNorm(nn.Conv2d(in_channels, ch, kernel_size=5, padding=2, stride=2)))
166
            self.layers.append(SpectralNorm(nn.Conv2d(ch, ch * 2, kernel_size=5, padding=2, stride=2)))
167
            self.layers.append(SpectralNorm(nn.Conv2d(ch * 2, ch * 4, kernel_size=5, padding=2, stride=2)))
168
            self.layers.append(SpectralNorm(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, padding=2, stride=2)))
169
            self.layers.append(SpectralNorm(nn.Linear(fc_channels, 1)))
170
        else:
171
            self.layers.append(nn.Conv2d(in_channels, ch, kernel_size=5, padding=2, stride=2))
172
            self.layers.append(nn.Conv2d(ch, ch * 2, kernel_size=5, padding=2, stride=2))
173
            self.layers.append(nn.Conv2d(ch * 2, ch * 4, kernel_size=5, padding=2, stride=2))
174
            self.layers.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, padding=2, stride=2))
175
            self.layers.append(nn.Linear(fc_channels, 1))
176
        self.layers = nn.ModuleList(self.layers)
177
178
    def forward(self, x):
179
        for layer in self.layers[:-1]:
180
            x = layer(x)
181
            if self.norm is not None:
182
                x = self.norm(x)
183
            x = self.act(x)
184
        self.embedding = x.view(x.size(0), -1)
185
        self.logit = self.layers[-1](self.embedding)
186
        return self.logit
187
188
189
class GlobalLocalDiscriminator(BaseNet):
190
    def __init__(self, in_channels, cnum=32, g_fc_channels=16 * 16 * 32 * 4, l_fc_channels=8 * 8 * 32 * 4, act=F.elu,
191
                 norm=None,
192
                 spectral_norm=True):
193
        super(GlobalLocalDiscriminator, self).__init__()
194
        self.act = act
195
        self.norm = norm
196
197
        self.global_discriminator = Discriminator(in_channels=in_channels, fc_channels=g_fc_channels, cnum=cnum,
198
                                                  act=act, norm=norm, spectral_norm=spectral_norm)
199
        self.local_discriminator = Discriminator(in_channels=in_channels, fc_channels=l_fc_channels, cnum=cnum,
200
                                                 act=act, norm=norm, spectral_norm=spectral_norm)
201
202
    def forward(self, x_g, x_l):
203
        x_global = self.global_discriminator(x_g)
204
        x_local = self.local_discriminator(x_l)
205
        return x_global, x_local
206
207
208
from inpainting.util.utils import generate_mask
209
210
211
class InpaintingModel_GMCNN(BaseModel):
212
    def __init__(self, in_channels, act=F.elu, norm=None, opt=None):
213
        super(InpaintingModel_GMCNN, self).__init__()
214
        self.opt = opt
215
        self.init(opt)
216
217
        self.confidence_mask_layer = ConfidenceDrivenMaskLayer()
218
219
        self.netGM = GMCNN(in_channels, out_channels=3, cnum=opt.g_cnum, act=act, norm=norm).cuda()
220
        init_weights(self.netGM)
221
        self.model_names = ['GM']
222
        if self.opt.phase == 'test':
223
            return
224
225
        self.netD = None
226
227
        self.optimizer_G = torch.optim.Adam(self.netGM.parameters(), lr=opt.lr, betas=(0.5, 0.9))
228
        self.optimizer_D = None
229
230
        self.wganloss = None
231
        self.recloss = nn.L1Loss()
232
        self.aeloss = nn.L1Loss()
233
        self.mrfloss = None
234
        self.lambda_adv = opt.lambda_adv
235
        self.lambda_rec = opt.lambda_rec
236
        self.lambda_ae = opt.lambda_ae
237
        self.lambda_gp = opt.lambda_gp
238
        self.lambda_mrf = opt.lambda_mrf
239
        self.G_loss = None
240
        self.G_loss_reconstruction = None
241
        self.G_loss_mrf = None
242
        self.G_loss_adv, self.G_loss_adv_local = None, None
243
        self.G_loss_ae = None
244
        self.D_loss, self.D_loss_local = None, None
245
        self.GAN_loss = None
246
247
        self.gt, self.gt_local = None, None
248
        self.mask, self.mask_01 = None, None
249
        self.rect = None
250
        self.im_in, self.gin = None, None
251
252
        self.completed, self.completed_local = None, None
253
        self.completed_logit, self.completed_local_logit = None, None
254
        self.gt_logit, self.gt_local_logit = None, None
255
256
        self.pred = None
257
258
        if self.opt.pretrain_network is False:
259
            self.netD = GlobalLocalDiscriminator(3, cnum=opt.d_cnum, act=act,
260
                                                 spectral_norm=self.opt.spectral_norm,
261
                                                 g_fc_channels=opt.img_shapes[0] // 16 * opt.img_shapes[
262
                                                     1] // 16 * opt.d_cnum * 4,
263
                                                 l_fc_channels=opt.img_shapes[0] // 16 * opt.img_shapes[
264
                                                     1] // 16 * opt.d_cnum * 4).cuda()
265
            init_weights(self.netD)
266
            self.optimizer_D = torch.optim.Adam(filter(lambda x: x.requires_grad, self.netD.parameters()), lr=opt.lr,
267
                                                betas=(0.5, 0.9))
268
            self.wganloss = WGANLoss()
269
            self.mrfloss = IDMRFLoss()
270
271
    def initVariables(self):
272
        self.gt = self.img
273
        mask = self.mask
274
        # mask, rect = generate_mask(self.opt.mask_type, self.opt.img_shapes, self.opt.mask_shapes)
275
        # self.mask_01 = torch.from_numpy(mask).cuda().repeat([self.opt.batch_size, 1, 1, 1])
276
        self.mask = self.confidence_mask_layer(mask)
277
        self.gt_local = self.gt
278
        self.im_in = self.gt * (1 - self.mask_01)
279
        self.gin = torch.cat((self.im_in, self.mask_01), 1)
280
281
    def forward_G(self):
282
        self.G_loss_reconstruction = self.recloss(self.completed * self.mask, self.gt.detach() * self.mask)
283
        self.G_loss_reconstruction = self.G_loss_reconstruction / torch.mean(self.mask_01)
284
        self.G_loss_ae = self.aeloss(self.pred * (1 - self.mask_01), self.gt.detach() * (1 - self.mask_01))
285
        self.G_loss_ae = self.G_loss_ae / torch.mean(1 - self.mask_01)
286
        self.G_loss = self.lambda_rec * self.G_loss_reconstruction + self.lambda_ae * self.G_loss_ae
287
288
        self.completed_logit, self.completed_local_logit = self.netD(self.completed, self.completed_local)
289
        self.G_loss_mrf = self.mrfloss((self.completed_local + 1) / 2.0, (self.gt_local.detach() + 1) / 2.0)
290
        self.G_loss = self.G_loss + self.lambda_mrf * self.G_loss_mrf
291
292
        self.G_loss_adv = -self.completed_logit.mean()
293
        self.G_loss_adv_local = -self.completed_local_logit.mean()
294
        self.G_loss = self.G_loss + self.lambda_adv * (self.G_loss_adv + self.G_loss_adv_local)
295
296
    def forward_D(self):
297
        self.completed_logit, self.completed_local_logit = self.netD(self.completed.detach(),
298
                                                                     self.completed_local.detach())
299
        self.gt_logit, self.gt_local_logit = self.netD(self.gt, self.gt_local)
300
        # hinge loss
301
        self.D_loss_local = nn.ReLU()(1.0 - self.gt_local_logit).mean() + nn.ReLU()(
302
            1.0 + self.completed_local_logit).mean()
303
        self.D_loss = nn.ReLU()(1.0 - self.gt_logit).mean() + nn.ReLU()(1.0 + self.completed_logit).mean()
304
        self.D_loss = self.D_loss + self.D_loss_local
305
306
    def backward_G(self):
307
        self.G_loss.backward()
308
309
    def backward_D(self):
310
        self.D_loss.backward(retain_graph=True)
311
312
    def optimize_parameters(self):
313
        self.initVariables()
314
315
        self.pred = self.netGM(self.gin)
316
        self.completed = self.pred * self.mask_01 + self.gt * (1 - self.mask_01)
317
        self.completed_local = self.completed
318
319
        for i in range(self.opt.D_max_iters):
320
            self.optimizer_D.zero_grad()
321
            self.optimizer_G.zero_grad()
322
            self.forward_D()
323
            self.backward_D()
324
            self.optimizer_D.step()
325
326
        self.optimizer_G.zero_grad()
327
        self.forward_G()
328
        self.backward_G()
329
        self.optimizer_G.step()
330
331
    def get_current_losses(self):
332
        l = {'G_loss': self.G_loss.item(), 'G_loss_rec': self.G_loss_reconstruction.item(),
333
             'G_loss_ae': self.G_loss_ae.item()}
334
        if self.opt.pretrain_network is False:
335
            l.update({'G_loss_adv': self.G_loss_adv.item(),
336
                      'G_loss_adv_local': self.G_loss_adv_local.item(),
337
                      'D_loss': self.D_loss.item(),
338
                      'G_loss_mrf': self.G_loss_mrf.item()})
339
        return l
340
341
    def get_current_visuals(self):
342
        return {'input': self.im_in.cpu().detach().numpy(), 'gt': self.gt.cpu().detach().numpy(),
343
                'completed': self.completed.cpu().detach().numpy()}
344
345
    def get_current_visuals_tensor(self):
346
        return {'input': self.im_in.cpu().detach(), 'gt': self.gt.cpu().detach(),
347
                'completed': self.completed.cpu().detach()}
348
349
    def evaluate(self, im_in, mask):
350
        im_in = torch.from_numpy(im_in).type(torch.FloatTensor).cuda() / 127.5 - 1
351
        mask = torch.from_numpy(mask).type(torch.FloatTensor).cuda()
352
        im_in = im_in * (1 - mask)
353
        xin = torch.cat((im_in, mask), 1)
354
        ret = self.netGM(xin) * mask + im_in * (1 - mask)
355
        ret = (ret.cpu().detach().numpy() + 1) * 127.5
356
        return ret.astype(np.uint8)