Diff of /class_deepdataset.py [000000] .. [6536f9]

Switch to unified view

a b/class_deepdataset.py
1
import os
2
import torch
3
import random
4
import warnings
5
import numpy as np
6
from torch.utils.data import Dataset
7
8
'''
9
there are 2 versions of selection of segments that feature vectors are created from
10
- DeepDataset that takes first 500 valid segments for feature vector extraction
11
- DeepDatasetV2 stores all valid segments and randomly selects 500 of them for feature vector extraction
12
13
'''
14
15
class DeepDataset(Dataset):
16
    def __init__(self, data_path, patient_ids, fs, lead):
17
        self.data_path = data_path
18
        self.patient_ids = patient_ids
19
        self.id_mapped = {pid: idx for idx, pid in enumerate(patient_ids)}
20
        self.fs = fs
21
        self.lead = lead
22
        self.vectors = []
23
        self.labels = []
24
        self.m = 8
25
        self.segment_length = 10  # In seconds
26
        self.num_segments = int(15*3600 / self.segment_length)  # Total segments in 15 hours
27
        self.max_segments_per_patient = 500
28
        
29
        for patient_id in self.patient_ids:
30
            valid_segments_count = 0
31
            signal_path = os.path.join(self.data_path, f"{patient_id}_signal.npy")
32
            qrs_path = os.path.join(self.data_path, f"{patient_id}_qrs.npy")
33
            signal = np.load(signal_path)[self.lead, :15*3600*self.fs]
34
            qrs_indices = np.load(qrs_path)
35
36
            for segment_idx in range(self.num_segments):
37
                if valid_segments_count >= self.max_segments_per_patient:
38
                    break  # Stop if we have enough valid segments
39
40
                start = segment_idx * self.segment_length * self.fs
41
                end = start + self.segment_length * self.fs
42
                qrs_in_segment = qrs_indices[(qrs_indices >= start) & (qrs_indices < end)] - start
43
44
                if len(qrs_in_segment) > 0:  # Check for valid segments
45
                    segment = signal[start:end]
46
                    vector_v = self.extract_feature_vector(segment, qrs_in_segment)
47
                    self.vectors.append(vector_v)
48
                    self.labels.append(self.id_mapped[patient_id])
49
                    valid_segments_count += 1
50
51
            if valid_segments_count < self.max_segments_per_patient:
52
                warnings.warn(f"Patient ID {patient_id} has less than 500 valid segments: {valid_segments_count} segments found.")
53
        
54
        self.vectors = np.array(self.vectors)
55
        self.labels = np.array(self.labels)
56
57
58
    def __len__(self):
59
        return len(self.vectors)
60
61
    def __getitem__(self, index):
62
        vector_v = self.vectors[index]
63
        label = self.labels[index]
64
        return torch.tensor(vector_v, dtype=torch.float).unsqueeze(0), torch.tensor(label, dtype=torch.long)
65
66
67
    def extract_feature_vector(self, segment, qrs_indices):
68
        qrs_complexes = []
69
        half_window = int(0.125 * self.fs / 2)
70
        window_length = 2 * half_window
71
72
        for idx in qrs_indices:
73
            start_idx = max(0, idx - half_window)
74
            end_idx = min(len(segment), idx + half_window)
75
            qrs_complex = segment[start_idx:end_idx]
76
            if len(qrs_complex) < window_length:
77
                padding = window_length - len(qrs_complex)
78
                qrs_complex = np.pad(qrs_complex, (0, padding), 'constant')
79
80
            qrs_complexes.append(qrs_complex)
81
82
        average_qrs = np.mean(qrs_complexes, axis=0) if qrs_complexes else np.zeros(window_length)
83
        correlations = [np.correlate(qrs_complex, average_qrs, 'valid')[0] for qrs_complex in qrs_complexes]
84
        top_m_indices = np.argsort(correlations)[-self.m:]
85
        vector_v = np.concatenate([qrs_complexes[idx] for idx in top_m_indices], axis=0)
86
87
        while len(vector_v) < self.m * window_length:
88
            vector_v = np.concatenate([vector_v, qrs_complexes[top_m_indices[-1]]])
