[d3ab9c]: / bert_mixup / late_mixup / data.py

Download this file

261 lines (228 with data), 8.6 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
import pandas as pd
import torch
from enumeration import SmilesEnumerator
from keras_preprocessing.sequence import pad_sequences
from transformers import BertTokenizer
from tqdm import tqdm
import logging
from deepchem.molnet import load_bbbp, load_bace_classification
import numpy as np
from tqdm import tqdm
## setting the threshold of logger to INFO
logging.basicConfig(filename="data_loader.log", level=logging.INFO)
## creating an object
logger = logging.getLogger()
MOLECULE_NET_DATASETS = {"bbbp": load_bbbp, "bace": load_bace_classification}
class MoleculeData:
def __init__(
self,
dataset_name,
max_sequence_length=512,
debug=0,
n_augment=0,
samples_per_class=-1,
model_name_or_path="shahrukhx01/smole-bert",
):
"""
Load dataset and bert tokenizer
"""
self.debug = debug
## load data into memory
tasks, datasets, transformers = MOLECULE_NET_DATASETS[dataset_name](
reload=False
)
self.train_dataset, self.valid_dataset, self.test_dataset = datasets
## set max sequence length for model
self.max_sequence_length = max_sequence_length
## get bert tokenizer
self.tokenizer = BertTokenizer.from_pretrained(
model_name_or_path, do_lower_case=True
)
self.enumerator = SmilesEnumerator()
self.n_augment = n_augment
self.samples_per_class = samples_per_class
def train_val_test_split(self):
"""
Separate out labels and texts
"""
num_samples = 1_000_000
if self.debug:
print("Debug mode is enabled")
num_samples = 100
train_molecules = self.train_dataset.ids[:num_samples]
train_labels = np.array(
[int(label[0]) for label in self.train_dataset.y][:num_samples]
)
self.indices = []
tp, tn = [], []
self.augmented_data_index = len(train_molecules)
self.label_df = pd.DataFrame(train_labels, columns=["labels"])
if self.samples_per_class > 0:
np.random.seed()
tp = np.random.choice(
list(self.label_df[self.label_df["labels"] == 1].index),
self.samples_per_class,
replace=False,
)
tn = np.random.choice(
list(self.label_df[self.label_df["labels"] == 0].index),
self.samples_per_class,
replace=False,
)
self.indices = list(tp) + list(tn)
aug_molecules, aug_labels = [], []
if self.n_augment:
for train_smiles, train_label in tqdm(
zip(train_molecules[self.indices], train_labels[self.indices])
):
molecules_augmented = self.enumerator.smiles_enumeration(
input_smiles=train_smiles, n_augment=self.n_augment
)
if len(molecules_augmented):
train_augmented_labels = [train_label] * len(molecules_augmented)
aug_molecules += molecules_augmented
aug_labels += train_augmented_labels
if len(aug_molecules) and len(aug_molecules) == len(aug_labels):
train_molecules, train_labels = list(train_molecules), list(train_labels)
train_molecules += aug_molecules
train_labels += aug_labels
val_molecules = self.valid_dataset.ids
val_labels = np.array([int(label[0]) for label in self.valid_dataset.y])
test_molecules = self.test_dataset.ids
test_labels = np.array([int(label[0]) for label in self.test_dataset.y])
return (
train_molecules,
val_molecules,
test_molecules,
train_labels,
val_labels,
test_labels,
)
def preprocess(self, texts):
"""
Add bert token (CLS and SEP) tokens to each sequence pre-tokenization
"""
## separate labels and texts before preprocessing
# Adding CLS and SEP tokens at the beginning and end of each sequence for BERT
texts_processed = ["[CLS] " + str(sequence) + " [SEP]" for sequence in texts]
return texts_processed
def tokenize(self, texts):
"""
Use bert tokenizer to tokenize each sequence and post-process
by padding or truncating to a fixed length
"""
## tokenize sequence
tokenized_molecules = [self.tokenizer.tokenize(text) for text in tqdm(texts)]
## convert tokens to ids
print("convert tokens to ids")
text_ids = [
self.tokenizer.convert_tokens_to_ids(x) for x in tqdm(tokenized_molecules)
]
## pad our text tokens for each sequence
print("pad our text tokens for each sequence")
text_ids_post_processed = pad_sequences(
text_ids,
maxlen=self.max_sequence_length,
dtype="long",
truncating="post",
padding="post",
)
return text_ids_post_processed
def create_attention_mask(self, text_ids):
"""
Add attention mask for padding tokens
"""
attention_masks = []
# create a mask of 1s for each token followed by 0s for padding
for seq in tqdm(text_ids):
seq_mask = [float(i > 0) for i in seq]
attention_masks.append(seq_mask)
return attention_masks
def process_molecules(self):
"""
Apply preprocessing and tokenization pipeline of texts
"""
## perform the split
(
train_molecules,
val_molecules,
test_molecules,
train_labels,
val_labels,
test_labels,
) = self.train_val_test_split()
print("preprocessing texts")
## preprocess train, val, test texts
train_molecules_processed = self.preprocess(train_molecules)
val_molecules_processed = self.preprocess(val_molecules)
test_molecules_processed = self.preprocess(test_molecules)
del train_molecules
del val_molecules
del test_molecules
## preprocess train, val, test texts
print("tokenizing train texts")
train_ids = self.tokenize(train_molecules_processed)
print("tokenizing val texts")
val_ids = self.tokenize(val_molecules_processed)
print("tokenizing test texts")
test_ids = self.tokenize(test_molecules_processed)
del train_molecules_processed
del val_molecules_processed
del test_molecules_processed
## create masks for train, val, test texts
print("creating train attention masks for texts")
train_masks = self.create_attention_mask(train_ids)
print("creating val attention masks for texts")
val_masks = self.create_attention_mask(val_ids)
print("creating test attention masks for texts")
test_masks = self.create_attention_mask(test_ids)
return (
train_ids,
val_ids,
test_ids,
train_masks,
val_masks,
test_masks,
train_labels,
val_labels,
test_labels,
)
def text_to_tensors(self):
"""
Converting all the data into torch tensors
"""
(
train_ids,
val_ids,
test_ids,
train_masks,
val_masks,
test_masks,
train_labels,
val_labels,
test_labels,
) = self.process_molecules()
print("converting all variables to tensors")
## convert inputs, masks and labels to torch tensors
self.train_inputs = torch.tensor(train_ids)
train_values = np.max(train_labels) + 1
self.train_labels = torch.tensor(
train_labels,
dtype=torch.long,
)
self.train_masks = torch.tensor(train_masks)
self.validation_inputs = torch.tensor(val_ids)
self.validation_labels = torch.tensor(
val_labels,
dtype=torch.long,
)
self.validation_masks = torch.tensor(val_masks)
self.test_inputs = torch.tensor(test_ids)
self.test_labels = torch.tensor(
test_labels,
dtype=torch.long,
)
self.test_masks = torch.tensor(test_masks)
if __name__ == "__main__":
dataset_name = "bbbp"
MoleculeData(dataset_name=dataset_name)