--- a
+++ b/bert_mixup/early_mixup/dataloader.py
@@ -0,0 +1,74 @@
+from deepchem.molnet import load_bace_classification, load_bbbp
+import numpy as np
+
+from simcse import SimCSE
+import torch
+from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
+
+from args_parser import parse_args
+import sys
+import pandas as pd
+
+
+_datasets = {"bace": load_bace_classification, "bbbp": load_bbbp}
+
+
+def embed_smiles(model, smiles):
+    embeddings = model.encode(smiles)
+    return embeddings
+
+
+def get_dataloaders(args):
+    model = SimCSE(args.model_name_or_path)
+
+    _, datasets, _ = _datasets.get(args.dataset_name)(reload=False)
+    (train_dataset, valid_dataset, test_dataset) = datasets
+
+    train_indices = []
+    train_labels = [y[0] for y in train_dataset.y]
+    label_df = pd.DataFrame(train_labels, columns=["labels"])
+    if args.samples_per_class > 0:
+        np.random.seed()
+        tp = np.random.choice(
+            list(label_df[label_df["labels"] == 1].index),
+            args.samples_per_class,
+            replace=False,
+        )
+        tn = np.random.choice(
+            list(label_df[label_df["labels"] == 0].index),
+            args.samples_per_class,
+            replace=False,
+        )
+        train_indices = list(tp) + list(tn)
+
+    np.random.seed()
+
+    train_smiles = train_dataset.ids[train_indices]
+    train_embeddings = embed_smiles(model, smiles=list(train_smiles))
+    train_labels = np.array([y[0] for y in train_dataset.y[train_indices]])
+
+    val_smiles = valid_dataset.ids
+    val_embeddings = embed_smiles(model, smiles=list(val_smiles))
+    val_labels = np.array([y[0] for y in valid_dataset.y])
+
+    test_smiles = test_dataset.ids
+    test_embeddings = embed_smiles(model, smiles=list(test_smiles))
+    test_labels = np.array([y[0] for y in test_dataset.y])
+
+    train_data = TensorDataset(train_embeddings, torch.Tensor(train_labels))
+    train_sampler = RandomSampler(train_data)
+    train_dataloader = DataLoader(
+        train_data, sampler=train_sampler, batch_size=args.batch_size
+    )
+
+    val_data = TensorDataset(val_embeddings, torch.Tensor(val_labels))
+    val_sampler = RandomSampler(val_data)
+    val_dataloader = DataLoader(val_data, sampler=val_sampler, batch_size=len(val_data))
+
+    test_data = TensorDataset(test_embeddings, torch.Tensor(test_labels))
+    test_sampler = RandomSampler(test_data)
+    test_dataloader = DataLoader(
+        test_data, sampler=test_sampler, batch_size=len(test_data)
+    )
+
+    return train_dataloader, val_dataloader, test_dataloader