89
        return vector_v[:self.m * window_length] 
90
91
92
93
class DeepDatasetV2(Dataset):
94
    def __init__(self, data_path, patient_ids, fs, lead):
95
        self.data_path = data_path
96
        self.patient_ids = patient_ids
97
        self.id_mapped = {pid: idx for idx, pid in enumerate(patient_ids)}
98
        self.fs = fs
99
        self.lead = lead
100
        self.vectors = []
101
        self.labels = []
102
        self.m = 8
103
        self.segment_length = 10  # seconds
104
        self.num_segments = int(15*3600 / self.segment_length)  # total segments in 15 hours
105
        self.max_segments_per_patient = 500 
106
        
107
        for patient_id in self.patient_ids:
108
            valid_segment_indices = [] 
109
110
            signal_path = os.path.join(self.data_path, f"{patient_id}_signal.npy")
111
            qrs_path = os.path.join(self.data_path, f"{patient_id}_qrs.npy")
112
            signal = np.load(signal_path)[self.lead, :15*3600*self.fs]
113
            qrs_indices = np.load(qrs_path)
114
115
            for segment_idx in range(self.num_segments):
116
                start = segment_idx * self.segment_length * self.fs
117
                end = start + self.segment_length * self.fs
118
                qrs_in_segment = qrs_indices[(qrs_indices >= start) & (qrs_indices < end)] - start
119
120
                if len(qrs_in_segment) > 0:  # Check for valid segments
121
                    valid_segment_indices.append(segment_idx)
122
123
            selected_indices = random.sample(valid_segment_indices, min(len(valid_segment_indices), self.max_segments_per_patient))
124
            if len(valid_segment_indices) < self.max_segments_per_patient:
125
                warnings.warn(f"Patient ID {patient_id} has less than 500 valid segments: {len(valid_segment_indices)} segments found.")
126
127
            for segment_idx in selected_indices: #extract V vectors only from selected segments
128
                start = segment_idx * self.segment_length * self.fs
129
                end = start + self.segment_length * self.fs
130
                qrs_in_segment = qrs_indices[(qrs_indices >= start) & (qrs_indices < end)] - start
131
                segment = signal[start:end]
132
                vector_v = self.extract_feature_vector(segment, qrs_in_segment)
133
                self.vectors.append(vector_v)
134
                self.labels.append(self.id_mapped[patient_id])
135
        
136
        self.vectors = np.array(self.vectors)
137
        self.labels = np.array(self.labels)
138
139
    def __len__(self):
140
        return len(self.vectors)
141
142
    def __getitem__(self, index):
143
        vector_v = self.vectors[index]
144
        label = self.labels[index]
145
        return torch.tensor(vector_v, dtype=torch.float).unsqueeze(0), torch.tensor(label, dtype=torch.long)
146
    
147
    def extract_feature_vector(self, segment, qrs_indices):
148
        qrs_complexes = []
149
        half_window = int(0.125 * self.fs / 2)
150
        window_length = 2 * half_window
151
152
        for idx in qrs_indices:
153
            start_idx = max(0, idx - half_window)
154
            end_idx = min(len(segment), idx + half_window)
155
            qrs_complex = segment[start_idx:end_idx]
156
            if len(qrs_complex) < window_length:
157
                padding = window_length - len(qrs_complex)
158
                qrs_complex = np.pad(qrs_complex, (0, padding), 'constant')
159
160
            qrs_complexes.append(qrs_complex)
161
162
        average_qrs = np.mean(qrs_complexes, axis=0) if qrs_complexes else np.zeros(window_length)
163
        correlations = [np.correlate(qrs_complex, average_qrs, 'valid')[0] for qrs_complex in qrs_complexes]
164
        top_m_indices = np.argsort(correlations)[-self.m:]
165
        vector_v = np.concatenate([qrs_complexes[idx] for idx in top_m_indices], axis=0)
166
167
        while len(vector_v) < self.m * window_length:
168
            vector_v = np.concatenate([vector_v, qrs_complexes[top_m_indices[-1]]])
169
        return vector_v[:self.m * window_length] 
170