|
a |
|
b/dataset.py |
|
|
1 |
# -*- coding: utf-8 -*- |
|
|
2 |
''' |
|
|
3 |
@time: 2019/9/8 19:47 |
|
|
4 |
|
|
|
5 |
@ author: javis |
|
|
6 |
''' |
|
|
7 |
import pywt, os, copy |
|
|
8 |
import torch |
|
|
9 |
import numpy as np |
|
|
10 |
import pandas as pd |
|
|
11 |
from config import config |
|
|
12 |
from torch.utils.data import Dataset |
|
|
13 |
from sklearn.preprocessing import scale |
|
|
14 |
from scipy import signal |
|
|
15 |
|
|
|
16 |
|
|
|
17 |
def resample(sig, target_point_num=None): |
|
|
18 |
''' |
|
|
19 |
对原始信号进行重采样 |
|
|
20 |
:param sig: 原始信号 |
|
|
21 |
:param target_point_num:目标型号点数 |
|
|
22 |
:return: 重采样的信号 |
|
|
23 |
''' |
|
|
24 |
sig = signal.resample(sig, target_point_num) if target_point_num else sig |
|
|
25 |
return sig |
|
|
26 |
|
|
|
27 |
def scaling(X, sigma=0.1): |
|
|
28 |
scalingFactor = np.random.normal(loc=1.0, scale=sigma, size=(1, X.shape[1])) |
|
|
29 |
myNoise = np.matmul(np.ones((X.shape[0], 1)), scalingFactor) |
|
|
30 |
return X * myNoise |
|
|
31 |
|
|
|
32 |
def verflip(sig): |
|
|
33 |
''' |
|
|
34 |
信号竖直翻转 |
|
|
35 |
:param sig: |
|
|
36 |
:return: |
|
|
37 |
''' |
|
|
38 |
return sig[::-1, :] |
|
|
39 |
|
|
|
40 |
def shift(sig, interval=20): |
|
|
41 |
''' |
|
|
42 |
上下平移 |
|
|
43 |
:param sig: |
|
|
44 |
:return: |
|
|
45 |
''' |
|
|
46 |
for col in range(sig.shape[1]): |
|
|
47 |
offset = np.random.choice(range(-interval, interval)) |
|
|
48 |
sig[:, col] += offset |
|
|
49 |
return sig |
|
|
50 |
|
|
|
51 |
|
|
|
52 |
def transform(sig, train=False): |
|
|
53 |
# 前置不可或缺的步骤 |
|
|
54 |
sig = resample(sig, config.target_point_num) |
|
|
55 |
# # 数据增强 |
|
|
56 |
if train: |
|
|
57 |
if np.random.randn() > 0.5: sig = scaling(sig) |
|
|
58 |
if np.random.randn() > 0.5: sig = verflip(sig) |
|
|
59 |
if np.random.randn() > 0.5: sig = shift(sig) |
|
|
60 |
# 后置不可或缺的步骤 |
|
|
61 |
sig = sig.transpose() |
|
|
62 |
sig = torch.tensor(sig.copy(), dtype=torch.float) |
|
|
63 |
return sig |
|
|
64 |
|
|
|
65 |
|
|
|
66 |
class ECGDataset(Dataset): |
|
|
67 |
""" |
|
|
68 |
A generic data loader where the samples are arranged in this way: |
|
|
69 |
dd = {'train': train, 'val': val, "idx2name": idx2name, 'file2idx': file2idx} |
|
|
70 |
""" |
|
|
71 |
|
|
|
72 |
def __init__(self, data_path, train=True): |
|
|
73 |
super(ECGDataset, self).__init__() |
|
|
74 |
dd = torch.load(config.train_data) |
|
|
75 |
self.train = train |
|
|
76 |
self.data = dd['train'] if train else dd['val'] |
|
|
77 |
self.idx2name = dd['idx2name'] |
|
|
78 |
self.file2idx = dd['file2idx'] |
|
|
79 |
self.wc = 1. / np.log(dd['wc']) |
|
|
80 |
|
|
|
81 |
def __getitem__(self, index): |
|
|
82 |
fid = self.data[index] |
|
|
83 |
file_path = os.path.join(config.train_dir, fid) |
|
|
84 |
df = pd.read_csv(file_path, sep=' ').values |
|
|
85 |
x = transform(df, self.train) |
|
|
86 |
target = np.zeros(config.num_classes) |
|
|
87 |
target[self.file2idx[fid]] = 1 |
|
|
88 |
target = torch.tensor(target, dtype=torch.float32) |
|
|
89 |
return x, target |
|
|
90 |
|
|
|
91 |
def __len__(self): |
|
|
92 |
return len(self.data) |
|
|
93 |
|
|
|
94 |
|
|
|
95 |
if __name__ == '__main__': |
|
|
96 |
d = ECGDataset(config.train_data) |
|
|
97 |
print(d[0]) |