|
a |
|
b/data.py |
|
|
1 |
from re import X |
|
|
2 |
import scipy.io |
|
|
3 |
import numpy as np |
|
|
4 |
import torch |
|
|
5 |
|
|
|
6 |
from torch.utils.data import Dataset, DataLoader |
|
|
7 |
from torch.utils.data.dataloader import default_collate |
|
|
8 |
from sklearn.preprocessing import OrdinalEncoder |
|
|
9 |
|
|
|
10 |
class dataset(Dataset): |
|
|
11 |
def __init__(self, X, y, train=True): |
|
|
12 |
self.X = X |
|
|
13 |
self.y = y |
|
|
14 |
self.train=train |
|
|
15 |
|
|
|
16 |
def __len__(self): |
|
|
17 |
return len(self.y) |
|
|
18 |
|
|
|
19 |
# def __getitem__(self, idx): |
|
|
20 |
# rng = np.random.randint(0, high=200) |
|
|
21 |
# if self.train: |
|
|
22 |
# x = self.X[idx][:, rng:rng + 600] |
|
|
23 |
# else: |
|
|
24 |
# x = self.X[idx][:, 200: 800] |
|
|
25 |
# return x, self.y[idx] |
|
|
26 |
|
|
|
27 |
def __getitem__(self, idx): |
|
|
28 |
x = self.X[idx] |
|
|
29 |
if self.train: |
|
|
30 |
# rn = np.random.randint(0, high=500) |
|
|
31 |
# x = x[:, rn:rn+4000] |
|
|
32 |
x = x[:, 0:4000] |
|
|
33 |
else: |
|
|
34 |
x = x[:, 0:4000] |
|
|
35 |
return x, self.y[idx] |
|
|
36 |
|
|
|
37 |
def generate_x_train(mat): |
|
|
38 |
# out: num_data_points * chl * trial_length |
|
|
39 |
data = [] |
|
|
40 |
last_label = False |
|
|
41 |
for i in range(0, len(mat['mrk'][0][0][0][0])-1): |
|
|
42 |
start_idx = mat['mrk'][0][0][0][0][i] |
|
|
43 |
end_idx = mat['mrk'][0][0][0][0][i+1] |
|
|
44 |
# to resolve shape issues, we use a shifted window |
|
|
45 |
# (possible overlapping but acceptable given it's trivial) |
|
|
46 |
end_idx += (8000 + start_idx - end_idx) |
|
|
47 |
data.append(mat['cnt'][start_idx: end_idx,].T) |
|
|
48 |
# add the last datapoint |
|
|
49 |
if len(mat['cnt']) - mat['mrk'][0][0][0][0][-1] >= 8000: |
|
|
50 |
last_label = True |
|
|
51 |
start_idx = mat['mrk'][0][0][0][0][-1] |
|
|
52 |
end_idx = start_idx + 8000 |
|
|
53 |
data.append(mat['cnt'][start_idx: end_idx,].T) |
|
|
54 |
return np.array(data), last_label |
|
|
55 |
|
|
|
56 |
def generate_y_train(mat, last_label): |
|
|
57 |
# out: 1 * num_labels |
|
|
58 |
class1, class2 = mat['nfo']['classes'][0][0][0][0][0], mat['nfo']['classes'][0][0][0][1][0] |
|
|
59 |
mapping = {-1: class1, 1: class2} |
|
|
60 |
labels = np.vectorize(mapping.get)(mat['mrk'][0][0][1])[0] |
|
|
61 |
if not last_label: |
|
|
62 |
labels = labels[:-1] |
|
|
63 |
return labels |
|
|
64 |
|
|
|
65 |
def generate_data(files): |
|
|
66 |
X, y = [], [] |
|
|
67 |
for file in files: |
|
|
68 |
print(file) |
|
|
69 |
mat = scipy.io.loadmat(file) |
|
|
70 |
X_batch, last_label = generate_x_train(mat) |
|
|
71 |
X.append(X_batch) |
|
|
72 |
y.append(generate_y_train(mat, last_label)) |
|
|
73 |
X, y = np.concatenate(X, axis=0), np.concatenate(y) |
|
|
74 |
y = OrdinalEncoder().fit_transform(y.reshape(-1, 1)) |
|
|
75 |
return X, y |
|
|
76 |
|
|
|
77 |
def split_data(X, y): |
|
|
78 |
def get_idx(): |
|
|
79 |
np.random.seed(seed=42) |
|
|
80 |
rng = np.random.choice(len(y), len(y), replace=False) |
|
|
81 |
return rng |
|
|
82 |
train_size, val_size, test_size = 1000, 197, 200 |
|
|
83 |
indices = get_idx() |
|
|
84 |
train_idx, val_idx, test_idx = indices[0: train_size], \ |
|
|
85 |
indices[train_size: train_size + val_size], indices[train_size + val_size:] |
|
|
86 |
train_X, train_y, val_X, val_y, test_X, test_y = \ |
|
|
87 |
X[train_idx], y[train_idx], X[val_idx], y[val_idx], X[test_idx], y[test_idx] |
|
|
88 |
return train_X, train_y, val_X, val_y, test_X, test_y |
|
|
89 |
|
|
|
90 |
def get_loaders(train_X, train_y, val_X, val_y, test_X, test_y): |
|
|
91 |
train_set, val_set, test_set = dataset(train_X, train_y, True), dataset(val_X, val_y, False), dataset(test_X, test_y, False) |
|
|
92 |
data_loader_train = torch.utils.data.DataLoader( |
|
|
93 |
train_set, |
|
|
94 |
batch_size=1, |
|
|
95 |
num_workers=1, |
|
|
96 |
pin_memory=True, |
|
|
97 |
drop_last=False, |
|
|
98 |
) |
|
|
99 |
data_loader_val = torch.utils.data.DataLoader( |
|
|
100 |
val_set, |
|
|
101 |
batch_size=1, |
|
|
102 |
num_workers=1, |
|
|
103 |
pin_memory=True, |
|
|
104 |
drop_last=False, |
|
|
105 |
) |
|
|
106 |
data_loader_test = torch.utils.data.DataLoader( |
|
|
107 |
test_set, |
|
|
108 |
batch_size=1, |
|
|
109 |
num_workers=1, |
|
|
110 |
pin_memory=True, |
|
|
111 |
drop_last=False, |
|
|
112 |
) |
|
|
113 |
dataloaders = { |
|
|
114 |
'train': data_loader_train, |
|
|
115 |
'val': data_loader_val, |
|
|
116 |
'test': data_loader_test |
|
|
117 |
} |
|
|
118 |
return dataloaders |