Diff of /train.py [000000] .. [eaa663]

Switch to side-by-side view

--- a
+++ b/train.py
@@ -0,0 +1,77 @@
+# -*- coding: utf-8 -*-
+"""
+Created on Sun Apr 21 13:52:07 2019
+
+@author: Administrator
+
+train.py: 训练模型
+
+"""
+
+from Unet import Unet
+import LoadBatches1D
+import keras
+from keras import optimizers
+import warnings
+import matplotlib.pyplot as plt
+
+warnings.filterwarnings("ignore")
+
+
+def lr_schedule(epoch):
+    # 训练网络时学习率衰减方案
+    lr = 0.0001
+    if epoch >= 50:
+        lr = 0.00001
+    print('Learning rate: ', lr)
+    return lr
+
+
+train_sigs_path = 'G:/ECG_UNet/train_sigs/'
+train_segs_path = 'G:/ECG_UNet/train_labels/'
+train_batch_size = 1
+n_classes = 3
+input_length = 1800
+optimizer_name = optimizers.Adam(lr_schedule(0))
+val_sigs_path = 'G:/ECG_UNet/val_sigs/'
+val_segs_path = 'G:/ECG_UNet/val_labels/'
+val_batch_size = 2
+
+lr_scheduler = keras.callbacks.LearningRateScheduler(lr_schedule)
+
+model = Unet(n_classes, input_length=input_length)
+
+model.compile(loss='categorical_crossentropy',
+              optimizer=optimizer_name,
+              metrics=['accuracy'])
+
+model.summary()
+
+output_length = 1800
+
+G = LoadBatches1D.SigSegmentationGenerator(train_sigs_path, train_segs_path, train_batch_size, n_classes, output_length)
+
+G2 = LoadBatches1D.SigSegmentationGenerator(val_sigs_path, val_segs_path, val_batch_size, n_classes, output_length)
+
+checkpointer = keras.callbacks.ModelCheckpoint(filepath='myNet.h5', monitor='val_acc', mode='max', save_best_only=True)
+
+history = model.fit_generator(G, 500//train_batch_size, validation_data=G2, validation_steps=200, epochs=70,
+                        callbacks=[checkpointer, lr_scheduler])
+
+plt.figure()
+plt.plot(history.history['acc'])
+plt.plot(history.history['val_acc'])
+plt.title('Model accuracy')
+plt.ylabel('Accuracy')
+plt.xlabel('Epoch')
+plt.legend(['Train', 'Test'], loc='upper left')
+plt.grid(True)
+
+plt.figure()
+plt.plot(history.history['loss'])
+plt.plot(history.history['val_loss'])
+plt.title('Model loss')
+plt.ylabel('Loss')
+plt.xlabel('Epoch')
+plt.legend(['Train', 'Test'], loc='upper left')
+plt.grid(True)