|
a |
|
b/ecgtoHR/utils.py |
|
|
1 |
import torch |
|
|
2 |
import torch.nn as nn |
|
|
3 |
import torch.nn.functional as functional |
|
|
4 |
|
|
|
5 |
import numpy as np |
|
|
6 |
import pandas as pd |
|
|
7 |
import scipy.signal |
|
|
8 |
from sklearn.preprocessing import MinMaxScaler, StandardScaler |
|
|
9 |
|
|
|
10 |
from tqdm import tqdm |
|
|
11 |
import wfdb as wf |
|
|
12 |
|
|
|
13 |
def dist_transform(window_size, ann): |
|
|
14 |
|
|
|
15 |
""" Compute distance transform of Respiration signaal based on breath positions |
|
|
16 |
Arguments: |
|
|
17 |
window_size{int} -- Window Length |
|
|
18 |
ann{ndarray} -- The ground truth R-Peaks |
|
|
19 |
Returns: |
|
|
20 |
ndarray -- transformed signal |
|
|
21 |
""" |
|
|
22 |
|
|
|
23 |
length = window_size |
|
|
24 |
transform = [] |
|
|
25 |
|
|
|
26 |
sample = 0 |
|
|
27 |
if len(ann) == 0: |
|
|
28 |
return None |
|
|
29 |
|
|
|
30 |
if len(ann) ==1: |
|
|
31 |
for i in range(window_size): |
|
|
32 |
transform.append(abs(i-ann[sample])) |
|
|
33 |
else: |
|
|
34 |
for i in range(window_size): |
|
|
35 |
|
|
|
36 |
if sample+1 == len(ann): |
|
|
37 |
for j in range(i,window_size): |
|
|
38 |
|
|
|
39 |
transform.append(abs(j - nextAnn)) |
|
|
40 |
break |
|
|
41 |
prevAnn = ann[sample] |
|
|
42 |
nextAnn = ann[sample+1] |
|
|
43 |
middle = int((prevAnn + nextAnn )/2) |
|
|
44 |
if i < middle: |
|
|
45 |
transform.append(abs(i - prevAnn)) |
|
|
46 |
elif i>= middle: |
|
|
47 |
transform.append(abs(i- nextAnn)) |
|
|
48 |
if i == nextAnn: |
|
|
49 |
sample+=1 |
|
|
50 |
|
|
|
51 |
transform = np.array(transform) |
|
|
52 |
minmaxScaler = MinMaxScaler() |
|
|
53 |
transform = minmaxScaler.fit_transform(transform.reshape((-1,1))) |
|
|
54 |
return transform |
|
|
55 |
|
|
|
56 |
def getWindow(all_paths): |
|
|
57 |
|
|
|
58 |
""" Windowing the ECG and its corresponding Distance Transform |
|
|
59 |
Arguments: |
|
|
60 |
all_paths{list} -- Paths to all the ECG files |
|
|
61 |
Returns: |
|
|
62 |
windowed_data{list(ndarray)},windowed_beats{list(ndarray)} -- Returns winodwed ECG and windowed ground truth |
|
|
63 |
""" |
|
|
64 |
|
|
|
65 |
windowed_data = [] |
|
|
66 |
windowed_beats = [] |
|
|
67 |
count = 0 |
|
|
68 |
count1 = 0 |
|
|
69 |
|
|
|
70 |
for path in tqdm(all_paths): |
|
|
71 |
|
|
|
72 |
ann = wf.rdann(path,'atr') |
|
|
73 |
record = wf.io.rdrecord(path) |
|
|
74 |
beats = ann.sample |
|
|
75 |
labels = ann.symbol |
|
|
76 |
len_beats = len(beats) |
|
|
77 |
data = record.p_signal[:,0] |
|
|
78 |
|
|
|
79 |
ini_index = 0 |
|
|
80 |
final_index = 0 |
|
|
81 |
### Checking for Beat annotations |
|
|
82 |
non_required_labels = ['[','!',']','x','(',')','p','t','u','`',"'",'^','|','~','+','s','T','*','D','=','"','@'] |
|
|
83 |
for window in range(len(data) // 3600): |
|
|
84 |
count += 1 |
|
|
85 |
for r_peak in range(ini_index,len_beats): |
|
|
86 |
if beats[r_peak] > (window+1) * 3600: |
|
|
87 |
final_index = r_peak |
|
|
88 |
#print('FInal index:',final_index) |
|
|
89 |
break |
|
|
90 |
record_anns = list(beats[ini_index: final_index]) |
|
|
91 |
record_labs = labels[ini_index: final_index] |
|
|
92 |
to_del_index = [] |
|
|
93 |
for actual_lab in range(len(record_labs)): |
|
|
94 |
for lab in range(len(non_required_labels)): |
|
|
95 |
if(record_labs[actual_lab] == non_required_labels[lab]): |
|
|
96 |
to_del_index.append(actual_lab) |
|
|
97 |
for indice in range(len(to_del_index)-1,-1,-1): |
|
|
98 |
del record_anns[to_del_index[indice]] |
|
|
99 |
windowed_beats.append(np.asarray(record_anns) - (window) * 3600) |
|
|
100 |
windowed_data.append(data[window * 3600 : (window+1) * 3600]) |
|
|
101 |
ini_index = final_index |
|
|
102 |
|
|
|
103 |
return windowed_data,windowed_beats |
|
|
104 |
|
|
|
105 |
def testDataEval(model, loader, criterion): |
|
|
106 |
|
|
|
107 |
"""Test model on dataloader |
|
|
108 |
|
|
|
109 |
Arguments: |
|
|
110 |
model {torch object} -- Model |
|
|
111 |
loader {torch object} -- Data Loader |
|
|
112 |
criterion {torch object} -- Loss Function |
|
|
113 |
Returns: |
|
|
114 |
float -- total loss |
|
|
115 |
""" |
|
|
116 |
|
|
|
117 |
model.eval() |
|
|
118 |
|
|
|
119 |
with torch.no_grad(): |
|
|
120 |
|
|
|
121 |
total_loss = 0 |
|
|
122 |
|
|
|
123 |
for (x,y) in loader: |
|
|
124 |
|
|
|
125 |
ecg,BR = x.unsqueeze(1).cuda(),y.unsqueeze(1).cuda() |
|
|
126 |
BR_pred = model(ecg) |
|
|
127 |
loss = criterion(BR_pred, BR) |
|
|
128 |
total_loss += loss |
|
|
129 |
|
|
|
130 |
return total_loss |
|
|
131 |
|
|
|
132 |
|
|
|
133 |
def save_model(exp_dir, epoch, model, optimizer,best_dev_loss): |
|
|
134 |
|
|
|
135 |
""" save checkpoint of model |
|
|
136 |
|
|
|
137 |
Arguments: |
|
|
138 |
exp_dir {string} -- Path to checkpoint |
|
|
139 |
epoch {int} -- epoch at which model is checkpointed |
|
|
140 |
model -- model state to be checkpointed |
|
|
141 |
optimizer {torch optimizer object} -- optimizer state of model to be checkpoint |
|
|
142 |
best_dev_loss {float} -- loss of model to be checkpointed |
|
|
143 |
""" |
|
|
144 |
|
|
|
145 |
out = torch.save( |
|
|
146 |
{ |
|
|
147 |
'epoch': epoch, |
|
|
148 |
'model': model.state_dict(), |
|
|
149 |
'optimizer': optimizer.state_dict(), |
|
|
150 |
'best_dev_loss': best_dev_loss, |
|
|
151 |
'exp_dir':exp_dir |
|
|
152 |
}, |
|
|
153 |
f=exp_dir + '/best_model.pt' |
|
|
154 |
) |
|
|
155 |
|
|
|
156 |
def findValleys(signal, prominence = 10, is_smooth = True , distance = 10): |
|
|
157 |
|
|
|
158 |
""" Return prominent peaks and valleys based on scipy's find_peaks function """ |
|
|
159 |
smoothened = smooth(-1*signal) |
|
|
160 |
valley_loc = scipy.signal.find_peaks(smoothened, prominence= 0.07)[0] |
|
|
161 |
|
|
|
162 |
return valley_loc |