a b/inpainting/test.py
1
import numpy as np
2
import cv2
3
import os
4
import subprocess
5
import glob
6
from options.test_options import TestOptions
7
from model.net import InpaintingModel_GMCNN
8
from util.utils import generate_rect_mask, generate_stroke_mask, getLatest
9
10
os.environ['CUDA_VISIBLE_DEVICES']=str(np.argmax([int(x.split()[2]) for x in subprocess.Popen(
11
        "nvidia-smi -q -d Memory | grep -A4 GPU | grep Free", shell=True, stdout=subprocess.PIPE).stdout.readlines()]
12
        ))
13
14
config = TestOptions().parse()
15
16
if os.path.isfile(config.dataset_path):
17
    pathfile = open(config.dataset_path, 'rt').read().splitlines()
18
elif os.path.isdir(config.dataset_path):
19
    pathfile = glob.glob(os.path.join(config.dataset_path, '*.png'))
20
else:
21
    print('Invalid testing data file/folder path.')
22
    exit(1)
23
total_number = len(pathfile)
24
test_num = total_number if config.test_num == -1 else min(total_number, config.test_num)
25
print('The total number of testing images is {}, and we take {} for test.'.format(total_number, test_num))
26
27
print('configuring model..')
28
ourModel = InpaintingModel_GMCNN(in_channels=4, opt=config)
29
ourModel.print_networks()
30
if config.load_model_dir != '':
31
    print('Loading pretrained model from {}'.format(config.load_model_dir))
32
    ourModel.load_networks(getLatest(os.path.join(config.load_model_dir, '*.pth')))
33
    print('Loading done.')
34
35
if config.random_mask:
36
    np.random.seed(config.seed)
37
38
for i in range(test_num):
39
    if config.mask_type == 'rect':
40
        mask, _ = generate_rect_mask(config.img_shapes, config.mask_shapes, config.random_mask)
41
    else:
42
        mask = generate_stroke_mask(im_size=(config.img_shapes[0], config.img_shapes[1]),
43
                                    parts=8, maxBrushWidth=20, maxLength=100, maxVertex=20)
44
    image = cv2.imread(pathfile[i])
45
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
46
    h, w = image.shape[:2]
47
48
    if h >= config.img_shapes[0] and w >= config.img_shapes[1]:
49
        h_start = (h-config.img_shapes[0]) // 2
50
        w_start = (w-config.img_shapes[1]) // 2
51
        image = image[h_start: h_start+config.img_shapes[0], w_start: w_start+config.img_shapes[1], :]
52
    else:
53
        t = min(h, w)
54
        image = image[(h-t)//2:(h-t)//2+t, (w-t)//2:(w-t)//2+t, :]
55
        image = cv2.resize(image, (config.img_shapes[1], config.img_shapes[0]))
56
57
    image = np.transpose(image, [2, 0, 1])
58
    image = np.expand_dims(image, axis=0)
59
    image_vis = image * (1-mask) + 255 * mask
60
    image_vis = np.transpose(image_vis[0][::-1,:,:], [1, 2, 0])
61
    cv2.imwrite(os.path.join(config.saving_path, 'input_{:03d}.png'.format(i)), image_vis.astype(np.uint8))
62
63
    h, w = image.shape[2:]
64
    grid = 4
65
    image = image[:, :, :h // grid * grid, :w // grid * grid]
66
    mask = mask[:, :, :h // grid * grid, :w // grid * grid]
67
    result = ourModel.evaluate(image, mask)
68
    result = np.transpose(result[0][::-1,:,:], [1, 2, 0])
69
    cv2.imwrite(os.path.join(config.saving_path, '{:03d}.png'.format(i)), result)
70
    print(' > {} / {}'.format(i+1, test_num))
71
print('done.')