Diff of /ecgtoBR/create_dataset.py [000000] .. [c0487b]

Switch to unified view

a b/ecgtoBR/create_dataset.py
1
import numpy as np
2
import os
3
import pandas as pd
4
import wfdb as wf
5
import argparse
6
from glob import glob
7
from tqdm import tqdm
8
from scipy.signal import resample
9
from sklearn.preprocessing import StandardScaler, MinMaxScaler
10
from sklearn.model_selection import train_test_split
11
import wget
12
13
import torch
14
15
from utils import dist_transform,getWindow
16
17
def data_preprocess(args):
18
19
    """ Preprocess data and create train - validate split
20
    """
21
22
    dat_path = os.path.join(args.data_path,'*.dat')
23
    paths = glob(dat_path)
24
25
    paths= sorted([path[:-4] for path in paths if path[-5] != "n"] )
26
27
    fs = args.sampling_freq
28
    fs_upsample = args.upsample_freq
29
    WINDOWS = args.window_length 
30
31
    ecgSignals = []
32
    BRSignals = []
33
    BRAnn1 = []
34
    BRAnn2 = []
35
36
    for path in tqdm(paths):
37
        
38
        ann    = wf.rdann(path,'breath')
39
        samples = np.array(ann.sample)
40
        aux_note = np.array(ann.aux_note)
41
        ann1 = samples[(aux_note == "ann1")]
42
        ann2 = samples[(aux_note == "ann2")]
43
        record = wf.io.rdrecord(path)
44
           
45
        ecgSignals.append(record.p_signal[:,record.sig_name.index('II,')])
46
        BRSignals.append(record.p_signal[:,record.sig_name.index('RESP,')])
47
        BRAnn1.append(ann1)
48
        BRAnn2.append(ann2)
49
50
    ecgSignals = np.array(ecgSignals,ndmin = 2)
51
    BRSignals = np.array(BRSignals, ndmin = 2)
52
53
    signals = np.stack([ecgSignals,BRSignals], axis= -1 )
54
55
    inputECG = []
56
    groundTruth = []
57
58
    for i in tqdm(range(len(signals))):
59
                    
60
        generateSignals = getWindow(signals[i],BRAnn2[i],windows=WINDOWS)
61
62
        for sig, ann in generateSignals:
63
        
64
            ecg = sig[:,0]
65
            br = sig[:,1]
66
            
67
            if len(ecg) == 1 or len(ann) == 0:
68
                break
69
70
            resampled = resample(ecg, WINDOWS*fs_upsample)
71
            scaler = StandardScaler()
72
            resampled = scaler.fit_transform(resampled.reshape((-1,1)))
73
            transform = dist_transform(br,ann)
74
            
75
            if resampled.shape == (fs_upsample*WINDOWS,1) and transform.shape == (WINDOWS*fs,1):
76
                inputECG.append(resampled.reshape((1,-1)))
77
                groundTruth.append(transform.reshape((1,-1)))
78
79
    X_train,X_test,y_train,y_test = train_test_split(np.array(inputECG),np.array(groundTruth),test_size = 0.2, random_state = 42)
80
81
    X_train_toTensor = torch.Tensor(X_train).type(torch.float)
82
    X_test_toTensor = torch.Tensor(X_test).type(torch.float)
83
    y_train_toTensor = torch.Tensor(y_train).type(torch.float)
84
    y_test_toTensor = torch.Tensor(y_test).type(torch.float)
85
    
86
    if not(os.path.exists('data')):
87
        os.mkdir('data')
88
89
    torch.save(X_train_toTensor, "data/ecgtoBR_train_data.pt")
90
    torch.save(y_train_toTensor, "data/ecgtoBR_train_labels.pt")
91
    torch.save(X_test_toTensor, "data/ecgtoBR_test_data.pt")
92
    torch.save(y_test_toTensor, "data/ecgtoBR_test_labels.pt")