Diff of /class_ecgdataset.py [000000] .. [6536f9]

Switch to unified view

a b/class_ecgdataset.py
1
import os
2
import torch
3
import numpy as np
4
from torch.utils.data import Dataset
5
6
#----------------------------------------------------
7
#----------storing-signals-implementation------------
8
#----------------------------------------------------
9
10
class ECGDataset(Dataset):
11
    def __init__(self, data_path, patient_ids, fs, n_windows, n_seconds, leads=None):
12
        self.data_path = data_path
13
        self.patient_ids = patient_ids
14
        self.fs = fs
15
        self.n_windows = n_windows
16
        self.n_seconds = n_seconds
17
        self.leads = leads
18
        self.segments = []
19
        self.labels = []
20
        self.id_mapped = {pid: idx for idx, pid in enumerate(patient_ids)}
21
        
22
        # Precompute and store segments
23
        for patient_id in self.patient_ids:
24
            signal = self.load_signal(patient_id)
25
            start_points = self.generate_starts(signal)
26
            for start_point in start_points:
27
                end_point = start_point + self.n_seconds * self.fs
28
                segment = signal[:, start_point:end_point]
29
                self.segments.append(segment)
30
                self.labels.append(self.id_mapped[patient_id])
31
32
        self.segments = np.array(self.segments)
33
        self.labels = np.array(self.labels)
34
                
35
    def load_signal(self, patient_id):
36
        signal_path = os.path.join(self.data_path, f"{patient_id}_signal.npy")
37
        signal = np.load(signal_path)
38
        signal = signal[:, :15*3600*self.fs]
39
        if self.leads is not None:
40
            signal = signal[self.leads, :]
41
        return signal
42
43
    def generate_starts(self, signal):
44
        max_start = signal.shape[1] - self.n_seconds * self.fs
45
        all_starts = np.arange(0, max_start, self.n_seconds * self.fs)
46
        chosen_starts = np.random.choice(all_starts, self.n_windows, replace=False)
47
        return chosen_starts
48
    
49
    def __len__(self):
50
        return len(self.segments)
51
52
    def __getitem__(self, index):
53
        segment = self.segments[index]
54
        label = self.labels[index]
55
        return torch.tensor(segment, dtype=torch.float), torch.tensor(label, dtype=torch.long)
56
57
58
59
#----------------------------------------------------
60
#-------------on-the-go-implementation---------------
61
#----------------------------------------------------
62
63
class ECGDataset_on_the_fly(Dataset):
64
    def __init__(self, data_path, patient_ids, fs, n_windows, n_seconds, leads=None):
65
        self.data_path = data_path
66
        self.patient_ids = patient_ids
67
        self.fs = fs
68
        self.n_windows = n_windows
69
        self.n_seconds = n_seconds
70
        self.leads = leads
71
        self.signal_cache = {}
72
        self.id_mapped = {pid: idx for idx, pid in enumerate(patient_ids)}
73
        self.segment_starts = {pid: self.generate_starts(pid) for pid in patient_ids}
74
        
75
    def load_signal(self, patient_id):
76
        if patient_id not in self.signal_cache:
77
            signal_path = os.path.join(self.data_path, f"{patient_id}_signal.npy")
78
            signal = np.load(signal_path)
79
            signal = signal[:, :15*3600*self.fs]
80
            if self.leads is not None:
81
                signal = signal[self.leads, :15*3600*self.fs]
82
            self.signal_cache[patient_id] = signal
83
        return self.signal_cache[patient_id]
84
85
    def generate_starts(self, patient_id):
86
        signal = self.load_signal(patient_id)
87
        max_start = signal.shape[1] - self.n_seconds * self.fs
88
        # make a grid of points that are n_seconds apart
89
        all_starts = np.arange(0, max_start, self.n_seconds * self.fs) 
90
        # and choose n_windows of them as our starting pts
91
        chosen_starts = np.random.choice(all_starts, self.n_windows, replace=False)
92
        return chosen_starts
93
    
94
    def __len__(self):
95
        return len(self.patient_ids) * self.n_windows
96
97
    def __getitem__(self, index):
98
        patient_index = index // self.n_windows
99
        patient_id = self.patient_ids[patient_index]
100
        window_index = index % self.n_windows
101
        start_point = self.segment_starts[patient_id][window_index]
102
        end_point = start_point + self.n_seconds * self.fs
103
104
        signal = self.load_signal(patient_id)
105
        segment = signal[:, start_point:end_point]
106
        label = self.id_mapped[patient_id]  
107
108
        return torch.tensor(segment, dtype=torch.float), torch.tensor(label, dtype=torch.long)
109