--- a +++ b/app/datasets/dl.py @@ -0,0 +1,37 @@ +import pickle + +import numpy as np +import torch +import torch.nn.functional as F +import torch.nn.utils.rnn as rnn_utils +from omegaconf import OmegaConf +from sklearn.model_selection import KFold, StratifiedKFold +from torch import nn +from torch.autograd import Variable +from torch.utils import data +from torch.utils.data import ( + ConcatDataset, + DataLoader, + Dataset, + Subset, + SubsetRandomSampler, + TensorDataset, + random_split, +) + + +class Dataset(data.Dataset): + def __init__(self, x, y, x_lab_length): + self.x = x + self.y = y + self.x_lab_length = x_lab_length + + def __getitem__(self, index): # 返回的是tensor + return self.x[index], self.y[index], self.x_lab_length[index] + + def __len__(self): + return len(self.y) + + +def get_dataset(x, y, x_lab_length): + return Dataset(x, y, x_lab_length)