a | b/src/dataset.py | ||
---|---|---|---|
1 | import torch |
||
2 | from torch.utils.data import Dataset |
||
3 | |||
4 | class OmicsData(Dataset): |
||
5 | def __init__(self, X, y): |
||
6 | self.X = torch.tensor(X, dtype=torch.float32) |
||
7 | self.y = torch.tensor(y, dtype=torch.float32) |
||
8 | |||
9 | def __len__(self): |
||
10 | return len(self.X) |
||
11 | |||
12 | def __getitem__(self, idx): |
||
13 | return self.X[idx], self.y[idx] |