|
a |
|
b/dataloaders/dataset1d.py |
|
|
1 |
import json |
|
|
2 |
|
|
|
3 |
import numpy as np |
|
|
4 |
import wfdb |
|
|
5 |
from scipy.signal import find_peaks |
|
|
6 |
from sklearn.preprocessing import scale |
|
|
7 |
from torch.utils.data import DataLoader, Dataset |
|
|
8 |
|
|
|
9 |
|
|
|
10 |
class EcgDataset1D(Dataset): |
|
|
11 |
def __init__(self, ann_path, mapping_path): |
|
|
12 |
super().__init__() |
|
|
13 |
self.data = json.load(open(ann_path)) |
|
|
14 |
self.mapper = json.load(open(mapping_path)) |
|
|
15 |
|
|
|
16 |
def __getitem__(self, index): |
|
|
17 |
img = np.load(self.data[index]["path"]).astype("float32") |
|
|
18 |
img = img.reshape(1, img.shape[0]) |
|
|
19 |
|
|
|
20 |
return {"image": img, "class": self.mapper[self.data[index]["label"]]} |
|
|
21 |
|
|
|
22 |
def get_dataloader(self, num_workers=4, batch_size=16, shuffle=True): |
|
|
23 |
data_loader = DataLoader( |
|
|
24 |
self, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, |
|
|
25 |
) |
|
|
26 |
return data_loader |
|
|
27 |
|
|
|
28 |
def __len__(self): |
|
|
29 |
return len(self.data) |
|
|
30 |
|
|
|
31 |
|
|
|
32 |
def callback_get_label(dataset, idx): |
|
|
33 |
return dataset[idx]["class"] |
|
|
34 |
|
|
|
35 |
|
|
|
36 |
class EcgPipelineDataset1D(Dataset): |
|
|
37 |
def __init__(self, path, mode=128): |
|
|
38 |
super().__init__() |
|
|
39 |
record = wfdb.rdrecord(path) |
|
|
40 |
self.signal = None |
|
|
41 |
self.mode = mode |
|
|
42 |
for sig_name, signal in zip(record.sig_name, record.p_signal.T): |
|
|
43 |
if sig_name in ["MLII", "II"] and np.all(np.isfinite(signal)): |
|
|
44 |
self.signal = scale(signal).astype("float32") |
|
|
45 |
if self.signal is None: |
|
|
46 |
raise Exception("No MLII LEAD") |
|
|
47 |
|
|
|
48 |
self.peaks = find_peaks(self.signal, distance=180)[0] |
|
|
49 |
mask_left = (self.peaks - self.mode // 2) > 0 |
|
|
50 |
mask_right = (self.peaks + self.mode // 2) < len(self.signal) |
|
|
51 |
mask = mask_left & mask_right |
|
|
52 |
self.peaks = self.peaks[mask] |
|
|
53 |
|
|
|
54 |
def __getitem__(self, index): |
|
|
55 |
peak = self.peaks[index] |
|
|
56 |
left, right = peak - self.mode // 2, peak + self.mode // 2 |
|
|
57 |
|
|
|
58 |
img = self.signal[left:right] |
|
|
59 |
img = img.reshape(1, img.shape[0]) |
|
|
60 |
|
|
|
61 |
return {"image": img, "peak": peak} |
|
|
62 |
|
|
|
63 |
def get_dataloader(self, num_workers=4, batch_size=16, shuffle=True): |
|
|
64 |
data_loader = DataLoader( |
|
|
65 |
self, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, |
|
|
66 |
) |
|
|
67 |
return data_loader |
|
|
68 |
|
|
|
69 |
def __len__(self): |
|
|
70 |
return len(self.peaks) |