--- a +++ b/ecg_classification/dataset.py @@ -0,0 +1,49 @@ +import numpy as np + +import torch +from torch.utils.data import Dataset, DataLoader + +from sklearn.model_selection import train_test_split + +from .config import Config + + +class ECGDataset(Dataset): + + def __init__(self, df): + self.df = df + self.data_columns = self.df.columns[:-2].tolist() + + def __getitem__(self, idx): + signal = self.df.loc[idx, self.data_columns].astype('float32') + signal = torch.FloatTensor([signal.values]) + target = torch.LongTensor(np.array(self.df.loc[idx, 'class'])) + return signal, target + + def __len__(self): + return len(self.df) + + +def get_dataloader(phase: str, batch_size: int = 96) -> DataLoader: + ''' + Dataset and DataLoader. + Parameters: + pahse: training or validation phase. + batch_size: data per iteration. + Returns: + data generator + ''' + df = pd.read_csv(config.train_csv_path) + train_df, val_df = train_test_split( + df, test_size=0.15, random_state=config.seed, stratify=df['label'] + ) + train_df, val_df = train_df.reset_index(drop=True), val_df.reset_index(drop=True) + df = train_df if phase == 'train' else val_df + dataset = ECGDataset(df) + dataloader = DataLoader(dataset=dataset, batch_size=batch_size, num_workers=4) + return dataloader + + +if __name__ == '__main__': + train_dataloader = get_dataloader(phase='train', batch_size=96) + val_dataloader = get_dataloader(phase='val', batch_size=96)