|
a |
|
b/inpainting/model/net.py |
|
|
1 |
import torch |
|
|
2 |
import torch.nn as nn |
|
|
3 |
import torch.nn.functional as F |
|
|
4 |
from inpainting.model.basemodel import BaseModel |
|
|
5 |
from inpainting.model.basenet import BaseNet |
|
|
6 |
from inpainting.model.loss import WGANLoss, IDMRFLoss |
|
|
7 |
from inpainting.model.layer import init_weights, PureUpsampling, ConfidenceDrivenMaskLayer, SpectralNorm |
|
|
8 |
import numpy as np |
|
|
9 |
|
|
|
10 |
|
|
|
11 |
# generative multi-column convolutional neural net |
|
|
12 |
class GMCNN(BaseNet): |
|
|
13 |
def __init__(self, in_channels, out_channels, cnum=32, act=F.elu, norm=F.instance_norm, using_norm=False): |
|
|
14 |
super(GMCNN, self).__init__() |
|
|
15 |
self.act = act |
|
|
16 |
self.using_norm = using_norm |
|
|
17 |
if using_norm is True: |
|
|
18 |
self.norm = norm |
|
|
19 |
else: |
|
|
20 |
self.norm = None |
|
|
21 |
ch = cnum |
|
|
22 |
|
|
|
23 |
# network structure |
|
|
24 |
self.EB1 = [] |
|
|
25 |
self.EB2 = [] |
|
|
26 |
self.EB3 = [] |
|
|
27 |
self.decoding_layers = [] |
|
|
28 |
|
|
|
29 |
self.EB1_pad_rec = [] |
|
|
30 |
self.EB2_pad_rec = [] |
|
|
31 |
self.EB3_pad_rec = [] |
|
|
32 |
|
|
|
33 |
self.EB1.append(nn.Conv2d(in_channels, ch, kernel_size=7, stride=1)) |
|
|
34 |
|
|
|
35 |
self.EB1.append(nn.Conv2d(ch, ch * 2, kernel_size=7, stride=2)) |
|
|
36 |
self.EB1.append(nn.Conv2d(ch * 2, ch * 2, kernel_size=7, stride=1)) |
|
|
37 |
|
|
|
38 |
self.EB1.append(nn.Conv2d(ch * 2, ch * 4, kernel_size=7, stride=2)) |
|
|
39 |
self.EB1.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=7, stride=1)) |
|
|
40 |
self.EB1.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=7, stride=1)) |
|
|
41 |
|
|
|
42 |
self.EB1.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=7, stride=1, dilation=2)) |
|
|
43 |
self.EB1.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=7, stride=1, dilation=4)) |
|
|
44 |
self.EB1.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=7, stride=1, dilation=8)) |
|
|
45 |
self.EB1.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=7, stride=1, dilation=16)) |
|
|
46 |
|
|
|
47 |
self.EB1.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=7, stride=1)) |
|
|
48 |
self.EB1.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=7, stride=1)) |
|
|
49 |
|
|
|
50 |
self.EB1.append(PureUpsampling(scale=4)) |
|
|
51 |
|
|
|
52 |
self.EB1_pad_rec = [3, 3, 3, 3, 3, 3, 6, 12, 24, 48, 3, 3, 0] |
|
|
53 |
|
|
|
54 |
self.EB2.append(nn.Conv2d(in_channels, ch, kernel_size=5, stride=1)) |
|
|
55 |
|
|
|
56 |
self.EB2.append(nn.Conv2d(ch, ch * 2, kernel_size=5, stride=2)) |
|
|
57 |
self.EB2.append(nn.Conv2d(ch * 2, ch * 2, kernel_size=5, stride=1)) |
|
|
58 |
|
|
|
59 |
self.EB2.append(nn.Conv2d(ch * 2, ch * 4, kernel_size=5, stride=2)) |
|
|
60 |
self.EB2.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, stride=1)) |
|
|
61 |
self.EB2.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, stride=1)) |
|
|
62 |
|
|
|
63 |
self.EB2.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, stride=1, dilation=2)) |
|
|
64 |
self.EB2.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, stride=1, dilation=4)) |
|
|
65 |
self.EB2.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, stride=1, dilation=8)) |
|
|
66 |
self.EB2.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, stride=1, dilation=16)) |
|
|
67 |
|
|
|
68 |
self.EB2.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, stride=1)) |
|
|
69 |
self.EB2.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, stride=1)) |
|
|
70 |
|
|
|
71 |
self.EB2.append(PureUpsampling(scale=2, mode='nearest')) |
|
|
72 |
self.EB2.append(nn.Conv2d(ch * 4, ch * 2, kernel_size=5, stride=1)) |
|
|
73 |
self.EB2.append(nn.Conv2d(ch * 2, ch * 2, kernel_size=5, stride=1)) |
|
|
74 |
self.EB2.append(PureUpsampling(scale=2)) |
|
|
75 |
self.EB2_pad_rec = [2, 2, 2, 2, 2, 2, 4, 8, 16, 32, 2, 2, 0, 2, 2, 0] |
|
|
76 |
|
|
|
77 |
self.EB3.append(nn.Conv2d(in_channels, ch, kernel_size=3, stride=1)) |
|
|
78 |
|
|
|
79 |
self.EB3.append(nn.Conv2d(ch, ch * 2, kernel_size=3, stride=2)) |
|
|
80 |
self.EB3.append(nn.Conv2d(ch * 2, ch * 2, kernel_size=3, stride=1)) |
|
|
81 |
|
|
|
82 |
self.EB3.append(nn.Conv2d(ch * 2, ch * 4, kernel_size=3, stride=2)) |
|
|
83 |
self.EB3.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=3, stride=1)) |
|
|
84 |
self.EB3.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=3, stride=1)) |
|
|
85 |
|
|
|
86 |
self.EB3.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=3, stride=1, dilation=2)) |
|
|
87 |
self.EB3.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=3, stride=1, dilation=4)) |
|
|
88 |
self.EB3.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=3, stride=1, dilation=8)) |
|
|
89 |
self.EB3.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=3, stride=1, dilation=16)) |
|
|
90 |
|
|
|
91 |
self.EB3.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=3, stride=1)) |
|
|
92 |
self.EB3.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=3, stride=1)) |
|
|
93 |
|
|
|
94 |
self.EB3.append(PureUpsampling(scale=2, mode='nearest')) |
|
|
95 |
self.EB3.append(nn.Conv2d(ch * 4, ch * 2, kernel_size=3, stride=1)) |
|
|
96 |
self.EB3.append(nn.Conv2d(ch * 2, ch * 2, kernel_size=3, stride=1)) |
|
|
97 |
self.EB3.append(PureUpsampling(scale=2, mode='nearest')) |
|
|
98 |
self.EB3.append(nn.Conv2d(ch * 2, ch, kernel_size=3, stride=1)) |
|
|
99 |
self.EB3.append(nn.Conv2d(ch, ch, kernel_size=3, stride=1)) |
|
|
100 |
|
|
|
101 |
self.EB3_pad_rec = [1, 1, 1, 1, 1, 1, 2, 4, 8, 16, 1, 1, 0, 1, 1, 0, 1, 1] |
|
|
102 |
|
|
|
103 |
self.decoding_layers.append(nn.Conv2d(ch * 7, ch // 2, kernel_size=3, stride=1)) |
|
|
104 |
self.decoding_layers.append(nn.Conv2d(ch // 2, out_channels, kernel_size=3, stride=1)) |
|
|
105 |
|
|
|
106 |
self.decoding_pad_rec = [1, 1] |
|
|
107 |
|
|
|
108 |
self.EB1 = nn.ModuleList(self.EB1) |
|
|
109 |
self.EB2 = nn.ModuleList(self.EB2) |
|
|
110 |
self.EB3 = nn.ModuleList(self.EB3) |
|
|
111 |
self.decoding_layers = nn.ModuleList(self.decoding_layers) |
|
|
112 |
|
|
|
113 |
# padding operations |
|
|
114 |
padlen = 49 |
|
|
115 |
self.pads = [0] * padlen |
|
|
116 |
for i in range(padlen): |
|
|
117 |
self.pads[i] = nn.ReflectionPad2d(i) |
|
|
118 |
self.pads = nn.ModuleList(self.pads) |
|
|
119 |
|
|
|
120 |
def forward(self, x): |
|
|
121 |
x1, x2, x3 = x, x, x |
|
|
122 |
for i, layer in enumerate(self.EB1): |
|
|
123 |
pad_idx = self.EB1_pad_rec[i] |
|
|
124 |
x1 = layer(self.pads[pad_idx](x1)) |
|
|
125 |
if self.using_norm: |
|
|
126 |
x1 = self.norm(x1) |
|
|
127 |
if pad_idx != 0: |
|
|
128 |
x1 = self.act(x1) |
|
|
129 |
|
|
|
130 |
for i, layer in enumerate(self.EB2): |
|
|
131 |
pad_idx = self.EB2_pad_rec[i] |
|
|
132 |
x2 = layer(self.pads[pad_idx](x2)) |
|
|
133 |
if self.using_norm: |
|
|
134 |
x2 = self.norm(x2) |
|
|
135 |
if pad_idx != 0: |
|
|
136 |
x2 = self.act(x2) |
|
|
137 |
|
|
|
138 |
for i, layer in enumerate(self.EB3): |
|
|
139 |
pad_idx = self.EB3_pad_rec[i] |
|
|
140 |
x3 = layer(self.pads[pad_idx](x3)) |
|
|
141 |
if self.using_norm: |
|
|
142 |
x3 = self.norm(x3) |
|
|
143 |
if pad_idx != 0: |
|
|
144 |
x3 = self.act(x3) |
|
|
145 |
|
|
|
146 |
x_d = torch.cat((x1, x2, x3), 1) |
|
|
147 |
x_d = self.act(self.decoding_layers[0](self.pads[self.decoding_pad_rec[0]](x_d))) |
|
|
148 |
x_d = self.decoding_layers[1](self.pads[self.decoding_pad_rec[1]](x_d)) |
|
|
149 |
x_out = torch.clamp(x_d, -1, 1) |
|
|
150 |
return x_out |
|
|
151 |
|
|
|
152 |
|
|
|
153 |
# return one dimensional output indicating the probability of realness or fakeness |
|
|
154 |
class Discriminator(BaseNet): |
|
|
155 |
def __init__(self, in_channels, cnum=32, fc_channels=8 * 8 * 32 * 4, act=F.elu, norm=None, spectral_norm=True): |
|
|
156 |
super(Discriminator, self).__init__() |
|
|
157 |
self.act = act |
|
|
158 |
self.norm = norm |
|
|
159 |
self.embedding = None |
|
|
160 |
self.logit = None |
|
|
161 |
|
|
|
162 |
ch = cnum |
|
|
163 |
self.layers = [] |
|
|
164 |
if spectral_norm: |
|
|
165 |
self.layers.append(SpectralNorm(nn.Conv2d(in_channels, ch, kernel_size=5, padding=2, stride=2))) |
|
|
166 |
self.layers.append(SpectralNorm(nn.Conv2d(ch, ch * 2, kernel_size=5, padding=2, stride=2))) |
|
|
167 |
self.layers.append(SpectralNorm(nn.Conv2d(ch * 2, ch * 4, kernel_size=5, padding=2, stride=2))) |
|
|
168 |
self.layers.append(SpectralNorm(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, padding=2, stride=2))) |
|
|
169 |
self.layers.append(SpectralNorm(nn.Linear(fc_channels, 1))) |
|
|
170 |
else: |
|
|
171 |
self.layers.append(nn.Conv2d(in_channels, ch, kernel_size=5, padding=2, stride=2)) |
|
|
172 |
self.layers.append(nn.Conv2d(ch, ch * 2, kernel_size=5, padding=2, stride=2)) |
|
|
173 |
self.layers.append(nn.Conv2d(ch * 2, ch * 4, kernel_size=5, padding=2, stride=2)) |
|
|
174 |
self.layers.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, padding=2, stride=2)) |
|
|
175 |
self.layers.append(nn.Linear(fc_channels, 1)) |
|
|
176 |
self.layers = nn.ModuleList(self.layers) |
|
|
177 |
|
|
|
178 |
def forward(self, x): |
|
|
179 |
for layer in self.layers[:-1]: |
|
|
180 |
x = layer(x) |
|
|
181 |
if self.norm is not None: |
|
|
182 |
x = self.norm(x) |
|
|
183 |
x = self.act(x) |
|
|
184 |
self.embedding = x.view(x.size(0), -1) |
|
|
185 |
self.logit = self.layers[-1](self.embedding) |
|
|
186 |
return self.logit |
|
|
187 |
|
|
|
188 |
|
|
|
189 |
class GlobalLocalDiscriminator(BaseNet): |
|
|
190 |
def __init__(self, in_channels, cnum=32, g_fc_channels=16 * 16 * 32 * 4, l_fc_channels=8 * 8 * 32 * 4, act=F.elu, |
|
|
191 |
norm=None, |
|
|
192 |
spectral_norm=True): |
|
|
193 |
super(GlobalLocalDiscriminator, self).__init__() |
|
|
194 |
self.act = act |
|
|
195 |
self.norm = norm |
|
|
196 |
|
|
|
197 |
self.global_discriminator = Discriminator(in_channels=in_channels, fc_channels=g_fc_channels, cnum=cnum, |
|
|
198 |
act=act, norm=norm, spectral_norm=spectral_norm) |
|
|
199 |
self.local_discriminator = Discriminator(in_channels=in_channels, fc_channels=l_fc_channels, cnum=cnum, |
|
|
200 |
act=act, norm=norm, spectral_norm=spectral_norm) |
|
|
201 |
|
|
|
202 |
def forward(self, x_g, x_l): |
|
|
203 |
x_global = self.global_discriminator(x_g) |
|
|
204 |
x_local = self.local_discriminator(x_l) |
|
|
205 |
return x_global, x_local |
|
|
206 |
|
|
|
207 |
|
|
|
208 |
from inpainting.util.utils import generate_mask |
|
|
209 |
|
|
|
210 |
|
|
|
211 |
class InpaintingModel_GMCNN(BaseModel): |
|
|
212 |
def __init__(self, in_channels, act=F.elu, norm=None, opt=None): |
|
|
213 |
super(InpaintingModel_GMCNN, self).__init__() |
|
|
214 |
self.opt = opt |
|
|
215 |
self.init(opt) |
|
|
216 |
|
|
|
217 |
self.confidence_mask_layer = ConfidenceDrivenMaskLayer() |
|
|
218 |
|
|
|
219 |
self.netGM = GMCNN(in_channels, out_channels=3, cnum=opt.g_cnum, act=act, norm=norm).cuda() |
|
|
220 |
init_weights(self.netGM) |
|
|
221 |
self.model_names = ['GM'] |
|
|
222 |
if self.opt.phase == 'test': |
|
|
223 |
return |
|
|
224 |
|
|
|
225 |
self.netD = None |
|
|
226 |
|
|
|
227 |
self.optimizer_G = torch.optim.Adam(self.netGM.parameters(), lr=opt.lr, betas=(0.5, 0.9)) |
|
|
228 |
self.optimizer_D = None |
|
|
229 |
|
|
|
230 |
self.wganloss = None |
|
|
231 |
self.recloss = nn.L1Loss() |
|
|
232 |
self.aeloss = nn.L1Loss() |
|
|
233 |
self.mrfloss = None |
|
|
234 |
self.lambda_adv = opt.lambda_adv |
|
|
235 |
self.lambda_rec = opt.lambda_rec |
|
|
236 |
self.lambda_ae = opt.lambda_ae |
|
|
237 |
self.lambda_gp = opt.lambda_gp |
|
|
238 |
self.lambda_mrf = opt.lambda_mrf |
|
|
239 |
self.G_loss = None |
|
|
240 |
self.G_loss_reconstruction = None |
|
|
241 |
self.G_loss_mrf = None |
|
|
242 |
self.G_loss_adv, self.G_loss_adv_local = None, None |
|
|
243 |
self.G_loss_ae = None |
|
|
244 |
self.D_loss, self.D_loss_local = None, None |
|
|
245 |
self.GAN_loss = None |
|
|
246 |
|
|
|
247 |
self.gt, self.gt_local = None, None |
|
|
248 |
self.mask, self.mask_01 = None, None |
|
|
249 |
self.rect = None |
|
|
250 |
self.im_in, self.gin = None, None |
|
|
251 |
|
|
|
252 |
self.completed, self.completed_local = None, None |
|
|
253 |
self.completed_logit, self.completed_local_logit = None, None |
|
|
254 |
self.gt_logit, self.gt_local_logit = None, None |
|
|
255 |
|
|
|
256 |
self.pred = None |
|
|
257 |
|
|
|
258 |
if self.opt.pretrain_network is False: |
|
|
259 |
self.netD = GlobalLocalDiscriminator(3, cnum=opt.d_cnum, act=act, |
|
|
260 |
spectral_norm=self.opt.spectral_norm, |
|
|
261 |
g_fc_channels=opt.img_shapes[0] // 16 * opt.img_shapes[ |
|
|
262 |
1] // 16 * opt.d_cnum * 4, |
|
|
263 |
l_fc_channels=opt.img_shapes[0] // 16 * opt.img_shapes[ |
|
|
264 |
1] // 16 * opt.d_cnum * 4).cuda() |
|
|
265 |
init_weights(self.netD) |
|
|
266 |
self.optimizer_D = torch.optim.Adam(filter(lambda x: x.requires_grad, self.netD.parameters()), lr=opt.lr, |
|
|
267 |
betas=(0.5, 0.9)) |
|
|
268 |
self.wganloss = WGANLoss() |
|
|
269 |
self.mrfloss = IDMRFLoss() |
|
|
270 |
|
|
|
271 |
def initVariables(self): |
|
|
272 |
self.gt = self.img |
|
|
273 |
mask = self.mask |
|
|
274 |
# mask, rect = generate_mask(self.opt.mask_type, self.opt.img_shapes, self.opt.mask_shapes) |
|
|
275 |
# self.mask_01 = torch.from_numpy(mask).cuda().repeat([self.opt.batch_size, 1, 1, 1]) |
|
|
276 |
self.mask = self.confidence_mask_layer(mask) |
|
|
277 |
self.gt_local = self.gt |
|
|
278 |
self.im_in = self.gt * (1 - self.mask_01) |
|
|
279 |
self.gin = torch.cat((self.im_in, self.mask_01), 1) |
|
|
280 |
|
|
|
281 |
def forward_G(self): |
|
|
282 |
self.G_loss_reconstruction = self.recloss(self.completed * self.mask, self.gt.detach() * self.mask) |
|
|
283 |
self.G_loss_reconstruction = self.G_loss_reconstruction / torch.mean(self.mask_01) |
|
|
284 |
self.G_loss_ae = self.aeloss(self.pred * (1 - self.mask_01), self.gt.detach() * (1 - self.mask_01)) |
|
|
285 |
self.G_loss_ae = self.G_loss_ae / torch.mean(1 - self.mask_01) |
|
|
286 |
self.G_loss = self.lambda_rec * self.G_loss_reconstruction + self.lambda_ae * self.G_loss_ae |
|
|
287 |
|
|
|
288 |
self.completed_logit, self.completed_local_logit = self.netD(self.completed, self.completed_local) |
|
|
289 |
self.G_loss_mrf = self.mrfloss((self.completed_local + 1) / 2.0, (self.gt_local.detach() + 1) / 2.0) |
|
|
290 |
self.G_loss = self.G_loss + self.lambda_mrf * self.G_loss_mrf |
|
|
291 |
|
|
|
292 |
self.G_loss_adv = -self.completed_logit.mean() |
|
|
293 |
self.G_loss_adv_local = -self.completed_local_logit.mean() |
|
|
294 |
self.G_loss = self.G_loss + self.lambda_adv * (self.G_loss_adv + self.G_loss_adv_local) |
|
|
295 |
|
|
|
296 |
def forward_D(self): |
|
|
297 |
self.completed_logit, self.completed_local_logit = self.netD(self.completed.detach(), |
|
|
298 |
self.completed_local.detach()) |
|
|
299 |
self.gt_logit, self.gt_local_logit = self.netD(self.gt, self.gt_local) |
|
|
300 |
# hinge loss |
|
|
301 |
self.D_loss_local = nn.ReLU()(1.0 - self.gt_local_logit).mean() + nn.ReLU()( |
|
|
302 |
1.0 + self.completed_local_logit).mean() |
|
|
303 |
self.D_loss = nn.ReLU()(1.0 - self.gt_logit).mean() + nn.ReLU()(1.0 + self.completed_logit).mean() |
|
|
304 |
self.D_loss = self.D_loss + self.D_loss_local |
|
|
305 |
|
|
|
306 |
def backward_G(self): |
|
|
307 |
self.G_loss.backward() |
|
|
308 |
|
|
|
309 |
def backward_D(self): |
|
|
310 |
self.D_loss.backward(retain_graph=True) |
|
|
311 |
|
|
|
312 |
def optimize_parameters(self): |
|
|
313 |
self.initVariables() |
|
|
314 |
|
|
|
315 |
self.pred = self.netGM(self.gin) |
|
|
316 |
self.completed = self.pred * self.mask_01 + self.gt * (1 - self.mask_01) |
|
|
317 |
self.completed_local = self.completed |
|
|
318 |
|
|
|
319 |
for i in range(self.opt.D_max_iters): |
|
|
320 |
self.optimizer_D.zero_grad() |
|
|
321 |
self.optimizer_G.zero_grad() |
|
|
322 |
self.forward_D() |
|
|
323 |
self.backward_D() |
|
|
324 |
self.optimizer_D.step() |
|
|
325 |
|
|
|
326 |
self.optimizer_G.zero_grad() |
|
|
327 |
self.forward_G() |
|
|
328 |
self.backward_G() |
|
|
329 |
self.optimizer_G.step() |
|
|
330 |
|
|
|
331 |
def get_current_losses(self): |
|
|
332 |
l = {'G_loss': self.G_loss.item(), 'G_loss_rec': self.G_loss_reconstruction.item(), |
|
|
333 |
'G_loss_ae': self.G_loss_ae.item()} |
|
|
334 |
if self.opt.pretrain_network is False: |
|
|
335 |
l.update({'G_loss_adv': self.G_loss_adv.item(), |
|
|
336 |
'G_loss_adv_local': self.G_loss_adv_local.item(), |
|
|
337 |
'D_loss': self.D_loss.item(), |
|
|
338 |
'G_loss_mrf': self.G_loss_mrf.item()}) |
|
|
339 |
return l |
|
|
340 |
|
|
|
341 |
def get_current_visuals(self): |
|
|
342 |
return {'input': self.im_in.cpu().detach().numpy(), 'gt': self.gt.cpu().detach().numpy(), |
|
|
343 |
'completed': self.completed.cpu().detach().numpy()} |
|
|
344 |
|
|
|
345 |
def get_current_visuals_tensor(self): |
|
|
346 |
return {'input': self.im_in.cpu().detach(), 'gt': self.gt.cpu().detach(), |
|
|
347 |
'completed': self.completed.cpu().detach()} |
|
|
348 |
|
|
|
349 |
def evaluate(self, im_in, mask): |
|
|
350 |
im_in = torch.from_numpy(im_in).type(torch.FloatTensor).cuda() / 127.5 - 1 |
|
|
351 |
mask = torch.from_numpy(mask).type(torch.FloatTensor).cuda() |
|
|
352 |
im_in = im_in * (1 - mask) |
|
|
353 |
xin = torch.cat((im_in, mask), 1) |
|
|
354 |
ret = self.netGM(xin) * mask + im_in * (1 - mask) |
|
|
355 |
ret = (ret.cpu().detach().numpy() + 1) * 127.5 |
|
|
356 |
return ret.astype(np.uint8) |