a b/data.py
1
from re import X
2
import scipy.io
3
import numpy as np
4
import torch
5
6
from torch.utils.data import Dataset, DataLoader
7
from torch.utils.data.dataloader import default_collate
8
from sklearn.preprocessing import OrdinalEncoder
9
10
class dataset(Dataset):
11
    def __init__(self, X, y, train=True):
12
        self.X = X
13
        self.y = y
14
        self.train=train
15
16
    def __len__(self):
17
        return len(self.y)
18
19
    # def __getitem__(self, idx):
20
    #     rng = np.random.randint(0, high=200)
21
    #     if self.train:
22
    #         x = self.X[idx][:, rng:rng + 600]
23
    #     else:
24
    #         x = self.X[idx][:, 200: 800]
25
    #     return x, self.y[idx]
26
27
    def __getitem__(self, idx):
28
        x = self.X[idx]
29
        if self.train:
30
#             rn = np.random.randint(0, high=500)
31
#             x = x[:, rn:rn+4000]
32
            x = x[:, 0:4000]
33
        else:
34
            x = x[:, 0:4000]
35
        return x, self.y[idx]
36
37
def generate_x_train(mat):
38
    # out: num_data_points * chl * trial_length
39
    data = []
40
    last_label = False
41
    for i in range(0, len(mat['mrk'][0][0][0][0])-1):
42
        start_idx = mat['mrk'][0][0][0][0][i]
43
        end_idx = mat['mrk'][0][0][0][0][i+1]
44
        # to resolve shape issues, we use a shifted window 
45
        # (possible overlapping but acceptable given it's trivial)
46
        end_idx += (8000 + start_idx - end_idx) 
47
        data.append(mat['cnt'][start_idx: end_idx,].T)
48
    # add the last datapoint
49
    if len(mat['cnt']) - mat['mrk'][0][0][0][0][-1] >= 8000:
50
        last_label = True
51
        start_idx = mat['mrk'][0][0][0][0][-1]
52
        end_idx = start_idx + 8000
53
        data.append(mat['cnt'][start_idx: end_idx,].T)
54
    return np.array(data), last_label
55
56
def generate_y_train(mat, last_label):
57
    # out: 1 * num_labels
58
    class1, class2 = mat['nfo']['classes'][0][0][0][0][0], mat['nfo']['classes'][0][0][0][1][0]
59
    mapping = {-1: class1, 1: class2}
60
    labels = np.vectorize(mapping.get)(mat['mrk'][0][0][1])[0]
61
    if not last_label:
62
        labels = labels[:-1]
63
    return labels
64
65
def generate_data(files):
66
    X, y = [], []
67
    for file in files:
68
        print(file)
69
        mat = scipy.io.loadmat(file)
70
        X_batch, last_label = generate_x_train(mat)
71
        X.append(X_batch)
72
        y.append(generate_y_train(mat, last_label))
73
    X, y = np.concatenate(X, axis=0), np.concatenate(y)
74
    y = OrdinalEncoder().fit_transform(y.reshape(-1, 1))
75
    return X, y
76
77
def split_data(X, y):
78
    def get_idx():
79
        np.random.seed(seed=42)
80
        rng = np.random.choice(len(y), len(y), replace=False)
81
        return rng
82
    train_size, val_size, test_size = 1000, 197, 200
83
    indices = get_idx()
84
    train_idx, val_idx, test_idx = indices[0: train_size], \
85
                indices[train_size: train_size + val_size], indices[train_size + val_size:]
86
    train_X, train_y, val_X, val_y, test_X, test_y = \
87
        X[train_idx], y[train_idx], X[val_idx], y[val_idx], X[test_idx], y[test_idx]
88
    return train_X, train_y, val_X, val_y, test_X, test_y
89
90
def get_loaders(train_X, train_y, val_X, val_y, test_X, test_y):
91
    train_set, val_set, test_set = dataset(train_X, train_y, True), dataset(val_X, val_y, False), dataset(test_X, test_y, False)
92
    data_loader_train = torch.utils.data.DataLoader(
93
        train_set, 
94
        batch_size=1, 
95
        num_workers=1,
96
        pin_memory=True, 
97
        drop_last=False,
98
    )
99
    data_loader_val = torch.utils.data.DataLoader(
100
            val_set, 
101
            batch_size=1, 
102
            num_workers=1,
103
            pin_memory=True, 
104
            drop_last=False,
105
    )
106
    data_loader_test = torch.utils.data.DataLoader(
107
            test_set, 
108
            batch_size=1, 
109
            num_workers=1,
110
            pin_memory=True, 
111
            drop_last=False,
112
    )
113
    dataloaders = {
114
        'train': data_loader_train,
115
        'val': data_loader_val,
116
        'test': data_loader_test
117
    }
118
    return dataloaders