Diff of /src/config.py [000000] .. [a378de]

Switch to unified view

a b/src/config.py
1
#-*- coding: utf-8 -*-
2
import argparse
3
4
parser = argparse.ArgumentParser()
5
6
def add_argument_group(name):
7
    arg = parser.add_argument_group(name)
8
    return arg
9
10
11
misc_arg = add_argument_group('misc')
12
misc_arg.add_argument('--split', type=bool, default = True)
13
misc_arg.add_argument('--input_size', type=int, default = 256, 
14
                      help='multiplies of 256 by the structure of the model') 
15
misc_arg.add_argument('--use_network', type=bool, default = False)
16
17
data_arg = add_argument_group('data')
18
data_arg.add_argument('--downloading', type=bool, default = False)
19
20
graph_arg = add_argument_group('graph')
21
graph_arg.add_argument('--filter_length', type=int, default = 32)
22
graph_arg.add_argument('--kernel_size', type=int, default = 16)
23
graph_arg.add_argument('--drop_rate', type=float, default = 0.2)
24
25
train_arg = add_argument_group('train')
26
train_arg.add_argument('--feature', type=str, default = "MLII",
27
                       help='one of MLII, V1, V2, V4, V5. Favorably MLII or V1')
28
train_arg.add_argument('--epochs', type=int, default = 80)
29
train_arg.add_argument('--batch', type=int, default = 256)
30
train_arg.add_argument('--patience', type=int, default = 10)
31
train_arg.add_argument('--min_lr', type=float, default = 0.00005)
32
train_arg.add_argument('--checkpoint_path', type=str, default = None)
33
train_arg.add_argument('--resume_epoch', type=int)
34
train_arg.add_argument('--ensemble', type=bool, default = False)
35
train_arg.add_argument('--trained_model', type=str, default = None, 
36
                       help='dir and filename of the trained model for usage.')
37
38
predict_arg = add_argument_group('predict')
39
predict_arg.add_argument('--num', type=int, default = None)
40
predict_arg.add_argument('--upload', type=bool, default = False)
41
predict_arg.add_argument('--sample_rate', type=int, default = None)
42
predict_arg.add_argument('--cinc_download', type=bool, default = False)
43
44
45
46
def get_config():
47
    config, unparsed = parser.parse_known_args()
48
49
    return config