a b/train.py
1
# -*- coding: utf-8 -*-
2
"""
3
Created on Sun Apr 21 13:52:07 2019
4
5
@author: Administrator
6
7
train.py: 训练模型
8
9
"""
10
11
from Unet import Unet
12
import LoadBatches1D
13
import keras
14
from keras import optimizers
15
import warnings
16
import matplotlib.pyplot as plt
17
18
warnings.filterwarnings("ignore")
19
20
21
def lr_schedule(epoch):
22
    # 训练网络时学习率衰减方案
23
    lr = 0.0001
24
    if epoch >= 50:
25
        lr = 0.00001
26
    print('Learning rate: ', lr)
27
    return lr
28
29
30
train_sigs_path = 'G:/ECG_UNet/train_sigs/'
31
train_segs_path = 'G:/ECG_UNet/train_labels/'
32
train_batch_size = 1
33
n_classes = 3
34
input_length = 1800
35
optimizer_name = optimizers.Adam(lr_schedule(0))
36
val_sigs_path = 'G:/ECG_UNet/val_sigs/'
37
val_segs_path = 'G:/ECG_UNet/val_labels/'
38
val_batch_size = 2
39
40
lr_scheduler = keras.callbacks.LearningRateScheduler(lr_schedule)
41
42
model = Unet(n_classes, input_length=input_length)
43
44
model.compile(loss='categorical_crossentropy',
45
              optimizer=optimizer_name,
46
              metrics=['accuracy'])
47
48
model.summary()
49
50
output_length = 1800
51
52
G = LoadBatches1D.SigSegmentationGenerator(train_sigs_path, train_segs_path, train_batch_size, n_classes, output_length)
53
54
G2 = LoadBatches1D.SigSegmentationGenerator(val_sigs_path, val_segs_path, val_batch_size, n_classes, output_length)
55
56
checkpointer = keras.callbacks.ModelCheckpoint(filepath='myNet.h5', monitor='val_acc', mode='max', save_best_only=True)
57
58
history = model.fit_generator(G, 500//train_batch_size, validation_data=G2, validation_steps=200, epochs=70,
59
                        callbacks=[checkpointer, lr_scheduler])
60
61
plt.figure()
62
plt.plot(history.history['acc'])
63
plt.plot(history.history['val_acc'])
64
plt.title('Model accuracy')
65
plt.ylabel('Accuracy')
66
plt.xlabel('Epoch')
67
plt.legend(['Train', 'Test'], loc='upper left')
68
plt.grid(True)
69
70
plt.figure()
71
plt.plot(history.history['loss'])
72
plt.plot(history.history['val_loss'])
73
plt.title('Model loss')
74
plt.ylabel('Loss')
75
plt.xlabel('Epoch')
76
plt.legend(['Train', 'Test'], loc='upper left')
77
plt.grid(True)