Diff of /data.py [000000] .. [597177]

Switch to side-by-side view

--- a
+++ b/data.py
@@ -0,0 +1,118 @@
+from re import X
+import scipy.io
+import numpy as np
+import torch
+
+from torch.utils.data import Dataset, DataLoader
+from torch.utils.data.dataloader import default_collate
+from sklearn.preprocessing import OrdinalEncoder
+
+class dataset(Dataset):
+    def __init__(self, X, y, train=True):
+        self.X = X
+        self.y = y
+        self.train=train
+
+    def __len__(self):
+        return len(self.y)
+
+    # def __getitem__(self, idx):
+    #     rng = np.random.randint(0, high=200)
+    #     if self.train:
+    #         x = self.X[idx][:, rng:rng + 600]
+    #     else:
+    #         x = self.X[idx][:, 200: 800]
+    #     return x, self.y[idx]
+
+    def __getitem__(self, idx):
+        x = self.X[idx]
+        if self.train:
+#             rn = np.random.randint(0, high=500)
+#             x = x[:, rn:rn+4000]
+            x = x[:, 0:4000]
+        else:
+            x = x[:, 0:4000]
+        return x, self.y[idx]
+
+def generate_x_train(mat):
+    # out: num_data_points * chl * trial_length
+    data = []
+    last_label = False
+    for i in range(0, len(mat['mrk'][0][0][0][0])-1):
+        start_idx = mat['mrk'][0][0][0][0][i]
+        end_idx = mat['mrk'][0][0][0][0][i+1]
+        # to resolve shape issues, we use a shifted window 
+        # (possible overlapping but acceptable given it's trivial)
+        end_idx += (8000 + start_idx - end_idx) 
+        data.append(mat['cnt'][start_idx: end_idx,].T)
+    # add the last datapoint
+    if len(mat['cnt']) - mat['mrk'][0][0][0][0][-1] >= 8000:
+        last_label = True
+        start_idx = mat['mrk'][0][0][0][0][-1]
+        end_idx = start_idx + 8000
+        data.append(mat['cnt'][start_idx: end_idx,].T)
+    return np.array(data), last_label
+
+def generate_y_train(mat, last_label):
+    # out: 1 * num_labels
+    class1, class2 = mat['nfo']['classes'][0][0][0][0][0], mat['nfo']['classes'][0][0][0][1][0]
+    mapping = {-1: class1, 1: class2}
+    labels = np.vectorize(mapping.get)(mat['mrk'][0][0][1])[0]
+    if not last_label:
+        labels = labels[:-1]
+    return labels
+
+def generate_data(files):
+    X, y = [], []
+    for file in files:
+        print(file)
+        mat = scipy.io.loadmat(file)
+        X_batch, last_label = generate_x_train(mat)
+        X.append(X_batch)
+        y.append(generate_y_train(mat, last_label))
+    X, y = np.concatenate(X, axis=0), np.concatenate(y)
+    y = OrdinalEncoder().fit_transform(y.reshape(-1, 1))
+    return X, y
+
+def split_data(X, y):
+    def get_idx():
+        np.random.seed(seed=42)
+        rng = np.random.choice(len(y), len(y), replace=False)
+        return rng
+    train_size, val_size, test_size = 1000, 197, 200
+    indices = get_idx()
+    train_idx, val_idx, test_idx = indices[0: train_size], \
+                indices[train_size: train_size + val_size], indices[train_size + val_size:]
+    train_X, train_y, val_X, val_y, test_X, test_y = \
+        X[train_idx], y[train_idx], X[val_idx], y[val_idx], X[test_idx], y[test_idx]
+    return train_X, train_y, val_X, val_y, test_X, test_y
+
+def get_loaders(train_X, train_y, val_X, val_y, test_X, test_y):
+    train_set, val_set, test_set = dataset(train_X, train_y, True), dataset(val_X, val_y, False), dataset(test_X, test_y, False)
+    data_loader_train = torch.utils.data.DataLoader(
+        train_set, 
+        batch_size=1, 
+        num_workers=1,
+        pin_memory=True, 
+        drop_last=False,
+    )
+    data_loader_val = torch.utils.data.DataLoader(
+            val_set, 
+            batch_size=1, 
+            num_workers=1,
+            pin_memory=True, 
+            drop_last=False,
+    )
+    data_loader_test = torch.utils.data.DataLoader(
+            test_set, 
+            batch_size=1, 
+            num_workers=1,
+            pin_memory=True, 
+            drop_last=False,
+    )
+    dataloaders = {
+        'train': data_loader_train,
+        'val': data_loader_val,
+        'test': data_loader_test
+    }
+    return dataloaders