a b/datasets.py
1
from torch.utils.data import Dataset, DataLoader
2
import torch
3
import random
4
5
class ContrastiveDataset(Dataset):
6
7
    def __init__(self, train_seq, train_mask, train_y, positive_prob=0.5):
8
9
        super().__init__()
10
        self.train_seq = train_seq
11
        self.train_mask = train_mask
12
        self.train_y = train_y
13
        self.positive_prob = positive_prob  # probability to sample two texts with the same category
14
15
        self.hash_table = {}  # format: {"category" : [i1, i2, ...]}
16
17
        # construct a hash table, each key is a category
18
        # and the value is a list of the indexs of the texts which belong to this category
19
        for i in range(len(self.train_seq)):
20
            label = self.train_y[i].item()
21
            if label in self.hash_table:
22
                self.hash_table[label].append(i)
23
            else:
24
                self.hash_table[label] = [i]
25
26
    def __getitem__(self, index):
27
        """
28
        Sample two texts from the same category with probability self.positive_prob
29
        :param index: index (int)
30
        :return:  seq_0 - a sequence of IDs (each ID represent a word in the vocabulary)
31
                  seq_1 - a sequence which differnt from seq0 (different text)
32
                  mask_0 - attention mask for seq1
33
                  mask_1 - attention mask for seq1
34
                  same_class - 1 if seq0 and seq1 are both from the same category, 0 otherwise
35
        """
36
        same_class = random.uniform(0, 1)
37
        same_class = same_class > self.positive_prob
38
39
        seq_0 = self.train_seq[index]
40
        mask_0 = self.train_mask[index]
41
        label_0 = self.train_y[index].item()
42
        class_samples = self.hash_table[label_0]
43
44
        if len(class_samples) < 2:  # handle the case where there are only a single text in some category (in this case we can't draw another text from this category...)
45
            same_class = False
46
47
        if same_class:
48
            while True:
49
                rnd_idx = random.randint(0, len(class_samples) - 1)
50
                index_1 = class_samples[rnd_idx]
51
                if index_1 != index:
52
                    seq_1 = self.train_seq[index_1]
53
                    mask_1 = self.train_mask[index_1]
54
                    label_1 = self.train_y[index_1].item()
55
                    break
56
        else:
57
            while True:
58
                index_1 = random.randint(0, self.__len__() - 1)
59
                if index_1 != index:
60
                    seq_1 = self.train_seq[index_1]
61
                    mask_1 = self.train_mask[index_1]
62
                    label_1 = self.train_y[index_1].item()
63
                    if label_1 != label_0:
64
                        break
65
66
        return seq_0, seq_1, mask_0, mask_1, torch.tensor(same_class, dtype=torch.float)
67
68
    def __len__(self):
69
        return len(self.train_seq)
70
71
72
class SimpleDataset(Dataset):
73
74
    def __init__(self, seq, mask, y):
75
        super().__init__()
76
        self.seq = seq
77
        self.mask = mask
78
        self.y = y
79
80
    def __getitem__(self, index):
81
        """
82
        Sample texts by the order of the training set.
83
        :param index: index (int)
84
        :return: seq - a sequence of IDs (each ID represent a word in the vocabulary)
85
                 mask - attention mask for seq
86
                 y - the category of this text
87
88
        """
89
        return self.seq[index], self.mask[index], torch.tensor(self.y[index].item())
90
91
    def __len__(self):
92
        return len(self.seq)