Diff of /ecgtoBR/utils.py [000000] .. [c0487b]

Switch to unified view

a b/ecgtoBR/utils.py
1
import torch
2
import torch.nn as nn
3
import torch.nn.functional as functional
4
5
import numpy as np
6
import pandas as pd
7
import scipy.signal
8
from sklearn.preprocessing import MinMaxScaler, StandardScaler
9
10
def testDataEval(model, loader, criterion):
11
12
    """Test model on dataloader
13
    
14
    Arguments:
15
        model {torch object} -- Model   
16
        loader {torch object} -- Data Loader  
17
        criterion {torch object} -- Loss Function
18
    Returns:
19
        float -- total loss
20
    """
21
    
22
    model.eval()
23
    
24
    with torch.no_grad():
25
        
26
        total_loss = 0
27
        
28
        for (x,y) in loader:
29
            
30
            ecg,BR = x.cuda(),y.cuda()
31
            BR_pred = model(ecg)
32
            loss = criterion(BR_pred, BR)
33
            total_loss += loss
34
            
35
    return total_loss
36
37
def smooth(signal,window_len=50):
38
    """Compute moving average of specified window length
39
    
40
    Arguments:
41
        signal {ndarray} -- signal to smooth
42
    
43
    Keyword Arguments:
44
        window_len {int} -- size of window over which average is to be computed (default: {50})
45
    
46
    Returns:
47
        ndarray   -- smoothed signal
48
    """
49
    
50
    y = pd.DataFrame(signal).rolling(window_len,center = True, min_periods = 1).mean().values.reshape((-1,))
51
    return y
52
53
def findValleys(signal, prominence = 0.07):
54
    """Find valleys of distance transform to estimate breath positions   
55
    Arguments:
56
        signal {ndarray} -- transform to get breath positions
57
    
58
    Keyword Arguments:
59
        prominence {int} -- threshold prominence to detect peaks (default: {0.07})
60
    
61
    Returns:
62
        ndarray -- valley locations in signal
63
    """
64
    smoothened = smooth(-1*signal)
65
    valley_loc = scipy.signal.find_peaks(smoothened, prominence= prominence)[0]
66
    
67
    return valley_loc
68
69
def getBR(signal, model):
70
    """ Get Breathing Rate after passing ECG through Model
71
    
72
    Arguments:
73
        signal {torch tensor} -- input ECG signal
74
        model  -- ECG to BR model
75
    
76
    Returns:
77
        ndarray -- position of predicted valley and corresponding predicted transform
78
    """
79
    
80
    model.eval()
81
    with torch.no_grad():
82
        transformPredicted = model(signal)
83
    transformPredicted = transformPredicted.cpu().numpy().reshape((-1,))
84
    valleys = findValleys(transformPredicted)
85
    return valleys, transformPredicted
86
87
def save_model(exp_dir, epoch, model, optimizer,best_dev_loss):
88
    """ save checkpoint of model 
89
    
90
    Arguments:
91
        exp_dir {string} -- Path to checkpoint
92
        epoch {int} -- epoch at which model is checkpointed
93
        model -- model state to be checkpointed
94
        optimizer {torch optimizer object} -- optimizer state of model to be checkpoint
95
        best_dev_loss {float} -- loss of model to be checkpointed
96
    """
97
    out = torch.save(
98
        {
99
            'epoch': epoch,
100
            'model': model.state_dict(),
101
            'optimizer': optimizer.state_dict(),
102
            'best_dev_loss': best_dev_loss,
103
            'exp_dir':exp_dir
104
        },
105
        f=exp_dir + '/best_model.pt'
106
    )
107
108
def dist_transform(signal, ann):
109
110
    """ Compute distance transform of Respiration signaal based on breath positions
111
    Arguments:
112
        signal{ndarray} -- The ECG signal  
113
        ann{ndarray} -- The ground truth R-Peaks
114
    Returns:
115
       ndarray -- transformed signal
116
    """
117
    
118
    length = len(signal)
119
    transform = []
120
121
    sample = 0
122
    if len(ann) == 0:
123
        return None
124
    if len(ann) ==1:
125
        for i in range(length):
126
            transform.append(abs(i-ann[sample]))
127
    else:
128
        for i in range(length):
129
130
            if sample+1 == len(ann):
131
                for j in range(i,length):
132
133
                    transform.append(abs(j - nextAnn))
134
                break
135
            prevAnn = ann[sample]
136
            nextAnn = ann[sample+1]
137
            middle = int((prevAnn + nextAnn )/2) 
138
            if i < middle:
139
                transform.append(abs(i - prevAnn))
140
            elif i>= middle:
141
                transform.append(abs(i- nextAnn))
142
            if i == nextAnn:
143
                sample+=1
144
145
    transform = np.array(transform)
146
    minmaxScaler = MinMaxScaler()
147
    transform = minmaxScaler.fit_transform(transform.reshape((-1,1)))
148
    return transform
149
150
151
def getWindow(signal,ann, windows = 10, freq  = 125, overlap = 0.5):
152
    """Generate ECG and Respiration signals with annotations of specified window length
153
    
154
    Arguments:
155
        signal {2-D array} -- array containing ecg at index 0 and resp at index 1
156
        ann {list} -- annotations within specified window
157
    
158
    Keyword Arguments:
159
        windows {int} -- size of window in seconds (default: {5})
160
        freq {int} -- sampling rate in Hz (default: {125})
161
        overlap {float} -- percentage of overlap between windows (default: {0.5})
162
    
163
    Yields:
164
        tuple -- signals and correspoinding annotations
165
    """
166
    
167
    for start in range(0,len(signal),int((1-overlap)*freq*windows)):
168
        yield (signal[start: start + windows*freq, :],[x-start for x in ann if x >= start and x < start+windows*freq])