|
a |
|
b/app/datasets/dl.py |
|
|
1 |
import pickle |
|
|
2 |
|
|
|
3 |
import numpy as np |
|
|
4 |
import torch |
|
|
5 |
import torch.nn.functional as F |
|
|
6 |
import torch.nn.utils.rnn as rnn_utils |
|
|
7 |
from omegaconf import OmegaConf |
|
|
8 |
from sklearn.model_selection import KFold, StratifiedKFold |
|
|
9 |
from torch import nn |
|
|
10 |
from torch.autograd import Variable |
|
|
11 |
from torch.utils import data |
|
|
12 |
from torch.utils.data import ( |
|
|
13 |
ConcatDataset, |
|
|
14 |
DataLoader, |
|
|
15 |
Dataset, |
|
|
16 |
Subset, |
|
|
17 |
SubsetRandomSampler, |
|
|
18 |
TensorDataset, |
|
|
19 |
random_split, |
|
|
20 |
) |
|
|
21 |
|
|
|
22 |
|
|
|
23 |
class Dataset(data.Dataset): |
|
|
24 |
def __init__(self, x, y, x_lab_length): |
|
|
25 |
self.x = x |
|
|
26 |
self.y = y |
|
|
27 |
self.x_lab_length = x_lab_length |
|
|
28 |
|
|
|
29 |
def __getitem__(self, index): # 返回的是tensor |
|
|
30 |
return self.x[index], self.y[index], self.x_lab_length[index] |
|
|
31 |
|
|
|
32 |
def __len__(self): |
|
|
33 |
return len(self.y) |
|
|
34 |
|
|
|
35 |
|
|
|
36 |
def get_dataset(x, y, x_lab_length): |
|
|
37 |
return Dataset(x, y, x_lab_length) |