a b/inpainting/options/test_options.py
1
import argparse
2
import os
3
import time
4
5
class TestOptions:
6
    def __init__(self):
7
        self.parser = argparse.ArgumentParser()
8
        self.initialized = False
9
10
    def initialize(self):
11
        self.parser.add_argument('--dataset', type=str, default='paris_streetview',
12
                                 help='The dataset of the experiment.')
13
        self.parser.add_argument('--data_file', type=str, default='./imgs/paris-streetview_256x256', help='the file storing testing file paths')
14
        self.parser.add_argument('--test_dir', type=str, default='./test_results', help='models are saved here')
15
        self.parser.add_argument('--load_model_dir', type=str, default='./checkpoints', help='pretrained models are given here')
16
        self.parser.add_argument('--seed', type=int, default=1, help='random seed')
17
        self.parser.add_argument('--gpu_ids', type=str, default='0')
18
19
        self.parser.add_argument('--model', type=str, default='gmcnn')
20
        self.parser.add_argument('--random_mask', type=int, default=0,
21
                                 help='using random mask')
22
23
        self.parser.add_argument('--img_shapes', type=str, default='256,256,3',
24
                                 help='given shape parameters: h,w,c or h,w')
25
        self.parser.add_argument('--mask_shapes', type=str, default='128,128',
26
                                 help='given mask parameters: h,w')
27
        self.parser.add_argument('--mask_type', type=str, default='rect')
28
        self.parser.add_argument('--test_num', type=int, default=-1)
29
        self.parser.add_argument('--mode', type=str, default='save')
30
        self.parser.add_argument('--phase', type=str, default='test')
31
32
        # for generator
33
        self.parser.add_argument('--g_cnum', type=int, default=32,
34
                                 help='# of generator filters in first conv layer')
35
        self.parser.add_argument('--d_cnum', type=int, default=32,
36
                                 help='# of discriminator filters in first conv layer')
37
38
    def parse(self):
39
        if not self.initialized:
40
            self.initialize()
41
        self.opt = self.parser.parse_args()
42
43
        if self.opt.data_file != '':
44
            self.opt.dataset_path = self.opt.data_file
45
46
        if os.path.exists(self.opt.test_dir) is False:
47
            os.mkdir(self.opt.test_dir)
48
49
        assert self.opt.random_mask in [0, 1]
50
        self.opt.random_mask = True if self.opt.random_mask == 1 else False
51
52
        assert self.opt.mask_type in ['rect', 'stroke']
53
54
        str_img_shapes = self.opt.img_shapes.split(',')
55
        self.opt.img_shapes = [int(x) for x in str_img_shapes]
56
57
        str_mask_shapes = self.opt.mask_shapes.split(',')
58
        self.opt.mask_shapes = [int(x) for x in str_mask_shapes]
59
60
        # model name and date
61
        self.opt.date_str = 'test_'+time.strftime('%Y%m%d-%H%M%S')
62
        self.opt.model_folder = self.opt.date_str + '_' + self.opt.dataset + '_' + self.opt.model
63
        self.opt.model_folder += '_s' + str(self.opt.img_shapes[0]) + 'x' + str(self.opt.img_shapes[1])
64
        self.opt.model_folder += '_gc' + str(self.opt.g_cnum)
65
        self.opt.model_folder += '_randmask-' + self.opt.mask_type if self.opt.random_mask else ''
66
        if self.opt.random_mask:
67
            self.opt.model_folder += '_seed-' + str(self.opt.seed)
68
        self.opt.saving_path = os.path.join(self.opt.test_dir, self.opt.model_folder)
69
70
        if os.path.exists(self.opt.saving_path) is False and self.opt.mode == 'save':
71
            os.mkdir(self.opt.saving_path)
72
73
        args = vars(self.opt)
74
75
        print('------------ Options -------------')
76
        for k, v in sorted(args.items()):
77
            print('%s: %s' % (str(k), str(v)))
78
        print('-------------- End ----------------')
79
80
        return self.opt