[6bf179]: / ecg_classification / dataset.py

Download this file

50 lines (37 with data), 1.5 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from .config import Config
class ECGDataset(Dataset):
def __init__(self, df):
self.df = df
self.data_columns = self.df.columns[:-2].tolist()
def __getitem__(self, idx):
signal = self.df.loc[idx, self.data_columns].astype('float32')
signal = torch.FloatTensor([signal.values])
target = torch.LongTensor(np.array(self.df.loc[idx, 'class']))
return signal, target
def __len__(self):
return len(self.df)
def get_dataloader(phase: str, batch_size: int = 96) -> DataLoader:
'''
Dataset and DataLoader.
Parameters:
pahse: training or validation phase.
batch_size: data per iteration.
Returns:
data generator
'''
df = pd.read_csv(config.train_csv_path)
train_df, val_df = train_test_split(
df, test_size=0.15, random_state=config.seed, stratify=df['label']
)
train_df, val_df = train_df.reset_index(drop=True), val_df.reset_index(drop=True)
df = train_df if phase == 'train' else val_df
dataset = ECGDataset(df)
dataloader = DataLoader(dataset=dataset, batch_size=batch_size, num_workers=4)
return dataloader
if __name__ == '__main__':
train_dataloader = get_dataloader(phase='train', batch_size=96)
val_dataloader = get_dataloader(phase='val', batch_size=96)