Switch to unified view

a b/params/train_test_params.py
1
from .basic_params import BasicParams
2
3
4
class TrainTestParams(BasicParams):
5
    """
6
    This class is a son class of BasicParams.
7
    This class includes parameters for training & testing and parameters inherited from the father class.
8
    """
9
    def initialize(self, parser):
10
        parser = BasicParams.initialize(self, parser)
11
12
        # Training parameters
13
        parser.add_argument('--epoch_num_p1', type=int, default=50,
14
                            help='epoch number for phase 1')
15
        parser.add_argument('--epoch_num_p2', type=int, default=50,
16
                            help='epoch number for phase 2')
17
        parser.add_argument('--epoch_num_p3', type=int, default=100,
18
                            help='epoch number for phase 3')
19
        parser.add_argument('--lr', type=float, default=1e-4,
20
                            help='initial learning rate')
21
        parser.add_argument('--beta1', type=float, default=0.5,
22
                            help='momentum term of adam')
23
        parser.add_argument('--lr_policy', type=str, default='linear',
24
                            help='The learning rate policy for the scheduler. [linear | step | plateau | cosine]')
25
        parser.add_argument('--epoch_count', type=int, default=1,
26
                            help='the starting epoch count, default start from 1')
27
        parser.add_argument('--epoch_num_decay', type=int, default=50,
28
                            help='Number of epoch to linearly decay learning rate to zero (lr_policy == linear)')
29
        parser.add_argument('--decay_step_size', type=int, default=50,
30
                            help='The original learning rate multiply by a gamma every decay_step_size epoch (lr_policy == step)')
31
        parser.add_argument('--weight_decay', type=float, default=1e-4,
32
                            help='weight decay (L2 penalty)')
33
34
        # Network saving and loading parameters
35
        parser.add_argument('--continue_train', action='store_true',
36
                            help='load the latest model and continue training')
37
        parser.add_argument('--save_model', action='store_true',
38
                            help='save the model during training')
39
        parser.add_argument('--save_epoch_freq', type=int, default=-1,
40
                            help='frequency of saving checkpoints at the end of epochs, -1 means only save the last epoch')
41
42
        # Logging and visualization
43
        parser.add_argument('--print_freq', type=int, default=1,
44
                            help='frequency of showing results on console')
45
46
        # Dataset parameters
47
        parser.add_argument('--train_ratio', type=float, default=0.8,
48
                            help='ratio of training set in the full dataset')
49
        parser.add_argument('--test_ratio', type=float, default=0.2,
50
                            help='ratio of testing set in the full dataset')
51
52
        self.isTrain = True
53
        self.isTest = True
54
        return parser