Switch to unified view

a b/ecg_classification/dataset.py
1
import numpy as np 
2
3
import torch
4
from torch.utils.data import Dataset, DataLoader
5
6
from sklearn.model_selection import train_test_split
7
8
from .config import Config
9
10
11
class ECGDataset(Dataset):
12
13
    def __init__(self, df):
14
        self.df = df
15
        self.data_columns = self.df.columns[:-2].tolist()
16
17
    def __getitem__(self, idx):
18
        signal = self.df.loc[idx, self.data_columns].astype('float32')
19
        signal = torch.FloatTensor([signal.values])                 
20
        target = torch.LongTensor(np.array(self.df.loc[idx, 'class']))
21
        return signal, target
22
23
    def __len__(self):
24
        return len(self.df)
25
      
26
      
27
def get_dataloader(phase: str, batch_size: int = 96) -> DataLoader:
28
    '''
29
    Dataset and DataLoader.
30
    Parameters:
31
        pahse: training or validation phase.
32
        batch_size: data per iteration.
33
    Returns:
34
        data generator
35
    '''
36
    df = pd.read_csv(config.train_csv_path)
37
    train_df, val_df = train_test_split(
38
        df, test_size=0.15, random_state=config.seed, stratify=df['label']
39
    )
40
    train_df, val_df = train_df.reset_index(drop=True), val_df.reset_index(drop=True)
41
    df = train_df if phase == 'train' else val_df
42
    dataset = ECGDataset(df)
43
    dataloader = DataLoader(dataset=dataset, batch_size=batch_size, num_workers=4)
44
    return dataloader
45
  
46
  
47
if __name__ == '__main__':
48
    train_dataloader = get_dataloader(phase='train', batch_size=96)
49
    val_dataloader = get_dataloader(phase='val', batch_size=96)