Diff of /deepheart/train_model.py [000000] .. [d3af21]

Switch to unified view

a b/deepheart/train_model.py
1
from parser import PCG
2
from model import CNN
3
import sys
4
5
true_strs = {"True", "true", "t"}
6
7
def load_and_train_model(model_path, load_pretrained):
8
    pcg = PCG(model_path)
9
10
    if load_pretrained:
11
        pcg.load("/tmp")
12
    else:
13
        pcg.initialize_wav_data()
14
15
    cnn = CNN(pcg, epochs=100, dropout=0.5)
16
    cnn.train()
17
18
if __name__ == '__main__':
19
    data_path = sys.argv[1]
20
21
    load_pretrained = False
22
    if len(sys.argv) == 3:
23
        load_pretrained = sys.argv[2] in true_strs
24
25
    load_and_train_model(data_path, load_pretrained)