Diff of /preprocess.py [000000] .. [6536f9]

Switch to side-by-side view

--- a
+++ b/preprocess.py
@@ -0,0 +1,159 @@
+import os
+import re
+import wfdb
+import numpy as np
+from wfdb import processing
+import sys
+
+from tqdm import tqdm
+from scipy.signal import butter, iirnotch, filtfilt
+from sklearn.decomposition import PCA
+
+from ecgdetectors import Detectors #https://pypi.org/project/py-ecg-detectors/
+
+
+def interpolate_nans(record):
+    '''
+    interpolate over NaN values in each lead of a given signal
+
+    Parameters:
+    - record (numpy.ndarray): 2D numpy array representing the ECG signal with leads in rows
+                                shape (channels x samples)
+
+    Returns:
+    -numpy.ndarray: the signal with NaNs interpolated, leads with all NaNs remain unchanged.
+    '''
+
+    for i in range(record.shape[0]):  ###for future: instead of looping through each lead, applying interpolation in a vectorized manner
+        lead_signal = record[i, :]
+
+        if np.isnan(lead_signal).all():
+            print(f"Warning: Lead {i} contains only NaNs.")
+            continue
+
+        if np.isnan(lead_signal).any():
+            valid_mask = ~np.isnan(lead_signal)
+            record[i, :] = np.interp(
+                np.arange(len(lead_signal)),
+                valid_mask.nonzero()[0],
+                lead_signal[valid_mask]
+            )
+    return record
+
+def filter_signal(signal, fs):
+    """
+    apply a high-pass and notch filter to an ECG signal
+
+    Parameters:
+    signal (array): the ECG signal to filter with shape (channels x samples)
+                    can be 1D (single lead) or 2D (multichannel)
+    fs (int): sampling frequency of the signal
+
+    Returns:
+    array or None: The filtered signal or None if an error occurs.
+    """
+    try:
+        if np.isnan(signal).any():  # if the interpolation didn't work this will catch it
+            print("NaN values detected during filtering, interpolating...")
+            signal = np.nan_to_num(signal)  # Replace NaNs with zeros
+
+        [b, a] = butter(3, (0.5, 40), btype='bandpass', fs=fs)
+        signal = filtfilt(b, a, signal, axis=1)
+        [bn, an] = iirnotch(50, 3, fs=fs)
+        signal = filtfilt(bn, an, signal, axis=1)
+
+        if signal.size == 0:
+            print("Warning: Filtered signal is empty.")
+
+        return signal
+    except Exception as e:
+        print(f"An error occurred during filtering: {e}")
+        return None
+
+#I'm using regular expression for this one coz info.txt is messy
+def extract_patient_ids(info_path):
+    '''
+    extracts patient IDs from info.txt file
+    
+    returns:
+    - list of IDs
+    '''
+    patient_ids = []
+    id_pattern = re.compile(r'^\d{4}\b')
+
+    with open(info_path, 'r') as file:
+        for line in file:
+            match = id_pattern.match(line)
+            if match:
+                patient_id = match.group()
+                patient_ids.append(patient_id)
+
+    return patient_ids
+
+
+def preprocess_and_save(data_path, save_path):
+    '''
+    Preprocesses ECG signal data and saves the filtered signals and QRS complex indices.
+
+    steps:
+    - itaration over ECG records specified in 'info.txt'
+    - NaN interpolation and filtering of the signals
+    - QRS detection in a two-step process: 
+        - an initial detection using the Pan-Tompkins algorithm
+        - refinement with WFDB's `correct_peaks`
+
+    parameters:
+    - data_path (str): path to database
+    - save_path (str): where to store preprocessed signals
+
+    signals are saved in separate files with shape (number of channels x number of samples)
+    '''
+
+    os.makedirs(save_path, exist_ok=True)
+
+    info_path = os.path.join(data_path, 'info.txt')
+    patient_ids = extract_patient_ids(info_path)
+
+    patient_ids_path = os.path.join(save_path, 'patient_ids.txt')
+    with open(patient_ids_path, 'w') as f:
+        for id in patient_ids:
+            f.write(f"{id}\n")
+
+    for patient_id in tqdm(patient_ids, desc='Processing Patients'):
+        record_path = os.path.join(data_path, f"0{patient_id}") #I'm adding a zero before id coz thats how the files are saved
+        save_signals_file = os.path.join(save_path, f"{patient_id}_signal.npy")
+        save_qrs_file = os.path.join(save_path, f"{patient_id}_qrs.npy")
+
+        record = wfdb.rdrecord(record_path)
+        fs = record.fs
+        signals = record.p_signal[(fs*60*5):, :].T #getting the shape (number of channels x number of samples)
+        interpolated_signals = interpolate_nans(signals)
+
+        filtered_signals = filter_signal(interpolated_signals, record.fs)
+        if filtered_signals is None:
+            print(f"Filtering failed for patient ID {patient_id}. Skipping QRS detection.")
+            continue
+
+        ### for future: fs is constant so filter coefficients don't have to be recalculated for every call
+        ### move filtering from seperate function to preprocess_and_save
+
+        #its way faster to find approximate peaks and correct them with wfdb rather than doing the whole search with wfdb
+        detectors = Detectors(record.fs) 
+        qrs_inds = detectors.pan_tompkins_detector(interpolated_signals[0])
+        corrected_peak_inds = processing.correct_peaks(filtered_signals[0],
+                                                    peak_inds=qrs_inds,
+                                                    search_radius=int(record.fs*0.2),
+                                                    smooth_window_size=int(record.fs*0.1))
+        
+        np.save(save_signals_file, filtered_signals)
+        np.save(save_qrs_file, corrected_peak_inds)
+
+if __name__ == "__main__":
+    if len(sys.argv) != 3:
+        print("Usage: python preprocess.py <data_path> <save_path>")
+        sys.exit(1)
+
+    data_path = sys.argv[1]
+    save_path = sys.argv[2]
+
+    preprocess_and_save(data_path, save_path)
\ No newline at end of file