--- a +++ b/dataset.py @@ -0,0 +1,97 @@ +# -*- coding: utf-8 -*- +''' +@time: 2019/9/8 19:47 + +@ author: javis +''' +import pywt, os, copy +import torch +import numpy as np +import pandas as pd +from config import config +from torch.utils.data import Dataset +from sklearn.preprocessing import scale +from scipy import signal + + +def resample(sig, target_point_num=None): + ''' + 对原始信号进行重采样 + :param sig: 原始信号 + :param target_point_num:目标型号点数 + :return: 重采样的信号 + ''' + sig = signal.resample(sig, target_point_num) if target_point_num else sig + return sig + +def scaling(X, sigma=0.1): + scalingFactor = np.random.normal(loc=1.0, scale=sigma, size=(1, X.shape[1])) + myNoise = np.matmul(np.ones((X.shape[0], 1)), scalingFactor) + return X * myNoise + +def verflip(sig): + ''' + 信号竖直翻转 + :param sig: + :return: + ''' + return sig[::-1, :] + +def shift(sig, interval=20): + ''' + 上下平移 + :param sig: + :return: + ''' + for col in range(sig.shape[1]): + offset = np.random.choice(range(-interval, interval)) + sig[:, col] += offset + return sig + + +def transform(sig, train=False): + # 前置不可或缺的步骤 + sig = resample(sig, config.target_point_num) + # # 数据增强 + if train: + if np.random.randn() > 0.5: sig = scaling(sig) + if np.random.randn() > 0.5: sig = verflip(sig) + if np.random.randn() > 0.5: sig = shift(sig) + # 后置不可或缺的步骤 + sig = sig.transpose() + sig = torch.tensor(sig.copy(), dtype=torch.float) + return sig + + +class ECGDataset(Dataset): + """ + A generic data loader where the samples are arranged in this way: + dd = {'train': train, 'val': val, "idx2name": idx2name, 'file2idx': file2idx} + """ + + def __init__(self, data_path, train=True): + super(ECGDataset, self).__init__() + dd = torch.load(config.train_data) + self.train = train + self.data = dd['train'] if train else dd['val'] + self.idx2name = dd['idx2name'] + self.file2idx = dd['file2idx'] + self.wc = 1. / np.log(dd['wc']) + + def __getitem__(self, index): + fid = self.data[index] + file_path = os.path.join(config.train_dir, fid) + df = pd.read_csv(file_path, sep=' ').values + x = transform(df, self.train) + target = np.zeros(config.num_classes) + target[self.file2idx[fid]] = 1 + target = torch.tensor(target, dtype=torch.float32) + return x, target + + def __len__(self): + return len(self.data) + + +if __name__ == '__main__': + d = ECGDataset(config.train_data) + print(d[0])