Diff of /dataset.py [000000] .. [39d39d]

Switch to unified view

a b/dataset.py
1
import torch
2
import torch.utils.data as data
3
4
class Dataset(data.Dataset):
5
6
    def __init__(self, x, y):
7
8
        self.x = x
9
        self.y = y
10
11
    def __len__(self):
12
13
        return len(self.x)
14
15
    def __getitem__(self, idx):
16
17
        x_item = torch.tensor(self.x[idx]).double()
18
        y_item = torch.tensor(self.y[idx]).long()
19
20
        return x_item, y_item