--- a +++ b/ecg_gan/dataset.py @@ -0,0 +1,37 @@ +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import Dataset, DataLoader + +from .config import config + + +class ECGDataset(Dataset): + + def __init__(self, df): + self.df = df + self.data_columns = self.df.columns[:-2].tolist() + + def __getitem__(self, idx): + signal = self.df.loc[idx, self.data_columns].astype('float32') + signal = torch.FloatTensor([signal.values]) + target = torch.LongTensor(np.array(self.df.loc[idx, 'class'])) + return signal, target + + def __len__(self): + return len(self.df) + +def get_dataloader(label_name, batch_size): + df = pd.read_csv(config.csv_path) + df = df.loc[df['label'] == label_name] + df.reset_index(drop=True, inplace=True) + dataset = ECGDataset(df) + dataloader = DataLoader(dataset=dataset, batch_size=batch_size, num_workers=0) + return dataloader + +if __name__ == '__main__': + config = Config() + dataloader = get_dataloader('Fusion of ventricular and normal', 96) +