|
a |
|
b/train.py |
|
|
1 |
from model.initialization import initialization |
|
|
2 |
from config import conf |
|
|
3 |
import argparse |
|
|
4 |
|
|
|
5 |
|
|
|
6 |
def boolean_string(s): |
|
|
7 |
if s.upper() not in {'FALSE', 'TRUE'}: |
|
|
8 |
raise ValueError('Not a valid boolean string') |
|
|
9 |
return s.upper() == 'TRUE' |
|
|
10 |
|
|
|
11 |
|
|
|
12 |
parser = argparse.ArgumentParser(description='Train') |
|
|
13 |
parser.add_argument('--cache', default=True, type=boolean_string, |
|
|
14 |
help='cache: if set as TRUE all the training data will be loaded at once' |
|
|
15 |
' before the training start. Default: TRUE') |
|
|
16 |
opt = parser.parse_args() |
|
|
17 |
|
|
|
18 |
m = initialization(conf, train=opt.cache)[0] |
|
|
19 |
|
|
|
20 |
print("Training START") |
|
|
21 |
m.fit() |
|
|
22 |
print("Training COMPLETE") |