|
a |
|
b/datasets.py |
|
|
1 |
from torch.utils.data import Dataset, DataLoader |
|
|
2 |
import torch |
|
|
3 |
import random |
|
|
4 |
|
|
|
5 |
class ContrastiveDataset(Dataset): |
|
|
6 |
|
|
|
7 |
def __init__(self, train_seq, train_mask, train_y, positive_prob=0.5): |
|
|
8 |
|
|
|
9 |
super().__init__() |
|
|
10 |
self.train_seq = train_seq |
|
|
11 |
self.train_mask = train_mask |
|
|
12 |
self.train_y = train_y |
|
|
13 |
self.positive_prob = positive_prob # probability to sample two texts with the same category |
|
|
14 |
|
|
|
15 |
self.hash_table = {} # format: {"category" : [i1, i2, ...]} |
|
|
16 |
|
|
|
17 |
# construct a hash table, each key is a category |
|
|
18 |
# and the value is a list of the indexs of the texts which belong to this category |
|
|
19 |
for i in range(len(self.train_seq)): |
|
|
20 |
label = self.train_y[i].item() |
|
|
21 |
if label in self.hash_table: |
|
|
22 |
self.hash_table[label].append(i) |
|
|
23 |
else: |
|
|
24 |
self.hash_table[label] = [i] |
|
|
25 |
|
|
|
26 |
def __getitem__(self, index): |
|
|
27 |
""" |
|
|
28 |
Sample two texts from the same category with probability self.positive_prob |
|
|
29 |
:param index: index (int) |
|
|
30 |
:return: seq_0 - a sequence of IDs (each ID represent a word in the vocabulary) |
|
|
31 |
seq_1 - a sequence which differnt from seq0 (different text) |
|
|
32 |
mask_0 - attention mask for seq1 |
|
|
33 |
mask_1 - attention mask for seq1 |
|
|
34 |
same_class - 1 if seq0 and seq1 are both from the same category, 0 otherwise |
|
|
35 |
""" |
|
|
36 |
same_class = random.uniform(0, 1) |
|
|
37 |
same_class = same_class > self.positive_prob |
|
|
38 |
|
|
|
39 |
seq_0 = self.train_seq[index] |
|
|
40 |
mask_0 = self.train_mask[index] |
|
|
41 |
label_0 = self.train_y[index].item() |
|
|
42 |
class_samples = self.hash_table[label_0] |
|
|
43 |
|
|
|
44 |
if len(class_samples) < 2: # handle the case where there are only a single text in some category (in this case we can't draw another text from this category...) |
|
|
45 |
same_class = False |
|
|
46 |
|
|
|
47 |
if same_class: |
|
|
48 |
while True: |
|
|
49 |
rnd_idx = random.randint(0, len(class_samples) - 1) |
|
|
50 |
index_1 = class_samples[rnd_idx] |
|
|
51 |
if index_1 != index: |
|
|
52 |
seq_1 = self.train_seq[index_1] |
|
|
53 |
mask_1 = self.train_mask[index_1] |
|
|
54 |
label_1 = self.train_y[index_1].item() |
|
|
55 |
break |
|
|
56 |
else: |
|
|
57 |
while True: |
|
|
58 |
index_1 = random.randint(0, self.__len__() - 1) |
|
|
59 |
if index_1 != index: |
|
|
60 |
seq_1 = self.train_seq[index_1] |
|
|
61 |
mask_1 = self.train_mask[index_1] |
|
|
62 |
label_1 = self.train_y[index_1].item() |
|
|
63 |
if label_1 != label_0: |
|
|
64 |
break |
|
|
65 |
|
|
|
66 |
return seq_0, seq_1, mask_0, mask_1, torch.tensor(same_class, dtype=torch.float) |
|
|
67 |
|
|
|
68 |
def __len__(self): |
|
|
69 |
return len(self.train_seq) |
|
|
70 |
|
|
|
71 |
|
|
|
72 |
class SimpleDataset(Dataset): |
|
|
73 |
|
|
|
74 |
def __init__(self, seq, mask, y): |
|
|
75 |
super().__init__() |
|
|
76 |
self.seq = seq |
|
|
77 |
self.mask = mask |
|
|
78 |
self.y = y |
|
|
79 |
|
|
|
80 |
def __getitem__(self, index): |
|
|
81 |
""" |
|
|
82 |
Sample texts by the order of the training set. |
|
|
83 |
:param index: index (int) |
|
|
84 |
:return: seq - a sequence of IDs (each ID represent a word in the vocabulary) |
|
|
85 |
mask - attention mask for seq |
|
|
86 |
y - the category of this text |
|
|
87 |
|
|
|
88 |
""" |
|
|
89 |
return self.seq[index], self.mask[index], torch.tensor(self.y[index].item()) |
|
|
90 |
|
|
|
91 |
def __len__(self): |
|
|
92 |
return len(self.seq) |