a b/src/train.py
1
from __future__ import division, print_function
2
import matplotlib.pyplot as plt
3
import numpy as np
4
from tqdm import tqdm
5
from keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard, ReduceLROnPlateau, LearningRateScheduler
6
from keras import models
7
from graph import ECG_model
8
from config import get_config
9
from utils import *
10
11
def train(config, X, y, Xval=None, yval=None):
12
    
13
    classes = ['N','V','/','A','F','~']#,'L','R',f','j','E','a']#,'J','Q','e','S']
14
    print("Initial shapes - X:", X.shape, "y:", y.shape)
15
    print("Initial validation shapes - Xval:", Xval.shape if Xval is not None else None, "yval:", yval.shape if yval is not None else None)
16
    print("Any NaN in initial X:", np.any(np.isnan(X)), "y:", np.any(np.isnan(y)))
17
    
18
    Xe = np.expand_dims(X, axis=2)
19
    if not config.split:
20
        from sklearn.model_selection import train_test_split
21
        Xe, Xvale, y, yval = train_test_split(Xe, y, test_size=0.2, random_state=1)
22
    else:
23
        Xvale = np.expand_dims(Xval, axis=2)
24
25
        print("Data shapes before training - Xe:", Xe.shape, "y:", y.shape)
26
        print("Val shapes before training - Xvale:", Xvale.shape, "yval:", yval.shape)
27
        
28
        
29
        print("Final shapes - Xe:", Xe.shape, "y:", y.shape)
30
        print("Final val shapes - Xvale:", Xvale.shape, "yval:", yval.shape)
31
    if config.checkpoint_path is not None:
32
        model = models.load_model(config.checkpoint_path)
33
        initial_epoch = config.resume_epoch # put the resuming epoch
34
    else:
35
        model = ECG_model(config)
36
        initial_epoch = 0
37
38
    mkdir_recursive('models')
39
    #lr_decay_callback = LearningRateSchedulerPerBatch(lambda epoch: 0.1)
40
    
41
    # Validate input data
42
    if np.any(np.isnan(Xe)) or np.any(np.isnan(y)):
43
        raise ValueError("Input data contains None/NaN values")
44
    if np.any(np.isnan(Xvale)) or np.any(np.isnan(yval)):
45
        raise ValueError("Validation data contains None/NaN values")
46
    callbacks = [
47
            EarlyStopping(patience = config.patience, verbose=1),
48
            ReduceLROnPlateau(factor = 0.5, patience = 3, min_lr = 0.01, verbose=1),
49
            TensorBoard(log_dir='./logs', histogram_freq=0, write_graph=True, write_images=True),
50
            ModelCheckpoint('models/{}-latest.keras'.format(config.feature), monitor='val_loss', save_best_only=False, verbose=1, save_freq=10)
51
            # , lr_decay_callback
52
    ]
53
54
    model.fit(Xe, y,
55
            validation_data=(Xvale, yval),
56
            epochs=config.epochs,
57
            batch_size=config.batch,
58
            callbacks=callbacks,
59
            initial_epoch=initial_epoch)
60
    print_results(config, model, Xvale, yval, classes, )
61
62
    #return model
63
64
def main(config):
65
    print('feature:', config.feature)
66
    #np.random.seed(0)
67
    (X,y, Xval, yval) = loaddata(config.input_size, config.feature)
68
    print(X, y)
69
    train(config, X, y, Xval, yval)
70
71
if __name__=="__main__":
72
    config = get_config()
73
    main(config)