Diff of /ecg_gan/dataset.py [000000] .. [6bf179]

Switch to unified view

a b/ecg_gan/dataset.py
1
import numpy as np
2
3
import torch
4
import torch.nn as nn
5
import torch.nn.functional as F
6
from torch.utils.data import Dataset, DataLoader
7
8
from .config import config
9
10
11
class ECGDataset(Dataset):
12
13
    def __init__(self, df):
14
        self.df = df
15
        self.data_columns = self.df.columns[:-2].tolist()
16
17
    def __getitem__(self, idx):
18
        signal = self.df.loc[idx, self.data_columns].astype('float32')
19
        signal = torch.FloatTensor([signal.values])                 
20
        target = torch.LongTensor(np.array(self.df.loc[idx, 'class']))
21
        return signal, target
22
23
    def __len__(self):
24
        return len(self.df)
25
26
def get_dataloader(label_name, batch_size):
27
    df = pd.read_csv(config.csv_path)
28
    df = df.loc[df['label'] == label_name]
29
    df.reset_index(drop=True, inplace=True)
30
    dataset = ECGDataset(df)
31
    dataloader = DataLoader(dataset=dataset, batch_size=batch_size, num_workers=0)
32
    return dataloader
33
  
34
if __name__ == '__main__':
35
    config = Config()
36
    dataloader = get_dataloader('Fusion of ventricular and normal', 96)
37