|
a |
|
b/src/config.py |
|
|
1 |
#-*- coding: utf-8 -*- |
|
|
2 |
import argparse |
|
|
3 |
|
|
|
4 |
parser = argparse.ArgumentParser() |
|
|
5 |
|
|
|
6 |
def add_argument_group(name): |
|
|
7 |
arg = parser.add_argument_group(name) |
|
|
8 |
return arg |
|
|
9 |
|
|
|
10 |
|
|
|
11 |
misc_arg = add_argument_group('misc') |
|
|
12 |
misc_arg.add_argument('--split', type=bool, default = True) |
|
|
13 |
misc_arg.add_argument('--input_size', type=int, default = 256, |
|
|
14 |
help='multiplies of 256 by the structure of the model') |
|
|
15 |
misc_arg.add_argument('--use_network', type=bool, default = False) |
|
|
16 |
|
|
|
17 |
data_arg = add_argument_group('data') |
|
|
18 |
data_arg.add_argument('--downloading', type=bool, default = False) |
|
|
19 |
|
|
|
20 |
graph_arg = add_argument_group('graph') |
|
|
21 |
graph_arg.add_argument('--filter_length', type=int, default = 32) |
|
|
22 |
graph_arg.add_argument('--kernel_size', type=int, default = 16) |
|
|
23 |
graph_arg.add_argument('--drop_rate', type=float, default = 0.2) |
|
|
24 |
|
|
|
25 |
train_arg = add_argument_group('train') |
|
|
26 |
train_arg.add_argument('--feature', type=str, default = "MLII", |
|
|
27 |
help='one of MLII, V1, V2, V4, V5. Favorably MLII or V1') |
|
|
28 |
train_arg.add_argument('--epochs', type=int, default = 80) |
|
|
29 |
train_arg.add_argument('--batch', type=int, default = 256) |
|
|
30 |
train_arg.add_argument('--patience', type=int, default = 10) |
|
|
31 |
train_arg.add_argument('--min_lr', type=float, default = 0.00005) |
|
|
32 |
train_arg.add_argument('--checkpoint_path', type=str, default = None) |
|
|
33 |
train_arg.add_argument('--resume_epoch', type=int) |
|
|
34 |
train_arg.add_argument('--ensemble', type=bool, default = False) |
|
|
35 |
train_arg.add_argument('--trained_model', type=str, default = None, |
|
|
36 |
help='dir and filename of the trained model for usage.') |
|
|
37 |
|
|
|
38 |
predict_arg = add_argument_group('predict') |
|
|
39 |
predict_arg.add_argument('--num', type=int, default = None) |
|
|
40 |
predict_arg.add_argument('--upload', type=bool, default = False) |
|
|
41 |
predict_arg.add_argument('--sample_rate', type=int, default = None) |
|
|
42 |
predict_arg.add_argument('--cinc_download', type=bool, default = False) |
|
|
43 |
|
|
|
44 |
|
|
|
45 |
|
|
|
46 |
def get_config(): |
|
|
47 |
config, unparsed = parser.parse_known_args() |
|
|
48 |
|
|
|
49 |
return config |