a b/ecgtoHR/utils.py
1
import torch
2
import torch.nn as nn
3
import torch.nn.functional as functional
4
5
import numpy as np
6
import pandas as pd
7
import scipy.signal
8
from sklearn.preprocessing import MinMaxScaler, StandardScaler
9
10
from tqdm import tqdm
11
import wfdb as wf
12
13
def dist_transform(window_size, ann):
14
    
15
    """ Compute distance transform of Respiration signaal based on breath positions
16
    Arguments:
17
        window_size{int} -- Window Length  
18
        ann{ndarray} -- The ground truth R-Peaks
19
    Returns:
20
       ndarray -- transformed signal
21
    """
22
23
    length = window_size
24
    transform = []
25
26
    sample = 0
27
    if len(ann) == 0:
28
        return None
29
30
    if len(ann) ==1:
31
        for i in range(window_size):
32
            transform.append(abs(i-ann[sample]))
33
    else:
34
        for i in range(window_size):
35
36
            if sample+1 == len(ann):
37
                for j in range(i,window_size):
38
39
                    transform.append(abs(j - nextAnn))
40
                break
41
            prevAnn = ann[sample]
42
            nextAnn = ann[sample+1]
43
            middle = int((prevAnn + nextAnn )/2) 
44
            if i < middle:
45
                transform.append(abs(i - prevAnn))
46
            elif i>= middle:
47
                transform.append(abs(i- nextAnn))
48
            if i == nextAnn:
49
                sample+=1
50
51
    transform = np.array(transform)
52
    minmaxScaler = MinMaxScaler()
53
    transform = minmaxScaler.fit_transform(transform.reshape((-1,1)))
54
    return transform
55
56
def getWindow(all_paths):
57
    
58
    """ Windowing the ECG and its corresponding Distance Transform
59
    Arguments:
60
        all_paths{list} -- Paths to all the ECG files
61
    Returns:
62
        windowed_data{list(ndarray)},windowed_beats{list(ndarray)} -- Returns winodwed ECG and windowed ground truth
63
    """
64
65
    windowed_data = []
66
    windowed_beats = []
67
    count = 0
68
    count1 = 0
69
    
70
    for path in tqdm(all_paths):
71
        
72
        ann    = wf.rdann(path,'atr')
73
        record = wf.io.rdrecord(path)
74
        beats  = ann.sample
75
        labels = ann.symbol
76
        len_beats = len(beats)
77
        data = record.p_signal[:,0]
78
79
        ini_index = 0
80
        final_index = 0
81
        ### Checking for Beat annotations
82
        non_required_labels = ['[','!',']','x','(',')','p','t','u','`',"'",'^','|','~','+','s','T','*','D','=','"','@']
83
        for window in range(len(data) // 3600):
84
            count += 1
85
            for r_peak in range(ini_index,len_beats):
86
                if beats[r_peak] > (window+1) * 3600:
87
                    final_index = r_peak
88
                    #print('FInal index:',final_index)
89
                    break
90
            record_anns = list(beats[ini_index: final_index])
91
            record_labs = labels[ini_index: final_index]
92
            to_del_index = []
93
            for actual_lab in range(len(record_labs)):
94
                for lab in range(len(non_required_labels)):
95
                    if(record_labs[actual_lab] == non_required_labels[lab]):
96
                        to_del_index.append(actual_lab)
97
            for indice in range(len(to_del_index)-1,-1,-1):
98
                del record_anns[to_del_index[indice]]
99
            windowed_beats.append(np.asarray(record_anns) - (window) * 3600)
100
            windowed_data.append(data[window * 3600 : (window+1) * 3600])
101
            ini_index = final_index
102
103
    return windowed_data,windowed_beats
104
105
def testDataEval(model, loader, criterion):
106
    
107
    """Test model on dataloader
108
    
109
    Arguments:
110
        model {torch object} -- Model   
111
        loader {torch object} -- Data Loader  
112
        criterion {torch object} -- Loss Function
113
    Returns:
114
        float -- total loss
115
    """
116
117
    model.eval()
118
    
119
    with torch.no_grad():
120
        
121
        total_loss = 0
122
        
123
        for (x,y) in loader:
124
            
125
            ecg,BR = x.unsqueeze(1).cuda(),y.unsqueeze(1).cuda()
126
            BR_pred = model(ecg)
127
            loss = criterion(BR_pred, BR)
128
            total_loss += loss
129
            
130
    return total_loss
131
132
133
def save_model(exp_dir, epoch, model, optimizer,best_dev_loss):
134
135
    """ save checkpoint of model 
136
    
137
    Arguments:
138
        exp_dir {string} -- Path to checkpoint
139
        epoch {int} -- epoch at which model is checkpointed
140
        model -- model state to be checkpointed
141
        optimizer {torch optimizer object} -- optimizer state of model to be checkpoint
142
        best_dev_loss {float} -- loss of model to be checkpointed
143
    """
144
145
    out = torch.save(
146
        {
147
            'epoch': epoch,
148
            'model': model.state_dict(),
149
            'optimizer': optimizer.state_dict(),
150
            'best_dev_loss': best_dev_loss,
151
            'exp_dir':exp_dir
152
        },
153
        f=exp_dir + '/best_model.pt'
154
    )
155
156
def findValleys(signal, prominence = 10, is_smooth = True , distance = 10):
157
    
158
    """ Return prominent peaks and valleys based on scipy's find_peaks function """
159
    smoothened = smooth(-1*signal)
160
    valley_loc = scipy.signal.find_peaks(smoothened, prominence= 0.07)[0]
161
    
162
    return valley_loc