|
a |
|
b/model/lstmRNNModel.py |
|
|
1 |
""" |
|
|
2 |
LSTM-RNN model for OSA detection. |
|
|
3 |
""" |
|
|
4 |
|
|
|
5 |
from keras.models import Sequential |
|
|
6 |
from keras.layers import Dense, LSTM |
|
|
7 |
from keras.utils import plot_model |
|
|
8 |
import os |
|
|
9 |
import numpy as np |
|
|
10 |
import tensorflow as tf |
|
|
11 |
import keras.backend.tensorflow_backend as KTF |
|
|
12 |
|
|
|
13 |
from model.common import TrainingMonitor, ModelCheckpoint, LossHistory |
|
|
14 |
|
|
|
15 |
RR_INTERVALS_INTERPOLATION = 240 |
|
|
16 |
# handcraft_features |
|
|
17 |
|
|
|
18 |
test_number = 1 |
|
|
19 |
base_floder_path = "result/lstm/" + "test_" + str(test_number) + "/" |
|
|
20 |
|
|
|
21 |
if not os.path.exists(base_floder_path): |
|
|
22 |
os.makedirs(base_floder_path) |
|
|
23 |
train_loss_path = base_floder_path + "train_loss.txt" |
|
|
24 |
validation_loss_path = base_floder_path + "validation_loss.txt" |
|
|
25 |
train_acc_path = base_floder_path + "train_acc.txt" |
|
|
26 |
validation_acc_path = base_floder_path + "validation_acc.txt" |
|
|
27 |
|
|
|
28 |
# GPU config |
|
|
29 |
config = tf.ConfigProto() |
|
|
30 |
config.gpu_options.allow_growth = True |
|
|
31 |
config.gpu_options.per_process_gpu_memory_fraction = 0.7 |
|
|
32 |
sess = tf.Session(config=config) |
|
|
33 |
KTF.set_session(sess) |
|
|
34 |
|
|
|
35 |
|
|
|
36 |
def get_dataset(): |
|
|
37 |
train_rri_amp_edr = np.load("G:/python project/apneaECGCode/data/apnea-ecg_train_clear_rri_ramp_edr.npy") |
|
|
38 |
train_label = np.load("G:/python project/apneaECGCode/data/apnea-ecg_train_clear_label.npy") |
|
|
39 |
test_rri_amp_edr = np.load("G:/python project/apneaECGCode/data/apnea-ecg_test_clear_rri_ramp_edr.npy") |
|
|
40 |
test_label = np.load("G:/python project/apneaECGCode/data/apnea-ecg_test_clear_label.npy") |
|
|
41 |
|
|
|
42 |
train_label = train_label.astype(dtype=np.int) |
|
|
43 |
test_label = test_label.astype(dtype=np.int) |
|
|
44 |
return train_rri_amp_edr, train_label, test_rri_amp_edr, test_label |
|
|
45 |
|
|
|
46 |
|
|
|
47 |
def create_lstm_model(input_shape): |
|
|
48 |
model = Sequential() |
|
|
49 |
model.add(LSTM(384, input_shape=input_shape, use_bias=True, dropout=0.1, |
|
|
50 |
recurrent_dropout=0.05, return_sequences=True)) |
|
|
51 |
# model.add(LeakyReLU(alpha=1)) |
|
|
52 |
# model.add(BatchNormalization()) |
|
|
53 |
model.add(LSTM(384, use_bias=True, dropout=0.2, |
|
|
54 |
recurrent_dropout=0.05, return_sequences=True)) |
|
|
55 |
# model.add(LeakyReLU(alpha=1)) |
|
|
56 |
# model.add(BatchNormalization()) |
|
|
57 |
model.add(LSTM(384, use_bias=True, dropout=0.3, |
|
|
58 |
recurrent_dropout=0.05)) |
|
|
59 |
# model.add(LeakyReLU(alpha=1)) |
|
|
60 |
# model.add(BatchNormalization()) |
|
|
61 |
# model.add(LSTM(64, use_bias=True, |
|
|
62 |
# dropout=0.7, recurrent_dropout=0.7)) |
|
|
63 |
# model.add(LeakyReLU(alpha=1)) |
|
|
64 |
# model.add(BatchNormalization()) |
|
|
65 |
model.add(Dense(128)) |
|
|
66 |
# model.add(Dropout(0.8)) |
|
|
67 |
# model.add(LeakyReLU(alpha=1)) |
|
|
68 |
model.add(Dense(64)) |
|
|
69 |
model.add(Dense(32)) |
|
|
70 |
# model.add(Dropout(0.5)) |
|
|
71 |
# model.add(LeakyReLU(alpha=1)) |
|
|
72 |
model.add(Dense(1, activation="sigmoid")) |
|
|
73 |
|
|
|
74 |
model.compile(loss="binary_crossentropy", optimizer="adam", metrics=['accuracy']) |
|
|
75 |
|
|
|
76 |
model.summary() |
|
|
77 |
plot_model(model, to_file=base_floder_path + '/lstm_model.png', show_shapes=True) |
|
|
78 |
|
|
|
79 |
return model |
|
|
80 |
|
|
|
81 |
|
|
|
82 |
def train_network(): |
|
|
83 |
print("read data...") |
|
|
84 |
X_train1, y_train, X_test1, y_test = get_dataset() |
|
|
85 |
|
|
|
86 |
model = create_lstm_model(input_shape=(RR_INTERVALS_INTERPOLATION, 3)) |
|
|
87 |
fig_path = base_floder_path |
|
|
88 |
model_file_path = base_floder_path + "/model" |
|
|
89 |
if not os.path.exists(model_file_path): |
|
|
90 |
os.makedirs(model_file_path) |
|
|
91 |
model_file_path += "/model_{epoch:02d}-{val_acc:.6f}.hdf5" |
|
|
92 |
checkpoint = ModelCheckpoint(model_file_path, monitor='val_acc', verbose=1, save_best_only=True) |
|
|
93 |
callbacks = [ |
|
|
94 |
TrainingMonitor(fig_path, model, train_loss_path, validation_loss_path, train_acc_path, validation_acc_path) |
|
|
95 |
, checkpoint] |
|
|
96 |
print("Training") |
|
|
97 |
history = LossHistory() |
|
|
98 |
history.init() |
|
|
99 |
model.fit(X_train1, y_train, batch_size=128, epochs=500, callbacks=callbacks, validation_data=(X_test1, y_test)) |
|
|
100 |
return model |
|
|
101 |
|
|
|
102 |
|
|
|
103 |
if __name__ == '__main__': |
|
|
104 |
train_network() |
|
|
105 |
|