a | b/dataset.py | ||
---|---|---|---|
1 | import torch |
||
2 | import torch.utils.data as data |
||
3 | |||
4 | class Dataset(data.Dataset): |
||
5 | |||
6 | def __init__(self, x, y): |
||
7 | |||
8 | self.x = x |
||
9 | self.y = y |
||
10 | |||
11 | def __len__(self): |
||
12 | |||
13 | return len(self.x) |
||
14 | |||
15 | def __getitem__(self, idx): |
||
16 | |||
17 | x_item = torch.tensor(self.x[idx]).double() |
||
18 | y_item = torch.tensor(self.y[idx]).long() |
||
19 | |||
20 | return x_item, y_item |