|
a |
|
b/deepheart/train_model.py |
|
|
1 |
from parser import PCG |
|
|
2 |
from model import CNN |
|
|
3 |
import sys |
|
|
4 |
|
|
|
5 |
true_strs = {"True", "true", "t"} |
|
|
6 |
|
|
|
7 |
def load_and_train_model(model_path, load_pretrained): |
|
|
8 |
pcg = PCG(model_path) |
|
|
9 |
|
|
|
10 |
if load_pretrained: |
|
|
11 |
pcg.load("/tmp") |
|
|
12 |
else: |
|
|
13 |
pcg.initialize_wav_data() |
|
|
14 |
|
|
|
15 |
cnn = CNN(pcg, epochs=100, dropout=0.5) |
|
|
16 |
cnn.train() |
|
|
17 |
|
|
|
18 |
if __name__ == '__main__': |
|
|
19 |
data_path = sys.argv[1] |
|
|
20 |
|
|
|
21 |
load_pretrained = False |
|
|
22 |
if len(sys.argv) == 3: |
|
|
23 |
load_pretrained = sys.argv[2] in true_strs |
|
|
24 |
|
|
|
25 |
load_and_train_model(data_path, load_pretrained) |