Diff of /test_a_sig.py [000000] .. [eaa663]

Switch to unified view

a b/test_a_sig.py
1
# -*- coding: utf-8 -*-
2
"""
3
Created on Sun Apr 21 14:08:55 2019
4
5
@author: Winham
6
7
test_a_sig.py: 加载训练好的模型,从验证集中随机选取一条信号进行测试
8
9
"""
10
11
import os
12
import numpy as np
13
import tensorflow as tf
14
from keras.models import load_model
15
import keras.backend as K
16
from sklearn import preprocessing as prep
17
import matplotlib.pyplot as plt
18
import time
19
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
20
21
val_sig_path = 'G:/ECG_UNet/val_sigs/'
22
val_label_path = 'G:/ECG_UNet/val_labels/'
23
24
sig_files = os.listdir(val_sig_path)
25
label_files = os.listdir(val_label_path)
26
27
select = np.random.choice(sig_files, 1)[0]
28
29
a_sig = np.load(val_sig_path+select)
30
a_seg = np.load(val_label_path+select)
31
32
K.clear_session()
33
tf.reset_default_graph()
34
model = load_model('myNet.h5')
35
36
a_sig = np.expand_dims(prep.scale(a_sig), axis=1)
37
a_sig = np.expand_dims(a_sig, axis=0)
38
39
tic = time.time()
40
a_pred = model.predict(a_sig)
41
toc = time.time()
42
43
44
print('Elapsed time: '+str(toc-tic)+' seconds.')
45
plt.plot(a_sig[0, :, 0])
46
plt.grid(True)
47
48
plt.plot(a_pred[0, :, 0], 'b')
49
plt.plot(a_pred[0, :, 1], 'k')
50
plt.plot(a_pred[0, :, 2], 'r')
51
plt.legend(['Sig', 'Background', 'Normal', 'PVC'], loc='lower right')