|
a |
|
b/train.py |
|
|
1 |
# -*- coding: utf-8 -*- |
|
|
2 |
""" |
|
|
3 |
Created on Sun Apr 21 13:52:07 2019 |
|
|
4 |
|
|
|
5 |
@author: Administrator |
|
|
6 |
|
|
|
7 |
train.py: 训练模型 |
|
|
8 |
|
|
|
9 |
""" |
|
|
10 |
|
|
|
11 |
from Unet import Unet |
|
|
12 |
import LoadBatches1D |
|
|
13 |
import keras |
|
|
14 |
from keras import optimizers |
|
|
15 |
import warnings |
|
|
16 |
import matplotlib.pyplot as plt |
|
|
17 |
|
|
|
18 |
warnings.filterwarnings("ignore") |
|
|
19 |
|
|
|
20 |
|
|
|
21 |
def lr_schedule(epoch): |
|
|
22 |
# 训练网络时学习率衰减方案 |
|
|
23 |
lr = 0.0001 |
|
|
24 |
if epoch >= 50: |
|
|
25 |
lr = 0.00001 |
|
|
26 |
print('Learning rate: ', lr) |
|
|
27 |
return lr |
|
|
28 |
|
|
|
29 |
|
|
|
30 |
train_sigs_path = 'G:/ECG_UNet/train_sigs/' |
|
|
31 |
train_segs_path = 'G:/ECG_UNet/train_labels/' |
|
|
32 |
train_batch_size = 1 |
|
|
33 |
n_classes = 3 |
|
|
34 |
input_length = 1800 |
|
|
35 |
optimizer_name = optimizers.Adam(lr_schedule(0)) |
|
|
36 |
val_sigs_path = 'G:/ECG_UNet/val_sigs/' |
|
|
37 |
val_segs_path = 'G:/ECG_UNet/val_labels/' |
|
|
38 |
val_batch_size = 2 |
|
|
39 |
|
|
|
40 |
lr_scheduler = keras.callbacks.LearningRateScheduler(lr_schedule) |
|
|
41 |
|
|
|
42 |
model = Unet(n_classes, input_length=input_length) |
|
|
43 |
|
|
|
44 |
model.compile(loss='categorical_crossentropy', |
|
|
45 |
optimizer=optimizer_name, |
|
|
46 |
metrics=['accuracy']) |
|
|
47 |
|
|
|
48 |
model.summary() |
|
|
49 |
|
|
|
50 |
output_length = 1800 |
|
|
51 |
|
|
|
52 |
G = LoadBatches1D.SigSegmentationGenerator(train_sigs_path, train_segs_path, train_batch_size, n_classes, output_length) |
|
|
53 |
|
|
|
54 |
G2 = LoadBatches1D.SigSegmentationGenerator(val_sigs_path, val_segs_path, val_batch_size, n_classes, output_length) |
|
|
55 |
|
|
|
56 |
checkpointer = keras.callbacks.ModelCheckpoint(filepath='myNet.h5', monitor='val_acc', mode='max', save_best_only=True) |
|
|
57 |
|
|
|
58 |
history = model.fit_generator(G, 500//train_batch_size, validation_data=G2, validation_steps=200, epochs=70, |
|
|
59 |
callbacks=[checkpointer, lr_scheduler]) |
|
|
60 |
|
|
|
61 |
plt.figure() |
|
|
62 |
plt.plot(history.history['acc']) |
|
|
63 |
plt.plot(history.history['val_acc']) |
|
|
64 |
plt.title('Model accuracy') |
|
|
65 |
plt.ylabel('Accuracy') |
|
|
66 |
plt.xlabel('Epoch') |
|
|
67 |
plt.legend(['Train', 'Test'], loc='upper left') |
|
|
68 |
plt.grid(True) |
|
|
69 |
|
|
|
70 |
plt.figure() |
|
|
71 |
plt.plot(history.history['loss']) |
|
|
72 |
plt.plot(history.history['val_loss']) |
|
|
73 |
plt.title('Model loss') |
|
|
74 |
plt.ylabel('Loss') |
|
|
75 |
plt.xlabel('Epoch') |
|
|
76 |
plt.legend(['Train', 'Test'], loc='upper left') |
|
|
77 |
plt.grid(True) |