Diff of /data.py [000000] .. [5a2c8f]

Switch to unified view

a b/data.py
1
# -*- coding: utf-8 -*-
2
"""
3
Created on Tue Sep 17 11:09:33 2019
4
5
@author: anne marie delaney
6
         eoin brophy
7
8
Data Loading module for GAN training
9
------------------------------------
10
11
Creating the Training Set
12
13
Creating the pytorch dataset class for use with Data Loader to enable batch training of the GAN
14
"""
15
import torch
16
from torch.utils.data import Dataset
17
import pandas as pd
18
19
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
20
21
class ECGData(Dataset):
22
  #This is the class for the ECG Data that we need to load, transform and then use in the dataloader.
23
  def __init__(self,source_file,class_id, transform = None):
24
    self.source_file = source_file
25
    data = pd.read_csv(source_file, header = None)
26
    class_data = data[data[187]==class_id]
27
    self.data = class_data.drop(class_data.iloc[:,187],axis=1)
28
    self.transform = transform
29
    self.class_id = class_id
30
    
31
  def __len__(self):
32
    return self.data.shape[0]
33
    
34
  def __getitem__(self,idx):
35
    sample = self.data.iloc[idx]
36
    if self.transform:
37
        sample = self.transform(sample)
38
    return sample
39
40
"""Including the function that will transform the dataframe to a pytorch tensor"""
41
42
class PD_to_Tensor(object):
43
    def __call__(self,sample):
44
      return torch.tensor(sample.values).to(device)
45