--- a +++ b/bert_mixup/late_mixup/data.py @@ -0,0 +1,260 @@ +import pandas as pd +import torch +from enumeration import SmilesEnumerator +from keras_preprocessing.sequence import pad_sequences +from transformers import BertTokenizer +from tqdm import tqdm +import logging +from deepchem.molnet import load_bbbp, load_bace_classification +import numpy as np +from tqdm import tqdm + +## setting the threshold of logger to INFO +logging.basicConfig(filename="data_loader.log", level=logging.INFO) + +## creating an object +logger = logging.getLogger() + + +MOLECULE_NET_DATASETS = {"bbbp": load_bbbp, "bace": load_bace_classification} + + +class MoleculeData: + def __init__( + self, + dataset_name, + max_sequence_length=512, + debug=0, + n_augment=0, + samples_per_class=-1, + model_name_or_path="shahrukhx01/smole-bert", + ): + """ + Load dataset and bert tokenizer + """ + self.debug = debug + ## load data into memory + tasks, datasets, transformers = MOLECULE_NET_DATASETS[dataset_name]( + reload=False + ) + self.train_dataset, self.valid_dataset, self.test_dataset = datasets + + ## set max sequence length for model + self.max_sequence_length = max_sequence_length + ## get bert tokenizer + self.tokenizer = BertTokenizer.from_pretrained( + model_name_or_path, do_lower_case=True + ) + self.enumerator = SmilesEnumerator() + self.n_augment = n_augment + self.samples_per_class = samples_per_class + + def train_val_test_split(self): + """ + Separate out labels and texts + """ + num_samples = 1_000_000 + if self.debug: + print("Debug mode is enabled") + num_samples = 100 + train_molecules = self.train_dataset.ids[:num_samples] + train_labels = np.array( + [int(label[0]) for label in self.train_dataset.y][:num_samples] + ) + + self.indices = [] + tp, tn = [], [] + self.augmented_data_index = len(train_molecules) + self.label_df = pd.DataFrame(train_labels, columns=["labels"]) + if self.samples_per_class > 0: + np.random.seed() + tp = np.random.choice( + list(self.label_df[self.label_df["labels"] == 1].index), + self.samples_per_class, + replace=False, + ) + tn = np.random.choice( + list(self.label_df[self.label_df["labels"] == 0].index), + self.samples_per_class, + replace=False, + ) + self.indices = list(tp) + list(tn) + aug_molecules, aug_labels = [], [] + if self.n_augment: + for train_smiles, train_label in tqdm( + zip(train_molecules[self.indices], train_labels[self.indices]) + ): + + molecules_augmented = self.enumerator.smiles_enumeration( + input_smiles=train_smiles, n_augment=self.n_augment + ) + if len(molecules_augmented): + train_augmented_labels = [train_label] * len(molecules_augmented) + aug_molecules += molecules_augmented + aug_labels += train_augmented_labels + if len(aug_molecules) and len(aug_molecules) == len(aug_labels): + train_molecules, train_labels = list(train_molecules), list(train_labels) + train_molecules += aug_molecules + train_labels += aug_labels + + val_molecules = self.valid_dataset.ids + val_labels = np.array([int(label[0]) for label in self.valid_dataset.y]) + + test_molecules = self.test_dataset.ids + test_labels = np.array([int(label[0]) for label in self.test_dataset.y]) + + return ( + train_molecules, + val_molecules, + test_molecules, + train_labels, + val_labels, + test_labels, + ) + + def preprocess(self, texts): + """ + Add bert token (CLS and SEP) tokens to each sequence pre-tokenization + """ + ## separate labels and texts before preprocessing + # Adding CLS and SEP tokens at the beginning and end of each sequence for BERT + texts_processed = ["[CLS] " + str(sequence) + " [SEP]" for sequence in texts] + return texts_processed + + def tokenize(self, texts): + """ + Use bert tokenizer to tokenize each sequence and post-process + by padding or truncating to a fixed length + """ + ## tokenize sequence + tokenized_molecules = [self.tokenizer.tokenize(text) for text in tqdm(texts)] + + ## convert tokens to ids + print("convert tokens to ids") + text_ids = [ + self.tokenizer.convert_tokens_to_ids(x) for x in tqdm(tokenized_molecules) + ] + + ## pad our text tokens for each sequence + print("pad our text tokens for each sequence") + text_ids_post_processed = pad_sequences( + text_ids, + maxlen=self.max_sequence_length, + dtype="long", + truncating="post", + padding="post", + ) + return text_ids_post_processed + + def create_attention_mask(self, text_ids): + """ + Add attention mask for padding tokens + """ + attention_masks = [] + # create a mask of 1s for each token followed by 0s for padding + for seq in tqdm(text_ids): + seq_mask = [float(i > 0) for i in seq] + attention_masks.append(seq_mask) + return attention_masks + + def process_molecules(self): + """ + Apply preprocessing and tokenization pipeline of texts + """ + ## perform the split + ( + train_molecules, + val_molecules, + test_molecules, + train_labels, + val_labels, + test_labels, + ) = self.train_val_test_split() + + print("preprocessing texts") + ## preprocess train, val, test texts + train_molecules_processed = self.preprocess(train_molecules) + val_molecules_processed = self.preprocess(val_molecules) + test_molecules_processed = self.preprocess(test_molecules) + + del train_molecules + del val_molecules + del test_molecules + + ## preprocess train, val, test texts + print("tokenizing train texts") + train_ids = self.tokenize(train_molecules_processed) + print("tokenizing val texts") + val_ids = self.tokenize(val_molecules_processed) + print("tokenizing test texts") + test_ids = self.tokenize(test_molecules_processed) + + del train_molecules_processed + del val_molecules_processed + del test_molecules_processed + + ## create masks for train, val, test texts + print("creating train attention masks for texts") + train_masks = self.create_attention_mask(train_ids) + print("creating val attention masks for texts") + val_masks = self.create_attention_mask(val_ids) + print("creating test attention masks for texts") + test_masks = self.create_attention_mask(test_ids) + return ( + train_ids, + val_ids, + test_ids, + train_masks, + val_masks, + test_masks, + train_labels, + val_labels, + test_labels, + ) + + def text_to_tensors(self): + """ + Converting all the data into torch tensors + """ + ( + train_ids, + val_ids, + test_ids, + train_masks, + val_masks, + test_masks, + train_labels, + val_labels, + test_labels, + ) = self.process_molecules() + + print("converting all variables to tensors") + ## convert inputs, masks and labels to torch tensors + self.train_inputs = torch.tensor(train_ids) + + train_values = np.max(train_labels) + 1 + + self.train_labels = torch.tensor( + train_labels, + dtype=torch.long, + ) + self.train_masks = torch.tensor(train_masks) + + self.validation_inputs = torch.tensor(val_ids) + self.validation_labels = torch.tensor( + val_labels, + dtype=torch.long, + ) + self.validation_masks = torch.tensor(val_masks) + + self.test_inputs = torch.tensor(test_ids) + self.test_labels = torch.tensor( + test_labels, + dtype=torch.long, + ) + self.test_masks = torch.tensor(test_masks) + + +if __name__ == "__main__": + dataset_name = "bbbp" + MoleculeData(dataset_name=dataset_name)