Diff of /Inference/run_model.py [000000] .. [c0487b]

Switch to unified view

a b/Inference/run_model.py
1
import argparse
2
import h5py 
3
import numpy as np
4
import os
5
from functools import reduce
6
import pandas as pd
7
import matplotlib.pyplot as plt
8
from matplotlib.widgets import Slider, Button, RadioButtons
9
import sys
10
from pyqtgraph.Qt import QtGui, QtCore
11
12
import torch
13
from torch.utils.data import TensorDataset,DataLoader
14
15
from PyQT_Plot import create_dashboard
16
from preprocess_data import data_read,windowing_and_resampling_hr,windowing_and_resampling_br
17
from utils import load_model_HR,load_model_BR,compute_heart_rate
18
19
def main(args):
20
    
21
    preprocessed_patient_data = data_read(args)
22
    print('-------- Data Acquisition Complete --------')
23
    windowed_patient_overlap,windowed_patient = windowing_and_resampling_hr(preprocessed_patient_data)
24
    print('-------- Pre-processing Complete for HR---------')
25
    windowed_patient_overlap_br = windowing_and_resampling_br(preprocessed_patient_data)
26
    print('-------- Pre-processing Complete for BR---------')
27
28
    ###
29
    patient_ecg = np.asarray(windowed_patient_overlap['ecg'][0][:60])
30
    actual_ecg_windows = np.asarray(windowed_patient['ecg'][0][:60])
31
    
32
    patient_ecg_br = np.asarray(windowed_patient_overlap_br['ecg'][0][:60])
33
    ###
34
35
    batch_len = 32
36
    batch_len_br = 1
37
    window_size = 5000
38
39
    patient_ecg = torch.from_numpy(patient_ecg).view(patient_ecg.shape[0],1,patient_ecg.shape[1]).float()
40
    input_ecg = TensorDataset(patient_ecg)
41
    testloader = DataLoader(input_ecg,batch_len)
42
43
    patient_ecg_br = torch.from_numpy(patient_ecg_br).view(patient_ecg_br.shape[0],1,patient_ecg_br.shape[1]).float()
44
    input_ecg_br = TensorDataset(patient_ecg_br)
45
    testloader_br = DataLoader(input_ecg_br,batch_len_br)
46
47
    SAVED_HR_MODEL_PATH = args.saved_hr_model_path
48
    SAVED_BR_MODEL_PATH = args.saved_br_model_path
49
    device = args.device
50
    
51
    ecg_peak_locs = load_model_HR(SAVED_HR_MODEL_PATH,testloader,device,batch_len,window_size)     
52
    br_peak_locs = load_model_BR(SAVED_BR_MODEL_PATH,testloader_br,device,batch_len,window_size)
53
54
    ### Finding Stored Paths
55
    save_dir = args.save_dir
56
    if not(os.path.isdir(save_dir)):
57
        os.mkdir(save_dir)
58
59
    save_path =  save_dir + '/r_peaks_patient_' + str(args.patient_no) + '.csv'
60
61
    all_hr = []
62
    initial_hr = len([peak for peak in list(ecg_peak_locs) if peak < 5000 * 6])
63
    
64
    for i in range(patient_ecg.shape[0]):
65
        all_hr.append( len([peak for peak in list(ecg_peak_locs) if peak > i * 2500 and peak < (i * 2500 ) + 5000 * 6 ]))
66
    unique = np.unique(np.asarray(all_hr))
67
    peak_no = np.linspace(1,len(ecg_peak_locs),len(ecg_peak_locs)).astype(int)
68
    peak_no = peak_no.reshape(-1,1)
69
    ecg_peak_locs = ecg_peak_locs.reshape(-1,1) 
70
    ecg_peak_locs = np.hstack((peak_no,ecg_peak_locs))
71
72
    pd.DataFrame(ecg_peak_locs).to_csv(save_path , header=None, index=None)  
73
    print('-------- R Peaks Saved --------')
74
75
    all_br = []
76
    initial_br = len([peak for peak in list(br_peak_locs) if peak < 1250 * 6])
77
    for i in range(patient_ecg.shape[0]):
78
        all_br.append( len([peak for peak in list(br_peak_locs) if peak > i * 625 and peak < (i * 625 ) + 1250 * 6 ]))
79
        # all_br.append( len([peak for peak in list(br_peak_locs) if peak > i * 2500 and peak < (i * 2500 ) + 5000 * 6 ]))
80
81
    i = 1
82
    scatter_peak = []
83
    scatter_peak_1 = []
84
    ecg_point = []
85
    ecg_point_1 = []
86
    k = 0
87
    hr = []
88
    peak_locs = ecg_peak_locs[:,1]
89
    for j in range(len(peak_locs)):     
90
        if(peak_locs[j] < 5000*i):
91
            scatter_peak.append(peak_locs[j]-5000*(i-1))
92
            if(i< len(actual_ecg_windows)):
93
                ecg_point.append(actual_ecg_windows[i-1,scatter_peak[k]])
94
                k = k+1                         
95
        elif(peak_locs[j] >= 5000*i):
96
            scatter_peak_1.append(np.asarray(scatter_peak))
97
            hr.append(compute_heart_rate(scatter_peak_1[i-1]))
98
            ecg_point_1.append(np.asarray(ecg_point))                     
99
            scatter_peak = []
100
            ecg_point = []
101
            i = i+1
102
            scatter_peak.append(peak_locs[j]-5000*(i-1))
103
            k = 0
104
            if(i< len(actual_ecg_windows)):
105
                ecg_point.append(actual_ecg_windows[i-1,scatter_peak[k]])
106
                k = k+1
107
    import pdb;pdb.set_trace()
108
    if(args.viewer):
109
        create_dashboard(actual_ecg_windows,scatter_peak_1,all_hr,all_br)
110
    
111
    
112
        
113
114
if __name__ == "__main__":
115
116
    parser = argparse.ArgumentParser()
117
    parser.add_argument('--path_dir',help = 'Path to all the records')
118
    parser.add_argument('--saved_hr_model_path',help = 'Path to saved Heart rate model')
119
    parser.add_argument('--saved_br_model_path',help = 'Path to saved breathing rate model')
120
    parser.add_argument('--patient_no',default = 8,type = int,help = 'Patient used for testing')
121
    parser.add_argument('--device',default = 'cuda', help = 'cpu/cuda')
122
    parser.add_argument('--save_dir',default = 'saved_models/',help = 'Directory used for saving')
123
    parser.add_argument('--viewer',default = 0,type = int, help = 'To view ECG plot: 1, else: 0')
124
    
125
    args = parser.parse_args()
126
127
    main(args)
128