Diff of /app/datasets/dl.py [000000] .. [d6904d]

Switch to unified view

a b/app/datasets/dl.py
1
import pickle
2
3
import numpy as np
4
import torch
5
import torch.nn.functional as F
6
import torch.nn.utils.rnn as rnn_utils
7
from omegaconf import OmegaConf
8
from sklearn.model_selection import KFold, StratifiedKFold
9
from torch import nn
10
from torch.autograd import Variable
11
from torch.utils import data
12
from torch.utils.data import (
13
    ConcatDataset,
14
    DataLoader,
15
    Dataset,
16
    Subset,
17
    SubsetRandomSampler,
18
    TensorDataset,
19
    random_split,
20
)
21
22
23
class Dataset(data.Dataset):
24
    def __init__(self, x, y, x_lab_length):
25
        self.x = x
26
        self.y = y
27
        self.x_lab_length = x_lab_length
28
29
    def __getitem__(self, index):  # 返回的是tensor
30
        return self.x[index], self.y[index], self.x_lab_length[index]
31
32
    def __len__(self):
33
        return len(self.y)
34
35
36
def get_dataset(x, y, x_lab_length):
37
    return Dataset(x, y, x_lab_length)