[03464c]: / params / train_params.py

Download this file

51 lines (44 with data), 2.8 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from .basic_params import BasicParams
class TrainParams(BasicParams):
"""
This class is a son class of BasicParams.
This class includes parameters for training and parameters inherited from the father class.
"""
def initialize(self, parser):
parser = BasicParams.initialize(self, parser)
# Training parameters
parser.add_argument('--epoch_num_p1', type=int, default=50,
help='epoch number for phase 1')
parser.add_argument('--epoch_num_p2', type=int, default=50,
help='epoch number for phase 2')
parser.add_argument('--epoch_num_p3', type=int, default=100,
help='epoch number for phase 3')
parser.add_argument('--lr', type=float, default=1e-4,
help='initial learning rate')
parser.add_argument('--beta1', type=float, default=0.5,
help='momentum term of adam')
parser.add_argument('--lr_policy', type=str, default='linear',
help='The learning rate policy for the scheduler. [linear | step | plateau | cosine]')
parser.add_argument('--epoch_count', type=int, default=1,
help='the starting epoch count, default start from 1')
parser.add_argument('--epoch_num_decay', type=int, default=50,
help='Number of epoch to linearly decay learning rate to zero (lr_policy == linear)')
parser.add_argument('--decay_step_size', type=int, default=50,
help='The original learning rate multiply by a gamma every decay_step_size epoch (lr_policy == step)')
parser.add_argument('--weight_decay', type=float, default=1e-4,
help='weight decay (L2 penalty)')
# Network saving and loading parameters
parser.add_argument('--continue_train', action='store_true',
help='load the latest model and continue training')
parser.add_argument('--save_model', action='store_true',
help='save the model during training')
parser.add_argument('--save_epoch_freq', type=int, default=-1,
help='frequency of saving checkpoints at the end of epochs, -1 means only save the last epoch')
# Logging and visualization
parser.add_argument('--print_freq', type=int, default=1,
help='frequency of showing results on console')
parser.add_argument('--save_latent_space', action='store_true',
help='save the latent space of input data to disc')
self.isTrain = True
self.isTest = False
return parser