--- a
+++ b/class_deepdataset.py
@@ -0,0 +1,170 @@
+import os
+import torch
+import random
+import warnings
+import numpy as np
+from torch.utils.data import Dataset
+
+'''
+there are 2 versions of selection of segments that feature vectors are created from
+- DeepDataset that takes first 500 valid segments for feature vector extraction
+- DeepDatasetV2 stores all valid segments and randomly selects 500 of them for feature vector extraction
+
+'''
+
+class DeepDataset(Dataset):
+    def __init__(self, data_path, patient_ids, fs, lead):
+        self.data_path = data_path
+        self.patient_ids = patient_ids
+        self.id_mapped = {pid: idx for idx, pid in enumerate(patient_ids)}
+        self.fs = fs
+        self.lead = lead
+        self.vectors = []
+        self.labels = []
+        self.m = 8
+        self.segment_length = 10  # In seconds
+        self.num_segments = int(15*3600 / self.segment_length)  # Total segments in 15 hours
+        self.max_segments_per_patient = 500
+        
+        for patient_id in self.patient_ids:
+            valid_segments_count = 0
+            signal_path = os.path.join(self.data_path, f"{patient_id}_signal.npy")
+            qrs_path = os.path.join(self.data_path, f"{patient_id}_qrs.npy")
+            signal = np.load(signal_path)[self.lead, :15*3600*self.fs]
+            qrs_indices = np.load(qrs_path)
+
+            for segment_idx in range(self.num_segments):
+                if valid_segments_count >= self.max_segments_per_patient:
+                    break  # Stop if we have enough valid segments
+
+                start = segment_idx * self.segment_length * self.fs
+                end = start + self.segment_length * self.fs
+                qrs_in_segment = qrs_indices[(qrs_indices >= start) & (qrs_indices < end)] - start
+
+                if len(qrs_in_segment) > 0:  # Check for valid segments
+                    segment = signal[start:end]
+                    vector_v = self.extract_feature_vector(segment, qrs_in_segment)
+                    self.vectors.append(vector_v)
+                    self.labels.append(self.id_mapped[patient_id])
+                    valid_segments_count += 1
+
+            if valid_segments_count < self.max_segments_per_patient:
+                warnings.warn(f"Patient ID {patient_id} has less than 500 valid segments: {valid_segments_count} segments found.")
+        
+        self.vectors = np.array(self.vectors)
+        self.labels = np.array(self.labels)
+
+
+    def __len__(self):
+        return len(self.vectors)
+
+    def __getitem__(self, index):
+        vector_v = self.vectors[index]
+        label = self.labels[index]
+        return torch.tensor(vector_v, dtype=torch.float).unsqueeze(0), torch.tensor(label, dtype=torch.long)
+
+
+    def extract_feature_vector(self, segment, qrs_indices):
+        qrs_complexes = []
+        half_window = int(0.125 * self.fs / 2)
+        window_length = 2 * half_window
+
+        for idx in qrs_indices:
+            start_idx = max(0, idx - half_window)
+            end_idx = min(len(segment), idx + half_window)
+            qrs_complex = segment[start_idx:end_idx]
+            if len(qrs_complex) < window_length:
+                padding = window_length - len(qrs_complex)
+                qrs_complex = np.pad(qrs_complex, (0, padding), 'constant')
+
+            qrs_complexes.append(qrs_complex)
+
+        average_qrs = np.mean(qrs_complexes, axis=0) if qrs_complexes else np.zeros(window_length)
+        correlations = [np.correlate(qrs_complex, average_qrs, 'valid')[0] for qrs_complex in qrs_complexes]
+        top_m_indices = np.argsort(correlations)[-self.m:]
+        vector_v = np.concatenate([qrs_complexes[idx] for idx in top_m_indices], axis=0)
+
+        while len(vector_v) < self.m * window_length:
+            vector_v = np.concatenate([vector_v, qrs_complexes[top_m_indices[-1]]])
+        return vector_v[:self.m * window_length] 
+
+
+
+class DeepDatasetV2(Dataset):
+    def __init__(self, data_path, patient_ids, fs, lead):
+        self.data_path = data_path
+        self.patient_ids = patient_ids
+        self.id_mapped = {pid: idx for idx, pid in enumerate(patient_ids)}
+        self.fs = fs
+        self.lead = lead
+        self.vectors = []
+        self.labels = []
+        self.m = 8
+        self.segment_length = 10  # seconds
+        self.num_segments = int(15*3600 / self.segment_length)  # total segments in 15 hours
+        self.max_segments_per_patient = 500 
+        
+        for patient_id in self.patient_ids:
+            valid_segment_indices = [] 
+
+            signal_path = os.path.join(self.data_path, f"{patient_id}_signal.npy")
+            qrs_path = os.path.join(self.data_path, f"{patient_id}_qrs.npy")
+            signal = np.load(signal_path)[self.lead, :15*3600*self.fs]
+            qrs_indices = np.load(qrs_path)
+
+            for segment_idx in range(self.num_segments):
+                start = segment_idx * self.segment_length * self.fs
+                end = start + self.segment_length * self.fs
+                qrs_in_segment = qrs_indices[(qrs_indices >= start) & (qrs_indices < end)] - start
+
+                if len(qrs_in_segment) > 0:  # Check for valid segments
+                    valid_segment_indices.append(segment_idx)
+
+            selected_indices = random.sample(valid_segment_indices, min(len(valid_segment_indices), self.max_segments_per_patient))
+            if len(valid_segment_indices) < self.max_segments_per_patient:
+                warnings.warn(f"Patient ID {patient_id} has less than 500 valid segments: {len(valid_segment_indices)} segments found.")
+
+            for segment_idx in selected_indices: #extract V vectors only from selected segments
+                start = segment_idx * self.segment_length * self.fs
+                end = start + self.segment_length * self.fs
+                qrs_in_segment = qrs_indices[(qrs_indices >= start) & (qrs_indices < end)] - start
+                segment = signal[start:end]
+                vector_v = self.extract_feature_vector(segment, qrs_in_segment)
+                self.vectors.append(vector_v)
+                self.labels.append(self.id_mapped[patient_id])
+        
+        self.vectors = np.array(self.vectors)
+        self.labels = np.array(self.labels)
+
+    def __len__(self):
+        return len(self.vectors)
+
+    def __getitem__(self, index):
+        vector_v = self.vectors[index]
+        label = self.labels[index]
+        return torch.tensor(vector_v, dtype=torch.float).unsqueeze(0), torch.tensor(label, dtype=torch.long)
+    
+    def extract_feature_vector(self, segment, qrs_indices):
+        qrs_complexes = []
+        half_window = int(0.125 * self.fs / 2)
+        window_length = 2 * half_window
+
+        for idx in qrs_indices:
+            start_idx = max(0, idx - half_window)
+            end_idx = min(len(segment), idx + half_window)
+            qrs_complex = segment[start_idx:end_idx]
+            if len(qrs_complex) < window_length:
+                padding = window_length - len(qrs_complex)
+                qrs_complex = np.pad(qrs_complex, (0, padding), 'constant')
+
+            qrs_complexes.append(qrs_complex)
+
+        average_qrs = np.mean(qrs_complexes, axis=0) if qrs_complexes else np.zeros(window_length)
+        correlations = [np.correlate(qrs_complex, average_qrs, 'valid')[0] for qrs_complex in qrs_complexes]
+        top_m_indices = np.argsort(correlations)[-self.m:]
+        vector_v = np.concatenate([qrs_complexes[idx] for idx in top_m_indices], axis=0)
+
+        while len(vector_v) < self.m * window_length:
+            vector_v = np.concatenate([vector_v, qrs_complexes[top_m_indices[-1]]])
+        return vector_v[:self.m * window_length] 
+    
\ No newline at end of file