|
a |
|
b/params/train_test_params.py |
|
|
1 |
from .basic_params import BasicParams |
|
|
2 |
|
|
|
3 |
|
|
|
4 |
class TrainTestParams(BasicParams): |
|
|
5 |
""" |
|
|
6 |
This class is a son class of BasicParams. |
|
|
7 |
This class includes parameters for training & testing and parameters inherited from the father class. |
|
|
8 |
""" |
|
|
9 |
def initialize(self, parser): |
|
|
10 |
parser = BasicParams.initialize(self, parser) |
|
|
11 |
|
|
|
12 |
# Training parameters |
|
|
13 |
parser.add_argument('--epoch_num_p1', type=int, default=50, |
|
|
14 |
help='epoch number for phase 1') |
|
|
15 |
parser.add_argument('--epoch_num_p2', type=int, default=50, |
|
|
16 |
help='epoch number for phase 2') |
|
|
17 |
parser.add_argument('--epoch_num_p3', type=int, default=100, |
|
|
18 |
help='epoch number for phase 3') |
|
|
19 |
parser.add_argument('--lr', type=float, default=1e-4, |
|
|
20 |
help='initial learning rate') |
|
|
21 |
parser.add_argument('--beta1', type=float, default=0.5, |
|
|
22 |
help='momentum term of adam') |
|
|
23 |
parser.add_argument('--lr_policy', type=str, default='linear', |
|
|
24 |
help='The learning rate policy for the scheduler. [linear | step | plateau | cosine]') |
|
|
25 |
parser.add_argument('--epoch_count', type=int, default=1, |
|
|
26 |
help='the starting epoch count, default start from 1') |
|
|
27 |
parser.add_argument('--epoch_num_decay', type=int, default=50, |
|
|
28 |
help='Number of epoch to linearly decay learning rate to zero (lr_policy == linear)') |
|
|
29 |
parser.add_argument('--decay_step_size', type=int, default=50, |
|
|
30 |
help='The original learning rate multiply by a gamma every decay_step_size epoch (lr_policy == step)') |
|
|
31 |
parser.add_argument('--weight_decay', type=float, default=1e-4, |
|
|
32 |
help='weight decay (L2 penalty)') |
|
|
33 |
|
|
|
34 |
# Network saving and loading parameters |
|
|
35 |
parser.add_argument('--continue_train', action='store_true', |
|
|
36 |
help='load the latest model and continue training') |
|
|
37 |
parser.add_argument('--save_model', action='store_true', |
|
|
38 |
help='save the model during training') |
|
|
39 |
parser.add_argument('--save_epoch_freq', type=int, default=-1, |
|
|
40 |
help='frequency of saving checkpoints at the end of epochs, -1 means only save the last epoch') |
|
|
41 |
|
|
|
42 |
# Logging and visualization |
|
|
43 |
parser.add_argument('--print_freq', type=int, default=1, |
|
|
44 |
help='frequency of showing results on console') |
|
|
45 |
|
|
|
46 |
# Dataset parameters |
|
|
47 |
parser.add_argument('--train_ratio', type=float, default=0.8, |
|
|
48 |
help='ratio of training set in the full dataset') |
|
|
49 |
parser.add_argument('--test_ratio', type=float, default=0.2, |
|
|
50 |
help='ratio of testing set in the full dataset') |
|
|
51 |
|
|
|
52 |
self.isTrain = True |
|
|
53 |
self.isTest = True |
|
|
54 |
return parser |