[d6904d]: / app / datasets / dl.py

Download this file

38 lines (30 with data), 860 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import pickle
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn.utils.rnn as rnn_utils
from omegaconf import OmegaConf
from sklearn.model_selection import KFold, StratifiedKFold
from torch import nn
from torch.autograd import Variable
from torch.utils import data
from torch.utils.data import (
ConcatDataset,
DataLoader,
Dataset,
Subset,
SubsetRandomSampler,
TensorDataset,
random_split,
)
class Dataset(data.Dataset):
def __init__(self, x, y, x_lab_length):
self.x = x
self.y = y
self.x_lab_length = x_lab_length
def __getitem__(self, index): # 返回的是tensor
return self.x[index], self.y[index], self.x_lab_length[index]
def __len__(self):
return len(self.y)
def get_dataset(x, y, x_lab_length):
return Dataset(x, y, x_lab_length)