[39d39d]: / py_version / dataset.py

Download this file

20 lines (12 with data), 318 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
import torch
import torch.utils.data as data
class Dataset(data.Dataset):
def __init__(self, x, y):
self.x = x
self.y = y
def __len__(self):
return len(self.x)
def __getitem__(self, idx):
x_item = torch.tensor(self.x[idx]).double()
y_item = torch.tensor(self.y[idx]).long()
return x_item, y_item