|
a |
|
b/inpainting/test.py |
|
|
1 |
import numpy as np |
|
|
2 |
import cv2 |
|
|
3 |
import os |
|
|
4 |
import subprocess |
|
|
5 |
import glob |
|
|
6 |
from options.test_options import TestOptions |
|
|
7 |
from model.net import InpaintingModel_GMCNN |
|
|
8 |
from util.utils import generate_rect_mask, generate_stroke_mask, getLatest |
|
|
9 |
|
|
|
10 |
os.environ['CUDA_VISIBLE_DEVICES']=str(np.argmax([int(x.split()[2]) for x in subprocess.Popen( |
|
|
11 |
"nvidia-smi -q -d Memory | grep -A4 GPU | grep Free", shell=True, stdout=subprocess.PIPE).stdout.readlines()] |
|
|
12 |
)) |
|
|
13 |
|
|
|
14 |
config = TestOptions().parse() |
|
|
15 |
|
|
|
16 |
if os.path.isfile(config.dataset_path): |
|
|
17 |
pathfile = open(config.dataset_path, 'rt').read().splitlines() |
|
|
18 |
elif os.path.isdir(config.dataset_path): |
|
|
19 |
pathfile = glob.glob(os.path.join(config.dataset_path, '*.png')) |
|
|
20 |
else: |
|
|
21 |
print('Invalid testing data file/folder path.') |
|
|
22 |
exit(1) |
|
|
23 |
total_number = len(pathfile) |
|
|
24 |
test_num = total_number if config.test_num == -1 else min(total_number, config.test_num) |
|
|
25 |
print('The total number of testing images is {}, and we take {} for test.'.format(total_number, test_num)) |
|
|
26 |
|
|
|
27 |
print('configuring model..') |
|
|
28 |
ourModel = InpaintingModel_GMCNN(in_channels=4, opt=config) |
|
|
29 |
ourModel.print_networks() |
|
|
30 |
if config.load_model_dir != '': |
|
|
31 |
print('Loading pretrained model from {}'.format(config.load_model_dir)) |
|
|
32 |
ourModel.load_networks(getLatest(os.path.join(config.load_model_dir, '*.pth'))) |
|
|
33 |
print('Loading done.') |
|
|
34 |
|
|
|
35 |
if config.random_mask: |
|
|
36 |
np.random.seed(config.seed) |
|
|
37 |
|
|
|
38 |
for i in range(test_num): |
|
|
39 |
if config.mask_type == 'rect': |
|
|
40 |
mask, _ = generate_rect_mask(config.img_shapes, config.mask_shapes, config.random_mask) |
|
|
41 |
else: |
|
|
42 |
mask = generate_stroke_mask(im_size=(config.img_shapes[0], config.img_shapes[1]), |
|
|
43 |
parts=8, maxBrushWidth=20, maxLength=100, maxVertex=20) |
|
|
44 |
image = cv2.imread(pathfile[i]) |
|
|
45 |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
46 |
h, w = image.shape[:2] |
|
|
47 |
|
|
|
48 |
if h >= config.img_shapes[0] and w >= config.img_shapes[1]: |
|
|
49 |
h_start = (h-config.img_shapes[0]) // 2 |
|
|
50 |
w_start = (w-config.img_shapes[1]) // 2 |
|
|
51 |
image = image[h_start: h_start+config.img_shapes[0], w_start: w_start+config.img_shapes[1], :] |
|
|
52 |
else: |
|
|
53 |
t = min(h, w) |
|
|
54 |
image = image[(h-t)//2:(h-t)//2+t, (w-t)//2:(w-t)//2+t, :] |
|
|
55 |
image = cv2.resize(image, (config.img_shapes[1], config.img_shapes[0])) |
|
|
56 |
|
|
|
57 |
image = np.transpose(image, [2, 0, 1]) |
|
|
58 |
image = np.expand_dims(image, axis=0) |
|
|
59 |
image_vis = image * (1-mask) + 255 * mask |
|
|
60 |
image_vis = np.transpose(image_vis[0][::-1,:,:], [1, 2, 0]) |
|
|
61 |
cv2.imwrite(os.path.join(config.saving_path, 'input_{:03d}.png'.format(i)), image_vis.astype(np.uint8)) |
|
|
62 |
|
|
|
63 |
h, w = image.shape[2:] |
|
|
64 |
grid = 4 |
|
|
65 |
image = image[:, :, :h // grid * grid, :w // grid * grid] |
|
|
66 |
mask = mask[:, :, :h // grid * grid, :w // grid * grid] |
|
|
67 |
result = ourModel.evaluate(image, mask) |
|
|
68 |
result = np.transpose(result[0][::-1,:,:], [1, 2, 0]) |
|
|
69 |
cv2.imwrite(os.path.join(config.saving_path, '{:03d}.png'.format(i)), result) |
|
|
70 |
print(' > {} / {}'.format(i+1, test_num)) |
|
|
71 |
print('done.') |