Diff of /dataloader.py [000000] .. [0d4320]

Switch to unified view

a b/dataloader.py
1
from torch.utils.data import Dataset
2
import pickle
3
import numpy as np
4
import torch
5
6
7
def load_pickle(fname):
8
    with open(fname, 'rb') as f:  
9
        return pickle.load(f)
10
11
12
13
def downsample(train_idx, neg_young, train_idx_pos):
14
    downsamples = np.random.permutation(neg_young)[:450000]
15
    mask=np.ones(len(train_idx), bool)
16
    mask[downsamples] = False
17
    downsample_idx = np.concatenate((train_idx[mask], np.repeat(train_idx_pos,50)))
18
    return downsample_idx
19
20
21
class OriginalData:
22
    def __init__(self, path):
23
        self.path = path
24
        self.feature_selection = load_pickle(path + 'frts_selection.pkl')
25
        self.x = load_pickle(path + 'preprocess_x.pkl')[:, self.feature_selection]
26
        self.y = load_pickle(path + 'y_bin.pkl')
27
        
28
    def datasampler(self, idx_path, train = True):
29
        idx = load_pickle(self.path + idx_path)
30
        if train:
31
            downsample_idx = downsample(idx, load_pickle(self.path + 'neg_young.pkl'), idx[self.y[idx] == 1])
32
            return self.x[downsample_idx, :], self.y[downsample_idx]
33
        return self.x, self.y
34
35
36
class EHRData(Dataset):
37
    def __init__(self, data, cla):
38
        self.data = data
39
        self.cla = cla
40
        
41
    def __len__(self):
42
        return len(self.cla)
43
        
44
    def __getitem__(self, idx):
45
        return self.data[idx], self.cla[idx]
46
47
48
def collate_fn(data):
49
    # padding
50
    data_list = []
51
    for datum in data:
52
        data_list.append(np.hstack((datum[0].toarray().ravel(), datum[1])))
53
    return torch.from_numpy(np.array(data_list)).long()