Diff of /ecg_gan/dataset.py [000000] .. [6bf179]

Switch to side-by-side view

--- a
+++ b/ecg_gan/dataset.py
@@ -0,0 +1,37 @@
+import numpy as np
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.utils.data import Dataset, DataLoader
+
+from .config import config
+
+
+class ECGDataset(Dataset):
+
+    def __init__(self, df):
+        self.df = df
+        self.data_columns = self.df.columns[:-2].tolist()
+
+    def __getitem__(self, idx):
+        signal = self.df.loc[idx, self.data_columns].astype('float32')
+        signal = torch.FloatTensor([signal.values])                 
+        target = torch.LongTensor(np.array(self.df.loc[idx, 'class']))
+        return signal, target
+
+    def __len__(self):
+        return len(self.df)
+
+def get_dataloader(label_name, batch_size):
+    df = pd.read_csv(config.csv_path)
+    df = df.loc[df['label'] == label_name]
+    df.reset_index(drop=True, inplace=True)
+    dataset = ECGDataset(df)
+    dataloader = DataLoader(dataset=dataset, batch_size=batch_size, num_workers=0)
+    return dataloader
+  
+if __name__ == '__main__':
+    config = Config()
+    dataloader = get_dataloader('Fusion of ventricular and normal', 96)
+