--- a +++ b/src/dataset.py @@ -0,0 +1,13 @@ +import torch +from torch.utils.data import Dataset + +class OmicsData(Dataset): + def __init__(self, X, y): + self.X = torch.tensor(X, dtype=torch.float32) + self.y = torch.tensor(y, dtype=torch.float32) + + def __len__(self): + return len(self.X) + + def __getitem__(self, idx): + return self.X[idx], self.y[idx] \ No newline at end of file