Diff of /preprocess.py [000000] .. [6536f9]

Switch to unified view

a b/preprocess.py
1
import os
2
import re
3
import wfdb
4
import numpy as np
5
from wfdb import processing
6
import sys
7
8
from tqdm import tqdm
9
from scipy.signal import butter, iirnotch, filtfilt
10
from sklearn.decomposition import PCA
11
12
from ecgdetectors import Detectors #https://pypi.org/project/py-ecg-detectors/
13
14
15
def interpolate_nans(record):
16
    '''
17
    interpolate over NaN values in each lead of a given signal
18
19
    Parameters:
20
    - record (numpy.ndarray): 2D numpy array representing the ECG signal with leads in rows
21
                                shape (channels x samples)
22
23
    Returns:
24
    -numpy.ndarray: the signal with NaNs interpolated, leads with all NaNs remain unchanged.
25
    '''
26
27
    for i in range(record.shape[0]):  ###for future: instead of looping through each lead, applying interpolation in a vectorized manner
28
        lead_signal = record[i, :]
29
30
        if np.isnan(lead_signal).all():
31
            print(f"Warning: Lead {i} contains only NaNs.")
32
            continue
33
34
        if np.isnan(lead_signal).any():
35
            valid_mask = ~np.isnan(lead_signal)
36
            record[i, :] = np.interp(
37
                np.arange(len(lead_signal)),
38
                valid_mask.nonzero()[0],
39
                lead_signal[valid_mask]
40
            )
41
    return record
42
43
def filter_signal(signal, fs):
44
    """
45
    apply a high-pass and notch filter to an ECG signal
46
47
    Parameters:
48
    signal (array): the ECG signal to filter with shape (channels x samples)
49
                    can be 1D (single lead) or 2D (multichannel)
50
    fs (int): sampling frequency of the signal
51
52
    Returns:
53
    array or None: The filtered signal or None if an error occurs.
54
    """
55
    try:
56
        if np.isnan(signal).any():  # if the interpolation didn't work this will catch it
57
            print("NaN values detected during filtering, interpolating...")
58
            signal = np.nan_to_num(signal)  # Replace NaNs with zeros
59
60
        [b, a] = butter(3, (0.5, 40), btype='bandpass', fs=fs)
61
        signal = filtfilt(b, a, signal, axis=1)
62
        [bn, an] = iirnotch(50, 3, fs=fs)
63
        signal = filtfilt(bn, an, signal, axis=1)
64
65
        if signal.size == 0:
66
            print("Warning: Filtered signal is empty.")
67
68
        return signal
69
    except Exception as e:
70
        print(f"An error occurred during filtering: {e}")
71
        return None
72
73
#I'm using regular expression for this one coz info.txt is messy
74
def extract_patient_ids(info_path):
75
    '''
76
    extracts patient IDs from info.txt file
77
    
78
    returns:
79
    - list of IDs
80
    '''
81
    patient_ids = []
82
    id_pattern = re.compile(r'^\d{4}\b')
83
84
    with open(info_path, 'r') as file:
85
        for line in file:
86
            match = id_pattern.match(line)
87
            if match:
88
                patient_id = match.group()
89
                patient_ids.append(patient_id)
90
91
    return patient_ids
92
93
94
def preprocess_and_save(data_path, save_path):
95
    '''
96
    Preprocesses ECG signal data and saves the filtered signals and QRS complex indices.
97
98
    steps:
99
    - itaration over ECG records specified in 'info.txt'
100
    - NaN interpolation and filtering of the signals
101
    - QRS detection in a two-step process: 
102
        - an initial detection using the Pan-Tompkins algorithm
103
        - refinement with WFDB's `correct_peaks`
104
105
    parameters:
106
    - data_path (str): path to database
107
    - save_path (str): where to store preprocessed signals
108
109
    signals are saved in separate files with shape (number of channels x number of samples)
110
    '''
111
112
    os.makedirs(save_path, exist_ok=True)
113
114
    info_path = os.path.join(data_path, 'info.txt')
115
    patient_ids = extract_patient_ids(info_path)
116
117
    patient_ids_path = os.path.join(save_path, 'patient_ids.txt')
118
    with open(patient_ids_path, 'w') as f:
119
        for id in patient_ids:
120
            f.write(f"{id}\n")
121
122
    for patient_id in tqdm(patient_ids, desc='Processing Patients'):
123
        record_path = os.path.join(data_path, f"0{patient_id}") #I'm adding a zero before id coz thats how the files are saved
124
        save_signals_file = os.path.join(save_path, f"{patient_id}_signal.npy")
125
        save_qrs_file = os.path.join(save_path, f"{patient_id}_qrs.npy")
126
127
        record = wfdb.rdrecord(record_path)
128
        fs = record.fs
129
        signals = record.p_signal[(fs*60*5):, :].T #getting the shape (number of channels x number of samples)
130
        interpolated_signals = interpolate_nans(signals)
131
132
        filtered_signals = filter_signal(interpolated_signals, record.fs)
133
        if filtered_signals is None:
134
            print(f"Filtering failed for patient ID {patient_id}. Skipping QRS detection.")
135
            continue
136
137
        ### for future: fs is constant so filter coefficients don't have to be recalculated for every call
138
        ### move filtering from seperate function to preprocess_and_save
139
140
        #its way faster to find approximate peaks and correct them with wfdb rather than doing the whole search with wfdb
141
        detectors = Detectors(record.fs) 
142
        qrs_inds = detectors.pan_tompkins_detector(interpolated_signals[0])
143
        corrected_peak_inds = processing.correct_peaks(filtered_signals[0],
144
                                                    peak_inds=qrs_inds,
145
                                                    search_radius=int(record.fs*0.2),
146
                                                    smooth_window_size=int(record.fs*0.1))
147
        
148
        np.save(save_signals_file, filtered_signals)
149
        np.save(save_qrs_file, corrected_peak_inds)
150
151
if __name__ == "__main__":
152
    if len(sys.argv) != 3:
153
        print("Usage: python preprocess.py <data_path> <save_path>")
154
        sys.exit(1)
155
156
    data_path = sys.argv[1]
157
    save_path = sys.argv[2]
158
159
    preprocess_and_save(data_path, save_path)