Diff of /dataloaders/dataset1d.py [000000] .. [fbbdf8]

Switch to unified view

a b/dataloaders/dataset1d.py
1
import json
2
3
import numpy as np
4
import wfdb
5
from scipy.signal import find_peaks
6
from sklearn.preprocessing import scale
7
from torch.utils.data import DataLoader, Dataset
8
9
10
class EcgDataset1D(Dataset):
11
    def __init__(self, ann_path, mapping_path):
12
        super().__init__()
13
        self.data = json.load(open(ann_path))
14
        self.mapper = json.load(open(mapping_path))
15
16
    def __getitem__(self, index):
17
        img = np.load(self.data[index]["path"]).astype("float32")
18
        img = img.reshape(1, img.shape[0])
19
20
        return {"image": img, "class": self.mapper[self.data[index]["label"]]}
21
22
    def get_dataloader(self, num_workers=4, batch_size=16, shuffle=True):
23
        data_loader = DataLoader(
24
            self, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers,
25
        )
26
        return data_loader
27
28
    def __len__(self):
29
        return len(self.data)
30
31
32
def callback_get_label(dataset, idx):
33
    return dataset[idx]["class"]
34
35
36
class EcgPipelineDataset1D(Dataset):
37
    def __init__(self, path, mode=128):
38
        super().__init__()
39
        record = wfdb.rdrecord(path)
40
        self.signal = None
41
        self.mode = mode
42
        for sig_name, signal in zip(record.sig_name, record.p_signal.T):
43
            if sig_name in ["MLII", "II"] and np.all(np.isfinite(signal)):
44
                self.signal = scale(signal).astype("float32")
45
        if self.signal is None:
46
            raise Exception("No MLII LEAD")
47
48
        self.peaks = find_peaks(self.signal, distance=180)[0]
49
        mask_left = (self.peaks - self.mode // 2) > 0
50
        mask_right = (self.peaks + self.mode // 2) < len(self.signal)
51
        mask = mask_left & mask_right
52
        self.peaks = self.peaks[mask]
53
54
    def __getitem__(self, index):
55
        peak = self.peaks[index]
56
        left, right = peak - self.mode // 2, peak + self.mode // 2
57
58
        img = self.signal[left:right]
59
        img = img.reshape(1, img.shape[0])
60
61
        return {"image": img, "peak": peak}
62
63
    def get_dataloader(self, num_workers=4, batch_size=16, shuffle=True):
64
        data_loader = DataLoader(
65
            self, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers,
66
        )
67
        return data_loader
68
69
    def __len__(self):
70
        return len(self.peaks)