[f539ea]: / src / dataset.py

Download this file

13 lines (10 with data), 342 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
import torch
from torch.utils.data import Dataset
class OmicsData(Dataset):
def __init__(self, X, y):
self.X = torch.tensor(X, dtype=torch.float32)
self.y = torch.tensor(y, dtype=torch.float32)
def __len__(self):
return len(self.X)
def __getitem__(self, idx):
return self.X[idx], self.y[idx]