Diff of /dataset.py [000000] .. [dcdaea]

Switch to unified view

a b/dataset.py
1
# -*- coding: utf-8 -*-
2
'''
3
@time: 2019/9/8 19:47
4
5
@ author: javis
6
'''
7
import pywt, os, copy
8
import torch
9
import numpy as np
10
import pandas as pd
11
from config import config
12
from torch.utils.data import Dataset
13
from sklearn.preprocessing import scale
14
from scipy import signal
15
16
17
def resample(sig, target_point_num=None):
18
    '''
19
    对原始信号进行重采样
20
    :param sig: 原始信号
21
    :param target_point_num:目标型号点数
22
    :return: 重采样的信号
23
    '''
24
    sig = signal.resample(sig, target_point_num) if target_point_num else sig
25
    return sig
26
27
def scaling(X, sigma=0.1):
28
    scalingFactor = np.random.normal(loc=1.0, scale=sigma, size=(1, X.shape[1]))
29
    myNoise = np.matmul(np.ones((X.shape[0], 1)), scalingFactor)
30
    return X * myNoise
31
32
def verflip(sig):
33
    '''
34
    信号竖直翻转
35
    :param sig:
36
    :return:
37
    '''
38
    return sig[::-1, :]
39
40
def shift(sig, interval=20):
41
    '''
42
    上下平移
43
    :param sig:
44
    :return:
45
    '''
46
    for col in range(sig.shape[1]):
47
        offset = np.random.choice(range(-interval, interval))
48
        sig[:, col] += offset
49
    return sig
50
51
52
def transform(sig, train=False):
53
    # 前置不可或缺的步骤
54
    sig = resample(sig, config.target_point_num)
55
    # # 数据增强
56
    if train:
57
        if np.random.randn() > 0.5: sig = scaling(sig)
58
        if np.random.randn() > 0.5: sig = verflip(sig)
59
        if np.random.randn() > 0.5: sig = shift(sig)
60
    # 后置不可或缺的步骤
61
    sig = sig.transpose()
62
    sig = torch.tensor(sig.copy(), dtype=torch.float)
63
    return sig
64
65
66
class ECGDataset(Dataset):
67
    """
68
    A generic data loader where the samples are arranged in this way:
69
    dd = {'train': train, 'val': val, "idx2name": idx2name, 'file2idx': file2idx}
70
    """
71
72
    def __init__(self, data_path, train=True):
73
        super(ECGDataset, self).__init__()
74
        dd = torch.load(config.train_data)
75
        self.train = train
76
        self.data = dd['train'] if train else dd['val']
77
        self.idx2name = dd['idx2name']
78
        self.file2idx = dd['file2idx']
79
        self.wc = 1. / np.log(dd['wc'])
80
81
    def __getitem__(self, index):
82
        fid = self.data[index]
83
        file_path = os.path.join(config.train_dir, fid)
84
        df = pd.read_csv(file_path, sep=' ').values
85
        x = transform(df, self.train)
86
        target = np.zeros(config.num_classes)
87
        target[self.file2idx[fid]] = 1
88
        target = torch.tensor(target, dtype=torch.float32)
89
        return x, target
90
91
    def __len__(self):
92
        return len(self.data)
93
94
95
if __name__ == '__main__':
96
    d = ECGDataset(config.train_data)
97
    print(d[0])