[6bf179]: / ecg_gan / dataset.py

Download this file

38 lines (27 with data), 1.1 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from .config import config
class ECGDataset(Dataset):
def __init__(self, df):
self.df = df
self.data_columns = self.df.columns[:-2].tolist()
def __getitem__(self, idx):
signal = self.df.loc[idx, self.data_columns].astype('float32')
signal = torch.FloatTensor([signal.values])
target = torch.LongTensor(np.array(self.df.loc[idx, 'class']))
return signal, target
def __len__(self):
return len(self.df)
def get_dataloader(label_name, batch_size):
df = pd.read_csv(config.csv_path)
df = df.loc[df['label'] == label_name]
df.reset_index(drop=True, inplace=True)
dataset = ECGDataset(df)
dataloader = DataLoader(dataset=dataset, batch_size=batch_size, num_workers=0)
return dataloader
if __name__ == '__main__':
config = Config()
dataloader = get_dataloader('Fusion of ventricular and normal', 96)