|
a |
|
b/class_deepdataset.py |
|
|
1 |
import os |
|
|
2 |
import torch |
|
|
3 |
import random |
|
|
4 |
import warnings |
|
|
5 |
import numpy as np |
|
|
6 |
from torch.utils.data import Dataset |
|
|
7 |
|
|
|
8 |
''' |
|
|
9 |
there are 2 versions of selection of segments that feature vectors are created from |
|
|
10 |
- DeepDataset that takes first 500 valid segments for feature vector extraction |
|
|
11 |
- DeepDatasetV2 stores all valid segments and randomly selects 500 of them for feature vector extraction |
|
|
12 |
|
|
|
13 |
''' |
|
|
14 |
|
|
|
15 |
class DeepDataset(Dataset): |
|
|
16 |
def __init__(self, data_path, patient_ids, fs, lead): |
|
|
17 |
self.data_path = data_path |
|
|
18 |
self.patient_ids = patient_ids |
|
|
19 |
self.id_mapped = {pid: idx for idx, pid in enumerate(patient_ids)} |
|
|
20 |
self.fs = fs |
|
|
21 |
self.lead = lead |
|
|
22 |
self.vectors = [] |
|
|
23 |
self.labels = [] |
|
|
24 |
self.m = 8 |
|
|
25 |
self.segment_length = 10 # In seconds |
|
|
26 |
self.num_segments = int(15*3600 / self.segment_length) # Total segments in 15 hours |
|
|
27 |
self.max_segments_per_patient = 500 |
|
|
28 |
|
|
|
29 |
for patient_id in self.patient_ids: |
|
|
30 |
valid_segments_count = 0 |
|
|
31 |
signal_path = os.path.join(self.data_path, f"{patient_id}_signal.npy") |
|
|
32 |
qrs_path = os.path.join(self.data_path, f"{patient_id}_qrs.npy") |
|
|
33 |
signal = np.load(signal_path)[self.lead, :15*3600*self.fs] |
|
|
34 |
qrs_indices = np.load(qrs_path) |
|
|
35 |
|
|
|
36 |
for segment_idx in range(self.num_segments): |
|
|
37 |
if valid_segments_count >= self.max_segments_per_patient: |
|
|
38 |
break # Stop if we have enough valid segments |
|
|
39 |
|
|
|
40 |
start = segment_idx * self.segment_length * self.fs |
|
|
41 |
end = start + self.segment_length * self.fs |
|
|
42 |
qrs_in_segment = qrs_indices[(qrs_indices >= start) & (qrs_indices < end)] - start |
|
|
43 |
|
|
|
44 |
if len(qrs_in_segment) > 0: # Check for valid segments |
|
|
45 |
segment = signal[start:end] |
|
|
46 |
vector_v = self.extract_feature_vector(segment, qrs_in_segment) |
|
|
47 |
self.vectors.append(vector_v) |
|
|
48 |
self.labels.append(self.id_mapped[patient_id]) |
|
|
49 |
valid_segments_count += 1 |
|
|
50 |
|
|
|
51 |
if valid_segments_count < self.max_segments_per_patient: |
|
|
52 |
warnings.warn(f"Patient ID {patient_id} has less than 500 valid segments: {valid_segments_count} segments found.") |
|
|
53 |
|
|
|
54 |
self.vectors = np.array(self.vectors) |
|
|
55 |
self.labels = np.array(self.labels) |
|
|
56 |
|
|
|
57 |
|
|
|
58 |
def __len__(self): |
|
|
59 |
return len(self.vectors) |
|
|
60 |
|
|
|
61 |
def __getitem__(self, index): |
|
|
62 |
vector_v = self.vectors[index] |
|
|
63 |
label = self.labels[index] |
|
|
64 |
return torch.tensor(vector_v, dtype=torch.float).unsqueeze(0), torch.tensor(label, dtype=torch.long) |
|
|
65 |
|
|
|
66 |
|
|
|
67 |
def extract_feature_vector(self, segment, qrs_indices): |
|
|
68 |
qrs_complexes = [] |
|
|
69 |
half_window = int(0.125 * self.fs / 2) |
|
|
70 |
window_length = 2 * half_window |
|
|
71 |
|
|
|
72 |
for idx in qrs_indices: |
|
|
73 |
start_idx = max(0, idx - half_window) |
|
|
74 |
end_idx = min(len(segment), idx + half_window) |
|
|
75 |
qrs_complex = segment[start_idx:end_idx] |
|
|
76 |
if len(qrs_complex) < window_length: |
|
|
77 |
padding = window_length - len(qrs_complex) |
|
|
78 |
qrs_complex = np.pad(qrs_complex, (0, padding), 'constant') |
|
|
79 |
|
|
|
80 |
qrs_complexes.append(qrs_complex) |
|
|
81 |
|
|
|
82 |
average_qrs = np.mean(qrs_complexes, axis=0) if qrs_complexes else np.zeros(window_length) |
|
|
83 |
correlations = [np.correlate(qrs_complex, average_qrs, 'valid')[0] for qrs_complex in qrs_complexes] |
|
|
84 |
top_m_indices = np.argsort(correlations)[-self.m:] |
|
|
85 |
vector_v = np.concatenate([qrs_complexes[idx] for idx in top_m_indices], axis=0) |
|
|
86 |
|
|
|
87 |
while len(vector_v) < self.m * window_length: |
|
|
88 |
vector_v = np.concatenate([vector_v, qrs_complexes[top_m_indices[-1]]]) |
|
|
89 |
return vector_v[:self.m * window_length] |
|
|
90 |
|
|
|
91 |
|
|
|
92 |
|
|
|
93 |
class DeepDatasetV2(Dataset): |
|
|
94 |
def __init__(self, data_path, patient_ids, fs, lead): |
|
|
95 |
self.data_path = data_path |
|
|
96 |
self.patient_ids = patient_ids |
|
|
97 |
self.id_mapped = {pid: idx for idx, pid in enumerate(patient_ids)} |
|
|
98 |
self.fs = fs |
|
|
99 |
self.lead = lead |
|
|
100 |
self.vectors = [] |
|
|
101 |
self.labels = [] |
|
|
102 |
self.m = 8 |
|
|
103 |
self.segment_length = 10 # seconds |
|
|
104 |
self.num_segments = int(15*3600 / self.segment_length) # total segments in 15 hours |
|
|
105 |
self.max_segments_per_patient = 500 |
|
|
106 |
|
|
|
107 |
for patient_id in self.patient_ids: |
|
|
108 |
valid_segment_indices = [] |
|
|
109 |
|
|
|
110 |
signal_path = os.path.join(self.data_path, f"{patient_id}_signal.npy") |
|
|
111 |
qrs_path = os.path.join(self.data_path, f"{patient_id}_qrs.npy") |
|
|
112 |
signal = np.load(signal_path)[self.lead, :15*3600*self.fs] |
|
|
113 |
qrs_indices = np.load(qrs_path) |
|
|
114 |
|
|
|
115 |
for segment_idx in range(self.num_segments): |
|
|
116 |
start = segment_idx * self.segment_length * self.fs |
|
|
117 |
end = start + self.segment_length * self.fs |
|
|
118 |
qrs_in_segment = qrs_indices[(qrs_indices >= start) & (qrs_indices < end)] - start |
|
|
119 |
|
|
|
120 |
if len(qrs_in_segment) > 0: # Check for valid segments |
|
|
121 |
valid_segment_indices.append(segment_idx) |
|
|
122 |
|
|
|
123 |
selected_indices = random.sample(valid_segment_indices, min(len(valid_segment_indices), self.max_segments_per_patient)) |
|
|
124 |
if len(valid_segment_indices) < self.max_segments_per_patient: |
|
|
125 |
warnings.warn(f"Patient ID {patient_id} has less than 500 valid segments: {len(valid_segment_indices)} segments found.") |
|
|
126 |
|
|
|
127 |
for segment_idx in selected_indices: #extract V vectors only from selected segments |
|
|
128 |
start = segment_idx * self.segment_length * self.fs |
|
|
129 |
end = start + self.segment_length * self.fs |
|
|
130 |
qrs_in_segment = qrs_indices[(qrs_indices >= start) & (qrs_indices < end)] - start |
|
|
131 |
segment = signal[start:end] |
|
|
132 |
vector_v = self.extract_feature_vector(segment, qrs_in_segment) |
|
|
133 |
self.vectors.append(vector_v) |
|
|
134 |
self.labels.append(self.id_mapped[patient_id]) |
|
|
135 |
|
|
|
136 |
self.vectors = np.array(self.vectors) |
|
|
137 |
self.labels = np.array(self.labels) |
|
|
138 |
|
|
|
139 |
def __len__(self): |
|
|
140 |
return len(self.vectors) |
|
|
141 |
|
|
|
142 |
def __getitem__(self, index): |
|
|
143 |
vector_v = self.vectors[index] |
|
|
144 |
label = self.labels[index] |
|
|
145 |
return torch.tensor(vector_v, dtype=torch.float).unsqueeze(0), torch.tensor(label, dtype=torch.long) |
|
|
146 |
|
|
|
147 |
def extract_feature_vector(self, segment, qrs_indices): |
|
|
148 |
qrs_complexes = [] |
|
|
149 |
half_window = int(0.125 * self.fs / 2) |
|
|
150 |
window_length = 2 * half_window |
|
|
151 |
|
|
|
152 |
for idx in qrs_indices: |
|
|
153 |
start_idx = max(0, idx - half_window) |
|
|
154 |
end_idx = min(len(segment), idx + half_window) |
|
|
155 |
qrs_complex = segment[start_idx:end_idx] |
|
|
156 |
if len(qrs_complex) < window_length: |
|
|
157 |
padding = window_length - len(qrs_complex) |
|
|
158 |
qrs_complex = np.pad(qrs_complex, (0, padding), 'constant') |
|
|
159 |
|
|
|
160 |
qrs_complexes.append(qrs_complex) |
|
|
161 |
|
|
|
162 |
average_qrs = np.mean(qrs_complexes, axis=0) if qrs_complexes else np.zeros(window_length) |
|
|
163 |
correlations = [np.correlate(qrs_complex, average_qrs, 'valid')[0] for qrs_complex in qrs_complexes] |
|
|
164 |
top_m_indices = np.argsort(correlations)[-self.m:] |
|
|
165 |
vector_v = np.concatenate([qrs_complexes[idx] for idx in top_m_indices], axis=0) |
|
|
166 |
|
|
|
167 |
while len(vector_v) < self.m * window_length: |
|
|
168 |
vector_v = np.concatenate([vector_v, qrs_complexes[top_m_indices[-1]]]) |
|
|
169 |
return vector_v[:self.m * window_length] |
|
|
170 |
|